7.2.3. Example: Large Language Model (LLM) Inference

Applications for LLaMa 1B inference

We have prepared two configurations based on the presence of certain environment variables: 1. Running Prefill on MN-Core 2 while executing Decode on CPU (Prefill on MN-Core 2) 2. Running Prefill on CPU while executing Decode on MN-Core 2 (Decode on MN-Core 2)

For each configuration, you can add a prompt like --prompt 'What is the meaning of life?'. Additionally, you can limit the number of threads used during compilation with the --num_compiler_threads option. For detailed information, please refer to Compilation Errors.

Execution Method (Prefill on MN-Core 2)

Listing 7.15 Prefill on MN-Core 2
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./examples/run_llm_infer.sh --compile_prefill --prepare_attention_mask_on_cpu --device mncore2:auto

Expected Output (Prefill on MN-Core 2)

=========== Generated with compilation ==========
 </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
 <s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.

Execution Method (Decode on MN-Core 2)

Listing 7.16 Decode on MN-Core 2
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/
$ ./examples/run_llm_infer.sh --compile_decode --prepare_attention_mask_on_cpu --device mncore2:auto

Expected Output (Decode on MN-Core 2)

=========== Generated with compilation ==========
 </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
========== Generated with model.generate ==========
 <s> <|system|>
You are a friendly chatbot who is an expert on MN-Core.</s>
<|user|>
The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.</s>
<|assistant|>
Yes, that's correct. The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens using 16 A100-40G GPUs. The training has started on 2023-09
Generated outputs matched.

Scripts

Listing 7.17 /opt/pfn/pfcomp/codegen/MLSDK/examples/run_llm_infer.sh
 1#! /bin/bash
 2
 3set -eux -o pipefail
 4
 5EXAMPLE_NAME=run_llm_infer
 6VENVDIR=/tmp/${EXAMPLE_NAME}_venv
 7
 8CURRENT_DIR=$(realpath $(dirname $0))
 9CODEGEN_DIR=$(realpath ${CURRENT_DIR}/../../)
