Skip to main content

bulk-inference-deepseek-r1-zero

This notebook demonstrates how to perform bulk inference of DeepSeek R1 Zero on the Tracto.ai platform.

import yt.wrapper as yt
import uuid
yt.config["pickling"]["dynamic_libraries"]["enable_auto_collection"] = False
yt.config["pickling"]["ignore_system_modules"] = True
yt.config["pickling"]["safe_stream_mode"] = False
username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
home = yt.get(f"//sys/users/{username}/@user_info/home_path")
working_dir = f"{home}/{uuid.uuid4().hex}"
else:
working_dir = f"//tmp/examples/{uuid.uuid4().hex}"
yt.create("map_node", working_dir)
print(working_dir)

Prepare data for inference as an YTSaurus table.

from datasets import load_dataset

dataset = load_dataset("Rapidata/Other-Animals-10")

table_path = f"{working_dir}/questions"
yt.create("table", table_path, force=True)

questions = [
{"question": f"Can {animal} fly?"}
for animal in set(dataset["train"].features["label"].int2str(dataset["train"]["label"]))
]

yt.write_table(table_path, questions)

Run bulk inference of DeepSeek R1 Zero on 2 nodes.

from typing import Iterable
import logging
import sys
import random

@yt.aggregator
def bulk_inference(records: Iterable[dict[str, str]]) -> dict[str, str]:
from vllm import LLM, SamplingParams

# yt job have to write all logs to stderr
vllm_logger = logging.getLogger("vllm")
vllm_logger.handlers.clear()
vllm_logger.addHandler(logging.StreamHandler(sys.stderr))

llm = LLM(model="deepseek-ai/DeepSeek-R1-Zero", tensor_parallel_size=8, seed=random.randint(0, 1000000), trust_remote_code=True)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.9,
max_tokens=32000,
)

conversations = [
[
{
"role": "user",
"content": record["question"],
},
]
for record in records
]
outputs = llm.chat(
messages=conversations,
sampling_params=sampling_params,
)
for output in outputs:
yield {
"prompt": output.prompt,
"text": output.outputs[0].text,
}
result_path = f"{working_dir}/result"

# WARNING: on playground you have only 1 host with only one H100
# for DeepSeek inference, at least 1 GPU at the H200 level is required

yt.run_map(
bulk_inference,
table_path,
result_path,
job_count=2,
spec={
"pool": "fifo",
"pool_trees": ["gpu_h200"],
"mapper": {
"gpu_limit": 8,
"memory_limit": 322122547200,
"cpu_limit": 64,
},
},
)
for record in yt.read_table(result_path):
print(record)