/** * Helper to keep track of history conversations. */ class Conversation { constructor(config) { this.system = config.system; this.roles = config.roles; this.offset = config.offset; this.seps = config.seps; this.convId = null; this.contextWindowStart = 0; } /** * Get prompt arrays with the first one as system. * * @returns The prompt array. */ getPromptArray() { if (this.seps.length == 0) { throw Error("Need seps to work") } let ret = [this.system + this.seps[0]]; for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { const item = tvmjsGlobalEnv.workerHistoryMsg[i]; const role = item[0]; const message = item[1]; if (message !== undefined && message != "") { ret.push(role + ": " + message + this.seps[i % this.seps.length]); } else { ret.push(role + ":"); } } return ret; } /** * Get prompt arrays that has not been fed as input * * @returns The prompt array. */ getPromptArrayUnproccessed() { if (this.seps.length == 0) { throw Error("Need seps to work") } if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) { throw Error("needs to call getLastPromptArray for the first message"); } let ret = [this.seps[this.seps.length - 1]]; for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { const item = tvmjsGlobalEnv.workerHistoryMsg[i]; const role = item[0]; const message = item[1]; if (message !== undefined && message != "") { ret.push(role + ": " + message + this.seps[i % this.seps.length]); } else { ret.push(role + ":"); } } return ret; } /** * Get last prompt array with prefix as system. * * @returns The prompt array. */ getLastPromptArray() { if (this.seps.length == 0) { throw Error("Need seps to work") } let ret = [this.system + this.seps[0]]; for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { const item = tvmjsGlobalEnv.workerHistoryMsg[i]; const role = item[0]; const message = item[1]; if (message !== undefined && message != "") { ret.push(role + ": " + message + this.seps[i % this.seps.length]); } else { ret.push(role + ":"); } } return ret; } reset() { tvmjsGlobalEnv.workerHistoryMsg = []; this.covId = null } getStopStr() { return this.seps[this.seps.length - 1]; } appendMessage(role, message) { tvmjsGlobalEnv.workerHistoryMsg.push([role, message]); } switchConversation(message) { tvmjsGlobalEnv.workerHistoryMsg = message this.covId = tvmjsGlobalEnv.covId } } function defaultConversation(maxWindowLength = 2048) { return new Conversation({ system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.", roles: ["user", "assistant"], maxWindowLength: maxWindowLength, offset: 0, seps: [" ", ""], }); }; class LLMChatPipeline { constructor(tvm, tokenizer, cacheMetadata, config) { if (cacheMetadata == undefined) { throw Error("Expect cacheMetadata"); } this.tvm = tvm; this.logger = console.log; this.tokenizer = tokenizer; this.bosTokenId = 1; this.eosTokenId = 2; this.maxWindowLength = config.maxWindowLength; this.maxGenLength = config.maxGenLength; this.meanGenLength = config.meanGenLength; this.streamInterval = 1; this.decodingTotalTime = 0; this.decodingTotalTokens = 0; this.encodingTotalTime = 0; this.encodingTotalTokens = 0; this.conversation = defaultConversation(this.maxWindowLength); this.device = this.tvm.webgpu(); this.vm = this.tvm.detachFromCurrentScope( this.tvm.createVirtualMachine(this.device) ); this.encoding = this.tvm.detachFromCurrentScope( this.vm.getFunction("encoding") ); this.decoding = this.tvm.detachFromCurrentScope( this.vm.getFunction("decoding") ); this.params = this.tvm.detachFromCurrentScope( this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize) ); const fcreateCache = this.vm.getFunction("create_kv_cache"); this.fclearKVCaches = this.tvm.detachFromCurrentScope( this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear") ); // use extern config for now this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); // fill with pad token this.logitsOnCPU = undefined; this.kvCacheLength = 0; this.clearCache = true } dispose() { // note: tvm instance is not owned by this class this.params.dispose(); this.decoding.dispose(); this.encoding.dispose(); this.vm.dispose(); this.kvCache.dispose(); this.fclearKVCaches.dispose(); if (this.logitsOnCPU != undefined) { this.logitsOnCPU.dispose(); } } #clearKVCache() { this.fclearKVCaches(this.kvCache); this.kvCacheLength = 0; } #forward(inputs, curPos) { this.tvm.beginScope(); var retValue; const seqLenShape = this.tvm.makeShapeTuple([curPos]); if (inputs.shape[1] > 1) { retValue = this.encoding( inputs, seqLenShape, this.kvCache, this.params ); } else { retValue = this.decoding( inputs, seqLenShape, this.kvCache, this.params ); } const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); this.tvm.endScope(); this.tvm.attachToCurrentScope(logits); return logits; } // NOTE: caller must call device.sync() #updateLogitsOnCPU(logits) { if (this.logitsOnCPU == undefined) { this.logitsOnCPU = this.tvm.detachFromCurrentScope( this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) ); } else { if (logits.shape[0] != this.logitsOnCPU.shape[0]) { throw Error("We expect the size of logits to remain unchanged"); } } this.logitsOnCPU.copyFrom(logits); } async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) { this.tvm.beginScope(); this.#updateLogitsOnCPU(logits); this.tvm.endScope(); await this.device.sync(); return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); } async getInputTokens() { let tokens = [this.bosTokenId]; let prompts = "" if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) { prompts = this.conversation.getPromptArray(); } else { tokens.pop(); prompts = this.conversation.getPromptArrayUnproccessed(); } tokens.push(...await this.tokenizer.encodeIds(prompts[0])); let ctxLength = tokens.length; let context = []; let need_shift_window = false; for (let i = prompts.length - 1; i > 0; --i) { const encoded = this.tokenizer.encodeIds(prompts[i]); ctxLength += encoded.length; if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) { need_shift_window = true; break; } context.unshift(encoded); } if (!need_shift_window) { for (const ctx of context) { tokens.push(...ctx); } return tokens; } // need shift window and re-encode this.logger("need shift window") this.kvCacheLength = 0; this.clearCache = true; // abandon all tokens we collected tokens = [this.bosTokenId] let all_prompts = this.conversation.getPromptArray(); tokens.push(...await this.tokenizer.encodeIds(all_prompts[0])); context = []; ctxLength = tokens.length; //only keep 10% of the window context const fill_factor = 0.1 for (let i = all_prompts.length - 1; i > 0; --i) { const encoded = this.tokenizer.encodeIds(all_prompts[i]); ctxLength += encoded.length; if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) { break; } context.unshift(encoded); } for (const ctx of context) { tokens.push(...ctx); } if (tokens.length + this.meanGenLength >= this.maxWindowLength) { throw Error("Exceed max window length curr=" + tokens.length); } return tokens; } resetChat() { if (this.conversation) { this.conversation.reset(); } this.#clearKVCache(); this.decodingTotalTime = 0; this.encodingTotalTime = 0; this.decodingTotalTokens = 0; this.encodingTotalTokens = 0; } async generate(inputPrompt, callbackUpdateResponse) { // switch to new Conversation if (this.conversation.convId !== tvmjsGlobalEnv.covId) {} this.conversation.appendMessage(this.conversation.roles[0], inputPrompt); this.conversation.appendMessage(this.conversation.roles[1], ""); const stopStr = this.conversation.getStopStr(); const tokens = await this.getInputTokens(); const inputTokenLength = tokens.length; var outputPrompt = ""; if (this.clearCache) { this.#clearKVCache(); this.clearCache = false; } const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length); if (maxGenLen < this.meanGenLength) { throw Error("Too small window size config"); } let step = 0; for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) { this.tvm.beginScope(); var inputData; let tstart = performance.now(); if (step == 0) { inputData = this.tvm.empty([1, tokens.length], "int32", this.device); inputData.copyFrom(tokens); } else { inputData = this.tvm.empty([1, 1], "int32", this.device); inputData.copyFrom(tokens.slice(tokens.length - 1)); } const logits = this.tvm.detachFromCurrentScope( this.#forward(inputData, this.kvCacheLength + inputTokenLength + step) ); this.tvm.endScope(); const nextToken = await this.sampleTokenFromLogits(logits); logits.dispose(); tokens.push(nextToken); const outputTokens = tokens.slice(inputTokenLength); outputPrompt = this.tokenizer.decodeIds(outputTokens); if (nextToken == this.eosTokenId) break; const stopPos = outputPrompt.lastIndexOf(stopStr); if (stopPos != -1) { outputPrompt = outputPrompt.substring(0, stopPos); break; } let tend = performance.now(); if (step != 0) { this.decodingTotalTokens += 1; this.decodingTotalTime += (tend - tstart) / 1000; } else { this.encodingTotalTime += (tend - tstart) / 1000; this.encodingTotalTokens += inputTokenLength; } if (step % this.streamInterval == 0) { callbackUpdateResponse(step, outputPrompt); } } this.kvCacheLength += tokens.length - 1; tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt; return outputPrompt; } async evaluate() { // run a canonical evaluation of the flow this.#clearKVCache(); const testPrompt = "The capital of Canada is"; const ids = await this.tokenizer.encodeIds(testPrompt); const inputPromptSize = ids.length; const tokens = Array.from(ids); tokens.unshift(this.bosTokenId); if (tokens.length == 0) { throw Error("empty token"); } this.tvm.beginScope(); const inputData = this.tvm.empty([1, tokens.length], "int32", this.device); inputData.copyFrom(tokens); const encodingStart = performance.now(); this.#forward(inputData, tokens.length); this.tvm.endScope(); await this.device.sync(); const decodingStart = performance.now(); this.tvm.beginScope(); const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]); this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1)); await this.device.sync(); this.tvm.endScope(); const decodingEnd = performance.now(); const msg = ( `encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` + `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec` ); // simply log tokens for eyeballing. console.log("Logits:"); console.log(this.logitsOnCPU.toArray()); console.log(msg); } /** * async preload webgpu pipelines when possible. */ async asyncLoadWebGPUPiplines() { await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); } runtimeStatsText() { return ( `encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` + `decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec` ) } } /** * A instance that can be used to facilitate deployment. */ class LLMChatInstance { constructor() { this.requestInProgress = false; this.config = undefined; this.tvm = undefined; this.pipeline = undefined; this.logger = console.log; this.debugTest = false; } /** * Initialize TVM * @param wasmUrl URL to wasm source. * @param cacheUrl URL to NDArray cache. * @param logger Custom logger. */ async #asyncInitTVM(wasmUrl, cacheUrl) { if (this.tvm !== undefined) { return; } this.logger = console.log; const wasmSource = await ( await fetch(wasmUrl) ).arrayBuffer(); const tvm = await tvmjs.instantiate( new Uint8Array(wasmSource), new EmccWASI(), this.logger ); // intialize WebGPU try { const output = await tvmjs.detectGPUDevice(); if (output !== undefined) { var label = "WebGPU"; if (output.adapterInfo.description.length != 0) { label += " - " + output.adapterInfo.description; } else { label += " - " + output.adapterInfo.vendor; } this.appendMessage("init", "Initialize GPU device: " + label); tvm.initWebGPU(output.device); } else { this.appendMessage("error", "This browser env do not support WebGPU"); this.reset(); throw Error("This browser env do not support WebGPU"); } } catch (err) { this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString()); console.log(err); this.reset(); throw Error("Find an error initializing WebGPU: " + err.toString()); } this.tvm = tvm; const initProgressCallback = (report) => { this.updateLastMessage("initing", report.text); } tvm.registerInitProgressCallback(initProgressCallback); await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); } /** * Async initialize instance. */ async asyncInit() { if (this.pipeline !== undefined) return; await this.#asyncInitConfig(); await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); await this.#asyncInitPipeline(); } /** * Async initialize config */ async #asyncInitConfig() { if (this.config !== undefined) return; this.config = await (await fetch("/lib/WebLLM/config.json")).json(); } /** * Initialize the pipeline * * @param tokenizerModel The url to tokenizer model. */ async #asyncInitPipeline() { if (this.pipeline !== undefined) return; // initialize UX and tokenizer const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer); this.pipeline = this.tvm.withNewScope(() => { return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config); }); await this.pipeline.asyncLoadWebGPUPiplines(); this.appendMessage("initing", "All initialization finished.", true); } appendMessage(kind, text, ifFinish) { if (kind == "initing") { text = "[System Initalize] " + text; } console.log(`[${kind}] ${text}`); globalThis.postMessage({ type: 'initing', action: 'append', msg: text, ifError: kind == 'error', ifFinish: !!ifFinish }) } updateLastMessage(type, text, ifFinish) { if (type == "initing") { text = `[System Initalize] ${text}` } globalThis.postMessage({ type, action: 'updateLast', msg: text, ifFinish: !!ifFinish }) } async respondTestMessage(repeat) { const testMessage = "I am a friendly bot. Please ask questions."; const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage); const currentIds = []; for (let k = 0; k < repeat; ++k) { for (let i = 0; i < encodedResult.length; ++i) { currentIds.push(encodedResult[i]); const msg = this.pipeline.tokenizer.decodeIds(currentIds); this.updateLastMessage("chatting", msg); await new Promise(resolve => setTimeout(resolve, 50)); } } } resetChat() { if (this.pipeline) { this.pipeline.resetChat(); } } /** * Run generate */ async generate() { if (this.requestInProgress) { return; } this.requestInProgress = true; try { await this.asyncInit(); } catch (err) { this.appendMessage("error", "Init error, " + err.toString()); console.log(err); this.reset(); this.requestInProgress = false; return; } if (this.debugTest) { await this.pipeline.evaluate(); this.requestInProgress = false; return; } const prompt = tvmjsGlobalEnv.message; if (prompt == "") { this.requestInProgress = false; return; } const callbackUpdateResponse = (step, msg) => { if (msg.endsWith("##")) { msg = msg.substring(0, msg.length - 2); } else if (msg.endsWith("#")) { msg = msg.substring(0, msg.length - 1); } this.updateLastMessage("chatting", msg); }; try { const output = await this.pipeline.generate(prompt, callbackUpdateResponse); this.updateLastMessage("chatting", output, true); this.updateLastMessage("stats",this.pipeline.runtimeStatsText()) console.log(this.pipeline.runtimeStatsText()); } catch (err) { this.appendMessage("error", "Generate error, " + err.toString()); console.log(err); this.reset(); } this.requestInProgress = false; } /** * Reset the instance; */ reset() { this.tvm = undefined; if (this.pipeline !== undefined) { this.pipeline.dispose(); } this.pipeline = undefined; } } localLLMChatIntance = new LLMChatInstance(); tvmjsGlobalEnv.asyncOnGenerate = async function () { await localLLMChatIntance.generate(); }; tvmjsGlobalEnv.asyncOnReset = async function () { await localLLMChatIntance.resetChat(); };