10BUILD_DIR=${BUILD_DIR:-${CODEGEN_DIR}/build}
11
12if [[ ! -d ${VENVDIR} ]]; then
13    python3 -m venv --system-site-packages ${VENVDIR}
14fi
15
16source ${VENVDIR}/bin/activate
17# See https://discuss.huggingface.co/t/cas-bridge-xethub-hf-co-broke/158626/8 for hf_xet.
18pip3 install transformers==4.44.0 huggingface-hub==0.34.4 hf_xet==v1.1.5
19
20source "${BUILD_DIR}/codegen_pythonpath.sh"
21
22export MNCORE_USE_LEGACY_ONNX_EXPORTER=1
23export MNCORE_USE_EXTERNAL_DATA_FORMAT=1
24
25exec python3 ${CURRENT_DIR}/llm_infer.py "$@"
Listing 7.18 /opt/pfn/pfcomp/codegen/MLSDK/examples/llm_infer.py
  1import argparse
  2import os
  3from typing import Mapping
  4
  5import torch
  6from mlsdk import CacheOptions, Context, MNDevice, storage
  7from mlsdk.experimental.llm.attention_mask import (
  8    prepare_4d_causal_attention_mask_with_cache_position,
  9)
 10from mlsdk.experimental.llm.kv_cache import (
 11    kv_cache_to_legacy,
 12    kv_cache_to_plamo,
 13    kv_cache_to_tensor,
 14)
 15from transformers import AutoModelForCausalLM, AutoTokenizer
 16
 17
 18def prepare_prompt(tokenizer, prompt, system_prompt):
 19    if tokenizer.chat_template:
 20        messages = [
 21            {
 22                "role": "system",
 23                "content": system_prompt,
 24            },
 25            {"role": "user", "content": prompt},
 26        ]
 27        prompt = tokenizer.apply_chat_template(
 28            messages,
 29            tokenize=False,
 30            add_generation_prompt=True,
 31        )
 32    return prompt
 33
 34
 35def infer_with_generate(
 36    prompt: str,
 37    model: AutoModelForCausalLM,
 38    tokenizer: AutoTokenizer,
 39    max_new_tokens: int,
 40) -> torch.Tensor:
 41    inputs = tokenizer(prompt, return_tensors="pt")
 42    # Greedy decoding for simplicity for comparing the results with the compiled version.
 43    output_ids = model.generate(
 44        inputs["input_ids"], do_sample=False, max_new_tokens=max_new_tokens
 45    )
 46    assert isinstance(output_ids, torch.Tensor)
 47    return output_ids
 48
 49
 50def infer_with_compilation(  # NOQA: CFQ002, CFQ001
 51    *,
 52    prompt: str,
 53    model: AutoModelForCausalLM,
 54    tokenizer: AutoTokenizer,
 55    max_length: int,
 56    max_new_tokens: int,
 57    compile_prefill: bool,
 58    compile_decode: bool,
 59    device_name: str,
 60    outdir: str,
 61    check_intermediate_outputs: bool,
 62    prepare_attention_mask_on_cpu: bool,
 63    disable_cache: bool,
 64    num_compiler_threads: int,
 65) -> torch.Tensor:
 66    is_plamo_model = any("plamo" in a.lower() for a in model.config.architectures)
 67
 68    def forward(inputs: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
 69        assert all(isinstance(v, torch.Tensor) for v in inputs.values()), {
 70            k: type(v) for k, v in inputs.items()
 71        }
 72        if "past_key_values" in inputs:
 73            if is_plamo_model:
 74                kv_cache_func = kv_cache_to_plamo
 75            else:
 76                # @todo (hvy): Stop using the deprecated legacy KV cache format of tuples.
 77                kv_cache_func = kv_cache_to_legacy
 78            past_key_values = kv_cache_func(inputs["past_key_values"])
 79        else:
 80            past_key_values = None
 81
 82        outputs = model.forward(
 83            input_ids=inputs["input_ids"],
 84            attention_mask=inputs["attention_mask"],
 85            position_ids=inputs["position_ids"],
 86            past_key_values=past_key_values,
 87            use_cache=True,
 88        )
 89        return {
 90            "logits": outputs.logits,
 91            "next_past_key_values": kv_cache_to_tensor(outputs.past_key_values)[
 92                :, :, :, :, 1:, :
 93            ],  # Do every operation, including the shifting, for the KV cache on device.
 94        }
 95
 96    assert tokenizer.padding_side == "left"
 97    inputs = tokenizer(
 98        prompt, return_tensors="pt", padding="max_length", max_length=max_length
 99    )
100    # @todo (hvy): Consider subtracting 1 from the position_ids to match modeling_llama.py.
101    assert "position_ids" not in inputs
102    inputs["position_ids"] = inputs["attention_mask"].cumsum(1)
103    if prepare_attention_mask_on_cpu:
104        inputs["attention_mask"] = prepare_4d_causal_attention_mask_with_cache_position(
105            inputs["attention_mask"], inputs["position_ids"], model.dtype
106        )
107    output_ids = inputs["input_ids"]
108
109    device = MNDevice(device_name)
110    context = Context(device)
111    Context.switch_context(context)
112    context.registry.register("model", model)
113
114    compiled_funcs = {}
115
116    for step in range(max_new_tokens):
117        if step == 0:
118            if compile_prefill and "prefill" not in compiled_funcs:
119                # Set codegen internal flags
120                os.environ["CODEGEN_TIME_SLICE_SCATTERED_INDEXING_BCAST"] = "1"
121                os.environ["CODEGEN_OP_DEF"] = "Gather=GatherBcast"
122
123                compiled_funcs["prefill"] = context.compile(
124                    forward,
125                    inputs,
126                    storage.path(outdir + "/prefill"),
127                    cache_options=(
128                        CacheOptions(outdir + "/prefill_cache")
129                        if not disable_cache
130                        else None
131                    ),
132                    num_compiler_threads=num_compiler_threads,
133                )
134            forward_for_step = compiled_funcs.get("prefill", forward)
135        else:
136            if compile_decode and "decode" not in compiled_funcs:
137                compiled_funcs["decode"] = context.compile(
138                    forward,
139                    inputs,
140                    storage.path(outdir + "/decode"),
141                    cache_options=(
142                        CacheOptions(outdir + "/decode_cache")
143                        if not disable_cache
144                        else None
145                    ),
146                    num_compiler_threads=num_compiler_threads,
147                )
148            forward_for_step = compiled_funcs.get("decode", forward)
149
150        outputs = forward_for_step(inputs)
151
152        if check_intermediate_outputs:
153            # @todo (hvy): Consider using a more sophisticated check for the outputs.
154            if "mncore" in device_name:
155                atol = 1.0
156            else:
157                assert device_name == "pfvm:cpu"
158                atol = 5e-3
159            n_tokens = inputs["position_ids"].max()
160            outputs_expected = forward(inputs)
161            logits = outputs["logits"][:, -n_tokens:]
162            logits_expected = outputs_expected["logits"][:, -n_tokens:]
163            next_past_key_values = outputs["next_past_key_values"][
164                :, :, :, :, -n_tokens:
165            ]
166            next_past_key_values_expected = outputs_expected["next_past_key_values"][
167                :, :, :, :, -n_tokens:
168            ]
169
170            assert torch.allclose(logits, logits_expected, atol=atol), (
171                step,
172                (logits - logits_expected).abs().max(),
173            )
174            assert torch.allclose(
175                next_past_key_values, next_past_key_values_expected, atol=atol
176            ), (
177                step,
178                (next_past_key_values - next_past_key_values_expected).abs().max(),
179            )
180
181        next_input_ids = (
182            outputs["logits"].cpu().argmax(dim=2)[:, -1:]
183        )  # Greedy decoding.
184        if prepare_attention_mask_on_cpu:
185            next_attention_mask = inputs["attention_mask"][:, :, -1:, :]
186            next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
187            next_attention_mask[:, :, :, -1] = 0
188        else:
189            next_attention_mask = inputs["attention_mask"]
190            next_attention_mask = torch.roll(next_attention_mask, shifts=-1, dims=-1)
191            next_attention_mask[:, -1] = 1
192        next_position_ids = inputs["position_ids"][:, -1:] + 1
193        next_past_key_values = outputs["next_past_key_values"].cpu()
194        inputs = {
195            "input_ids": next_input_ids.detach(),
196            "attention_mask": next_attention_mask.detach(),
197            "position_ids": next_position_ids.detach(),
198            "past_key_values": next_past_key_values.detach(),
199        }
200
201        output_ids = torch.cat([output_ids, next_input_ids], dim=1)
202
203        if next_input_ids.item() == tokenizer.eos_token_id:
204            break
205
206    return output_ids[:, max_new_tokens:]
207
208
209def main(args):
210    prompt = args.prompt
211    system_prompt = args.system_prompt
212    model_name = args.model_name
213    max_length = args.max_length
214    max_new_tokens = args.max_new_tokens
215    compile_prefill = args.compile_prefill
216    compile_decode = args.compile_decode
217    device_name = args.device_name
218    outdir = args.outdir
219    check_intermediate_outputs = args.check_intermediate_outputs
220    prepare_attention_mask_on_cpu = args.prepare_attention_mask_on_cpu
221    disable_cache = args.disable_cache
222    num_compiler_threads = args.num_compiler_threads
223
224    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
225    model.eval()  # Some models do not return the KV cache in training mode.
226    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
227    if tokenizer.pad_token_id is None:
228        tokenizer.pad_token_id = tokenizer.eos_token_id
229    tokenizer.padding_side = "left"  # For static KV caching
230    tokenizer.truncation_side = "left"
231
232    prompt = prepare_prompt(tokenizer, prompt, system_prompt)
233
234    outputs = infer_with_compilation(
235        prompt=prompt,
236        model=model,
237        tokenizer=tokenizer,
238        max_length=max_length,
239        max_new_tokens=max_new_tokens,
240        compile_prefill=compile_prefill,
241        compile_decode=compile_decode,
242        device_name=device_name,
243        outdir=outdir,
244        check_intermediate_outputs=check_intermediate_outputs,
245        prepare_attention_mask_on_cpu=prepare_attention_mask_on_cpu,
246        disable_cache=disable_cache,
247        num_compiler_threads=num_compiler_threads,
248    )
249    print(
250        "=========== Generated with compilation ==========\n",
251        tokenizer.decode(outputs[0]),
252    )
253
254    outputs_expected = infer_with_generate(prompt, model, tokenizer, max_new_tokens)
255    print(
256        "========== Generated with model.generate ==========\n",
257        tokenizer.decode(outputs_expected[0]),
258    )
259
260    # @todo (hvy): Do not rely on `max_new_tokens` tokens always being generated?
261    assert torch.equal(
262        outputs[:, -max_new_tokens:], outputs_expected[:, -max_new_tokens:]
263    ), "Outputs differed. Check generated outputs above."
264    print("Generated outputs matched.")
265
266
267if __name__ == "__main__":
268    parser = argparse.ArgumentParser()
269    parser.add_argument(
270        "--prompt",
271        type=str,
272        default='The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens. With some proper optimization, we can achieve this within a span of "just" 90 days using 16 A100-40G GPUs 🚀🚀. The training has started on 2023-09-01.',  # NOQA
273    )
274    parser.add_argument(
275        "--system_prompt",
276        type=str,
277        default="You are a friendly chatbot who is an expert on MN-Core.",
278    )
279    parser.add_argument(
280        "--model_name", type=str, default="TinyLlama/TinyLlama-1.1B-Chat-v1.0"
281    )
282    parser.add_argument("--max_length", type=int, default=256)
283    parser.add_argument("--max_new_tokens", type=int, default=64)
284    parser.add_argument("--num_compiler_threads", type=int, default=-1)
285    parser.add_argument("--compile_prefill", action="store_true")
286    parser.add_argument("--compile_decode", action="store_true")
287    parser.add_argument("--device_name", type=str, default="mncore2:auto")
288    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_llm_infer")
289    parser.add_argument("--check_intermediate_outputs", action="store_true")
290    parser.add_argument("--prepare_attention_mask_on_cpu", action="store_true")
291    parser.add_argument("--disable_cache", action="store_true")
292    args = parser.parse_args()
293    main(args)