7.2.3. Example: Large Language Model (LLM) Inference
Applications for LLaMa 1B inference
We have prepared two configurations based on the presence of certain environment variables: 1. Running Prefill on MN-Core 2 while executing Decode on CPU (Prefill on MN-Core 2) 2. Running Prefill on CPU while executing Decode on MN-Core 2 (Decode on MN-Core 2)
For each configuration, you can add a prompt like --prompt 'What is the meaning of life?'.
Additionally, you can limit the number of threads used during compilation with the --num_compiler_threads option.
For detailed information, please refer to Compilation Errors.
Execution Method (Prefill on MN-Core 2)
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./examples/run_llm_infer.sh --compile_prefill --prepare_attention_mask_on_cpu --device mncore2:auto
Expected Output (Prefill on MN-Core 2)
=========== Generated with compilation ==========
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
<s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.
Execution Method (Decode on MN-Core 2)
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./examples/run_llm_infer.sh --compile_decode --prepare_attention_mask_on_cpu --device mncore2:auto
Expected Output (Decode on MN-Core 2)
=========== Generated with compilation ==========
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
<s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.
Scripts
1#! /bin/bash
2
3set -eux -o pipefail
4
5EXAMPLE_NAME=run_llm_infer
6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
7
8CURRENT_DIR=$(realpath $(dirname $0))
9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13 python3 -m venv --system-site-packages ${VENVDIR}
14fi
15
16source ${VENVDIR}/bin/activate
17# See https://discuss.huggingface.co/t/cas-bridge-xethub-hf-co-broke/158626/8 for hf_xet.
18pip3 install transformers==4.44.0 huggingface-hub==0.34.4 hf_xet==v1.1.5
19
20source "${BUILD_DIR}/codegen_pythonpath.sh"
21
22export MNCORE_USE_LEGACY_ONNX_EXPORTER=1
23export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
24
25exec python3 ${CURRENT_DIR}/llm_infer.py "$@"
1import argparse
2import os
3from typing import Mapping
4
5import torch
6from mlsdk import CacheOptions, Context, MNDevice, storage
7from mlsdk.experimental.llm.attention_mask import (
8 prepare_4d_causal_attention_mask_with_cache_position,
9)
10from mlsdk.experimental.llm.kv_cache import (
11 kv_cache_to_legacy,
12 kv_cache_to_plamo,
13 kv_cache_to_tensor,
14)
15from transformers import AutoModelForCausalLM, AutoTokenizer
16
17
18def prepare_prompt(tokenizer, prompt, system_prompt):
19 if tokenizer.chat_template:
20 messages = [
21 {
22 "role": "system",
23 "content": system_prompt,
24 },
25 {"role": "user", "content": prompt},
26 ]
27 prompt = tokenizer.apply_chat_template(
28 messages,
29 tokenize=False,
30 add_generation_prompt=True,
31 )
32 return prompt
33
34
35def infer_with_generate(
36 prompt: str,
37 model: AutoModelForCausalLM,
38 tokenizer: AutoTokenizer,
39 max_new_tokens: int,
40) -> torch.Tensor:
41 inputs = tokenizer(prompt, return_tensors="pt")
42 # Greedy decoding for simplicity for comparing the results with the compiled version.
43 output_ids = model.generate(
44 inputs["input_ids"], do_sample=False, max_new_tokens=max_new_tokens
45 )
46 assert isinstance(output_ids, torch.Tensor)
47 return output_ids
48
49
50def infer_with_compilation( # NOQA: CFQ002, CFQ001
51 *,
52 prompt: str,
53 model: AutoModelForCausalLM,
54 tokenizer: AutoTokenizer,
55 max_length: int,
56 max_new_tokens: int,
57 compile_prefill: bool,
58 compile_decode: bool,
59 device_name: str,
60 outdir: str,
61 check_intermediate_outputs: bool,
62 prepare_attention_mask_on_cpu: bool,
63 disable_cache: bool,
64 num_compiler_threads: int,
65) -> torch.Tensor:
66 is_plamo_model = any("plamo" in a.lower() for a in model.config.architectures)
67
68 def forward(inputs: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
69 assert all(isinstance(v, torch.Tensor) for v in inputs.values()), {
70 k: type(v) for k, v in inputs.items()
71 }
72 if "past_key_values" in inputs:
73 if is_plamo_model:
74 kv_cache_func = kv_cache_to_plamo
75 else:
76 # @todo (hvy): Stop using the deprecated legacy KV cache format of tuples.
77 kv_cache_func = kv_cache_to_legacy
78 past_key_values = kv_cache_func(inputs["past_key_values"])
79 else:
80 past_key_values = None
81
82 outputs = model.forward(
83 input_ids=inputs["input_ids"],
84 attention_mask=inputs["attention_mask"],
85 position_ids=inputs["position_ids"],
86 past_key_values=past_key_values,
87 use_cache=True,
88 )
89 return {
90 "logits": outputs.logits,
91 "next_past_key_values": kv_cache_to_tensor(outputs.past_key_values)[
92 :, :, :, :, 1:, :
93 ], # Do every operation, including the shifting, for the KV cache on device.
94 }
95
96 assert tokenizer.padding_side == "left"
97 inputs = tokenizer(
98 prompt, return_tensors="pt", padding="max_length", max_length=max_length
99 )
100 # @todo (hvy): Consider subtracting 1 from the position_ids to match modeling_llama.py.
101 assert "position_ids" not in inputs
102 inputs["position_ids"] = inputs["attention_mask"].cumsum(1)
103 if prepare_attention_mask_on_cpu:
104 inputs["attention_mask"] = prepare_4d_causal_attention_mask_with_cache_position(
105 inputs["attention_mask"], inputs["position_ids"], model.dtype
106 )
107 output_ids = inputs["input_ids"]
108
109 device = MNDevice(device_name)
110 context = Context(device)
111 Context.switch_context(context)
112 context.registry.register("model", model)
113
114 compiled_funcs = {}
115
116 for step in range(max_new_tokens):
117 if step == 0:
118 if compile_prefill and "prefill" not in compiled_funcs:
119 # Set codegen internal flags
120 os.environ["CODEGEN_TIME_SLICE_SCATTERED_INDEXING_BCAST"] = "1"
121 os.environ["CODEGEN_OP_DEF"] = "Gather=GatherBcast"
122
123 compiled_funcs["prefill"] = context.compile(
124 forward,
125 inputs,
126 storage.path(outdir + "/prefill"),
127 cache_options=(
128 CacheOptions(outdir + "/prefill_cache")
129 if not disable_cache
130 else None
131 ),
132 num_compiler_threads=num_compiler_threads,
133 )
134 forward_for_step = compiled_funcs.get("prefill", forward)
135 else:
136 if compile_decode and "decode" not in compiled_funcs:
137 compiled_funcs["decode"] = context.compile(
138 forward,
139 inputs,
140 storage.path(outdir + "/decode"),
141 cache_options=(
142 CacheOptions(outdir + "/decode_cache")
143 if not disable_cache
144 else None
145 ),
146 num_compiler_threads=num_compiler_threads,
147 )
148 forward_for_step = compiled_funcs.get("decode", forward)
149
150 outputs = forward_for_step(inputs)
151
152 if check_intermediate_outputs:
153 # @todo (hvy): Consider using a more sophisticated check for the outputs.
154 if "mncore" in device_name:
155 atol = 1.0
156 else:
157 assert device_name == "pfvm:cpu"
158 atol = 5e-3
159 n_tokens = inputs["position_ids"].max()
160 outputs_expected = forward(inputs)
161 logits = outputs["logits"][:, -n_tokens:]
162 logits_expected = outputs_expected["logits"][:, -n_tokens:]
163 next_past_key_values = outputs["next_past_key_values"][
164 :, :, :, :, -n_tokens:
165 ]
166 next_past_key_values_expected = outputs_expected["next_past_key_values"][
167 :, :, :, :, -n_tokens:
168 ]
169
170 assert torch.allclose(logits, logits_expected, atol=atol), (
171 step,
172 (logits - logits_expected).abs().max(),
173 )
174 assert torch.allclose(
175 next_past_key_values, next_past_key_values_expected, atol=atol
176 ), (
177 step,
178 (next_past_key_values - next_past_key_values_expected).abs().max(),
179 )
180
181 next_input_ids = (
182 outputs["logits"].cpu().argmax(dim=2)[:, -1:]
183 ) # Greedy decoding.
184 if prepare_attention_mask_on_cpu:
185 next_attention_mask = inputs["attention_mask"][:, :, -1:, :]
186 next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
187 next_attention_mask[:, :, :, -1] = 0
188 else:
189 next_attention_mask = inputs["attention_mask"]
190 next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
191 next_attention_mask[:, -1] = 1
192 next_position_ids = inputs["position_ids"][:, -1:] + 1
193 next_past_key_values = outputs["next_past_key_values"].cpu()
194 inputs = {
195 "input_ids": next_input_ids.detach(),
196 "attention_mask": next_attention_mask.detach(),
197 "position_ids": next_position_ids.detach(),
198 "past_key_values": next_past_key_values.detach(),
199 }
200
201 output_ids = torch.cat([output_ids, next_input_ids], dim=1)
202
203 if next_input_ids.item() == tokenizer.eos_token_id:
204 break
205
206 return output_ids[:, max_new_tokens:]
207
208
209def main(args):
210 prompt = args.prompt
211 system_prompt = args.system_prompt
212 model_name = args.model_name
213 max_length = args.max_length
214 max_new_tokens = args.max_new_tokens
215 compile_prefill = args.compile_prefill
216 compile_decode = args.compile_decode
217 device_name = args.device_name
218 outdir = args.outdir
219 check_intermediate_outputs = args.check_intermediate_outputs
220 prepare_attention_mask_on_cpu = args.prepare_attention_mask_on_cpu
221 disable_cache = args.disable_cache
222 num_compiler_threads = args.num_compiler_threads
223
224 model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
225 model.eval() # Some models do not return the KV cache in training mode.
226 tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
227 if tokenizer.pad_token_id is None:
228 tokenizer.pad_token_id = tokenizer.eos_token_id
229 tokenizer.padding_side = "left" # For static KV caching
230 tokenizer.truncation_side = "left"
231
232 prompt = prepare_prompt(tokenizer, prompt, system_prompt)
233
234 outputs = infer_with_compilation(
235 prompt=prompt,
236 model=model,
237 tokenizer=tokenizer,
238 max_length=max_length,
239 max_new_tokens=max_new_tokens,
240 compile_prefill=compile_prefill,
241 compile_decode=compile_decode,
242 device_name=device_name,
243 outdir=outdir,
244 check_intermediate_outputs=check_intermediate_outputs,
245 prepare_attention_mask_on_cpu=prepare_attention_mask_on_cpu,
246 disable_cache=disable_cache,
247 num_compiler_threads=num_compiler_threads,
248 )
249 print(
250 "=========== Generated with compilation ==========\n",
251 tokenizer.decode(outputs[0]),
252 )
253
254 outputs_expected = infer_with_generate(prompt, model, tokenizer, max_new_tokens)
255 print(
256 "========== Generated with model.generate ==========\n",
257 tokenizer.decode(outputs_expected[0]),
258 )
259
260 # @todo (hvy): Do not rely on `max_new_tokens` tokens always being generated?
261 assert torch.equal(
262 outputs[:, -max_new_tokens:], outputs_expected[:, -max_new_tokens:]
263 ), "Outputs differed. Check generated outputs above."
264 print("Generated outputs matched.")
265
266
267if __name__ == "__main__":
268 parser = argparse.ArgumentParser()
269 parser.add_argument(
270 "--prompt",
271 type=str,
272 default='The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.', # NOQA
273 )
274 parser.add_argument(
275 "--system_prompt",
276 type=str,
277 default="You are a friendly chatbot who is an expert on MN-Core.",
278 )
279 parser.add_argument(
280 "--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
281 )
282 parser.add_argument("--max_length", type=int, default=256)
283 parser.add_argument("--max_new_tokens", type=int, default=64)
284 parser.add_argument("--num_compiler_threads", type=int, default=-1)
285 parser.add_argument("--compile_prefill", action="store_true")
286 parser.add_argument("--compile_decode", action="store_true")
287 parser.add_argument("--device_name", type=str, default="mncore2:auto")
288 parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_llm_infer")
289 parser.add_argument("--check_intermediate_outputs", action="store_true")
290 parser.add_argument("--prepare_attention_mask_on_cpu", action="store_true")
291 parser.add_argument("--disable_cache", action="store_true")
292 args = parser.parse_args()
293 main(args)