bulk-inference-phi-3-mini-128k
This notebook demonstrates how to perform bulk inference of Phi-3-mini-128k-instruct on the Tracto.ai platform.
Tracto is perfect for offline batch inference:
- Easy scaling - just change the
job_countparameter. - Fault tolerance out of the box - if a job crashes, it gets restarted automatically. No need to handle GPU failures or infrastructure issues - Tracto takes care of it.
- Full Tracto integration - save results directly to the distributed file system and process them further on the platform.
import yt.wrapper as yt
import uuid
username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
home = yt.get(f"//sys/users/{username}/@user_info/home_path")
working_dir = f"{home}/tmp/{uuid.uuid4().hex}"
else:
working_dir = f"//tmp/examples/{uuid.uuid4().hex}"
yt.create("map_node", working_dir, ignore_existing=True)
print(working_dir)
Prepare data for inference as a table.
from datasets import load_dataset
dataset = load_dataset("Rapidata/Other-Animals-10")
table_path = f"{working_dir}/questions"
yt.create("table", table_path, force=True)
questions = [
{"question": f"Can a {animal} fly? If it can't - how can we help it take off?"}
for animal in set(dataset["train"].features["label"].int2str(dataset["train"]["label"]))
]
yt.write_table(table_path, questions)
Run bulk inference of Phi-3-mini.
from typing import Iterable
import logging
import sys
import random
yt.config["pickling"]["safe_stream_mode"] = False
result_path = f"{working_dir}/result"
@yt.aggregator
def bulk_inference(records: Iterable[dict[str, str]]) -> dict[str, str]:
from vllm import LLM, SamplingParams
llm = LLM(model="microsoft/Phi-3-mini-128k-instruct", trust_remote_code=True)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.9,
max_tokens=32000,
)
conversations = [
[
{
"role": "user",
"content": record["question"],
},
]
for record in records
]
outputs = llm.chat(
messages=conversations,
sampling_params=sampling_params,
)
for prompt, output in zip(conversations, outputs):
yield {
"prompt": prompt,
"text": output.outputs[0].text,
}
yt.run_map(
bulk_inference,
table_path,
result_path,
job_count=2,
spec={
"pool_trees": ["gpu_h100"],
"mapper": {
"gpu_limit": 1,
"memory_limit": 32212254720,
"cpu_limit": 20,
},
},
)
for record in yt.read_table(result_path):
print(f"Q: {record['prompt']}")
print(f"A: {record['text']}")