LLM 参数越来越小,使模型跑在端侧成为可能,为什么要模型跑在端侧呢,首先可以节省服务器的算力,现在 GPU 的租用价格还是比较的高的,例如租用一个 A10 的卡1 年都要 3 万多。如果将一部分算力转移到端侧通过小模型进行计算,这样可以释放服务器的大部分算力。其次是安全问题,跑在端侧,所有数据只在本地使用不会上传到服务器,确保了个人的隐私数据不进行上传。
怎么将模型运行在端侧呢,我们拿浏览器举例,现在很多推理引擎都已经支持 CPU,例如 Ollama/Llamafile,这些都是服务端。微软 ONNX Runtime 主要用于端侧,需要将模型转为 ONNX 格式。ONNX web runtime 可以使用 GPU 或者 CPU 在浏览器进行推理,GPU 使用 WEBGL,CPU 使用的是 WASM。
本文使用 Transformer.js 加载 Qwen1.5,由于模型调用比较耗时,使用 Webworker 加载/推理模型,通过 Message 机制与 UI 进行交互。
worker.js
加载 Model 、 Tokenizer
import { env, pipeline } from '@xenova/transformers';
import {AutoModelForCausalLM, AutoTokenizer } from '@xenova/transformers';
env.allowLocalModels = false;
env.useBrowserCache = true;
class MyTranslationPipeline {
static modelId="Xenova/Qwen1.5-0.5B-Chat"
static model = null;
static tokenizer = null;
static async getModel(progress_callback = null) {
if (this.model === null) {
let model = await AutoModelForCausalLM.from_pretrained(this.modelId, { progress_callback });
this.model = model;
}
return this.model;
}
static async getTokenizer(progress_callback = null) {
if (this.tokenizer === null) {
let tokenizer = await AutoTokenizer.from_pretrained(this.modelId, { progress_callback });
this.tokenizer = tokenizer;
}
return this.tokenizer;
}
}
// Listen for messages from the main thread
self.addEventListener('message', async (event) => {
// Retrieve the translation pipeline. When called for the first time,
// this will load the pipeline and save it for future use.
let model = await MyTranslationPipeline.getModel(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
let tokenizer = await MyTranslationPipeline.getTokenizer(x => {
// We also add a progress callback to the pipeline so that we can
// track model loading.
self.postMessage(x);
});
let prompt = "Give me a short introduction to large language model."
let messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
let text = tokenizer.apply_chat_template(
messages,
{
tokenize: false,
add_generation_prompt: true
}
)
let model_inputs = await tokenizer([text], {"return_tensors":"pt"})
// // Actually perform the translation
let output = await model.generate(model_inputs.input_ids,
{
"max_new_tokens":512,
callback_function: x => {
console.log(tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true }))
self.postMessage({
status: 'update',
output: tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true })
});
}
}
);
// Send the output back to the main thread
self.postMessage({
status: 'complete',
output: tokenizer.decode(output[0], { skip_special_tokens: true }),
});
});
定义 Message
通过 Message 回调进行交互
useEffect(() => {
const onMessageReceived = (e: MessageEvent<WorkerMessage>) => {
switch (e.data.status) {
case 'initiate':
// Model file start load: add a new progress item to the list.
setReady(false);
setProgressItems(prev => [...prev, { file: e.data.file!, progress: 0 }]);
break;
case 'progress':
// Model file progress: update one of the progress items.
setProgressItems(prev =>
prev.map(item =>
item.file === e.data.file ? { ...item, progress: e.data.progress! } : item
)
);
break;
case 'done':
// Model file loaded: remove the progress item from the list.
setProgressItems(prev => prev.filter(item => item.file !== e.data.file));
break;
case 'ready':
// Pipeline ready: the worker is ready to accept messages.
setReady(true);
break;
case 'update':
// Generation update: update the output text.
setOutput(e.data.output!);
break;
case 'complete':
// Generation complete: re-enable the "Translate" button
setDisabled(false);
break;
}
};
if (!worker.current) {
// Create the worker if it does not yet exist.
worker.current = new Worker(new URL('./worker.js', import.meta.url), {
type: 'module'
});
}
// Attach the callback function as an event listener.
worker.current.addEventListener('message', onMessageReceived);
// Define a cleanup function for when the component is unmounted.
return () => {
if (worker.current) {
worker.current.removeEventListener('message', onMessageReceived);
}
};
}, []);
总结
TransformerJS 实现了 Transformer 库中的所有类和方法,目前并不是所有模型的都支持在浏览器中使用,支持的模型可以在官网进行查询。https://github.com/xenova/transformers.js