Skip to main content

bulk-inference-phi-3-mini-128k

This notebook demonstrates how to perform bulk inference of Phi-3-mini-128k-instruct on the Tracto.ai platform.

Tracto is perfect for offline batch inference:

  • Easy scaling - just change the job_count parameter.
  • Fault tolerance out of the box - if a job crashes, it gets restarted automatically. No need to handle GPU failures or infrastructure issues - Tracto takes care of it.
  • Full Tracto integration - save results directly to the distributed file system and process them further on the platform.
import yt.wrapper as yt
import uuid

username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
home = yt.get(f"//sys/users/{username}/@user_info/home_path")
working_dir = f"{home}/tmp/{uuid.uuid4().hex}"
else:
working_dir = f"//tmp/examples/{uuid.uuid4().hex}"
yt.create("map_node", working_dir, ignore_existing=True)
print(working_dir)

Prepare data for inference as a table.

from datasets import load_dataset

dataset = load_dataset("Rapidata/Other-Animals-10")

table_path = f"{working_dir}/questions"
yt.create("table", table_path, force=True)

questions = [
{"question": f"Can a {animal} fly? If it can't - how can we help it take off?"}
for animal in set(dataset["train"].features["label"].int2str(dataset["train"]["label"]))
]

yt.write_table(table_path, questions)

Run bulk inference of Phi-3-mini.

from typing import Iterable
import logging
import sys
import random

yt.config["pickling"]["safe_stream_mode"] = False


result_path = f"{working_dir}/result"


@yt.aggregator
def bulk_inference(records: Iterable[dict[str, str]]) -> dict[str, str]:
from vllm import LLM, SamplingParams

llm = LLM(model="microsoft/Phi-3-mini-128k-instruct", trust_remote_code=True)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.9,
max_tokens=32000,
)

conversations = [
[
{
"role": "user",
"content": record["question"],
},
]
for record in records
]
outputs = llm.chat(
messages=conversations,
sampling_params=sampling_params,
)
for prompt, output in zip(conversations, outputs):
yield {
"prompt": prompt,
"text": output.outputs[0].text,
}


yt.run_map(
bulk_inference,
table_path,
result_path,
job_count=2,
spec={
"pool_trees": ["gpu_h100"],
"mapper": {
"gpu_limit": 1,
"memory_limit": 32212254720,
"cpu_limit": 20,
},
},
)
for record in yt.read_table(result_path):
print(f"Q: {record['prompt']}")
print(f"A: {record['text']}")