Skip to main content

Tractorun advanced

This notebook provides an extended demonstration of the advanced capabilities of the tractorun library. The focus will be on two key features:

  1. Checkpoints for PyTorch: how to save a checkpoint and restore training from a checkpoint.
  2. Distributed Model Training: how to run distributed training by tractorun on multiple nodes with multiple processes.

For a basic example, please refer to tractorun-torch-mnist.

import uuid
import sys

from yt import wrapper as yt
from yt import type_info

Create a base directory for examples

working_dir = f"//tmp/examples/tractorun-mnist-advanced_{uuid.uuid4()}"
yt.create("map_node", working_dir, recursive=True)
print(working_dir)

Ensure torch and torchvision exist

Let's ensure that the system has installed torch and torchvision.

import torch
import torchvision

Run distributed training

Let's use MNIST dataset. This process of uploading data is described in the basic tractorun notebook (TODO(max) -- link)

dataset_train_path = "//home/samples/mnist-torch-train"
dataset_test_path = "//home/samples/mnist-torch-test"
username = yt.get_user_name()
if yt.exists(f"//sys/users/{username}/@user_info/home_path"):
home = yt.get(f"//sys/users/{username}/@user_info/home_path")
working_dir = f"{home}/{uuid.uuid4().hex}"
else:
working_dir = f"//tmp/examples/{uuid.uuid4().hex}"
yt.create("map_node", working_dir)
print(working_dir)

In order to run tractorun in distributed mode and using checkpoints:

  1. Use toolbox.checkpoint_manager to manage checkpoints.
  2. Set distributed training configuration by tractorun.mesh.Mesh

<details> <summary>Show the full diff</summary>

@@ -6,7 +6,15 @@
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

+from tractorun.backend.tractorch import YtTensorDataset, Tractorch
+from tractorun.toolbox import Toolbox
+from tractorun.run import run
+from tractorun.mesh import Mesh
+from tractorun.resources import Resources
+from tractorun.stderr_reader import StderrMode
+from tractorun.backend.tractorch.serializer import TensorSerializer

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
@@ -33,9 +41,12 @@
return output


-def train(args, model, device, train_loader, optimizer, epoch):
+def train(args, model, device, train_loader, optimizer, epoch, first_batch_index, checkpoint_manager):
model.train()
+ ts = TensorSerializer()
for batch_idx, (data, target) in enumerate(train_loader):
+ if batch_idx \< first_batch_index:
+ continue
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
@@ -45,9 +56,18 @@
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
- 100. * batch_idx / len(train_loader), loss.item()))
+ 100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
if args.dry_run:
break
+ state_dict = {
+ "model": model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ metadata_dict = {
+ "first_batch_index": batch_idx + 1,
+ "loss": loss.item(),
+ }
+ checkpoint_manager.save_checkpoint(ts.serialize(state_dict), metadata_dict)


def test(model, device, test_loader):
@@ -66,10 +86,10 @@

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
- 100. * correct / len(test_loader.dataset)))
+ 100. * correct / len(test_loader.dataset)), file=sys.stderr)


-def main():
+def main(toolbox: Toolbox):
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
@@ -94,7 +114,7 @@
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
- args = parser.parse_args()
+ args = parser.parse_args([])
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

@@ -120,26 +140,48 @@
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
- dataset1 = datasets.MNIST('../data', train=True, download=True,
- transform=transform)
- dataset2 = datasets.MNIST('../data', train=False,
- transform=transform)
+ dataset1 = YtTensorDataset(toolbox=toolbox, path=dataset_train_path, columns=['data', 'labels'])
+ dataset2 = YtTensorDataset(toolbox=toolbox, path=dataset_test_path, columns=['data', 'labels'])
+
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

+ ts = TensorSerializer()
+ first_batch_index = 0
+ checkpoint = toolbox.checkpoint_manager.get_last_checkpoint()
+ if checkpoint is not None:
+ first_batch_index = checkpoint.metadata["first_batch_index"]
+ print(
+ "Found checkpoint with index",
+ checkpoint.index,
+ "and first batch index",
+ first_batch_index,
+ file=sys.stderr,
+ )
+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
- train(args, model, device, train_loader, optimizer, epoch)
+ train(args, model, device, train_loader, optimizer, epoch, first_batch_index, toolbox.checkpoint_manager)
test(model, device, test_loader)
scheduler.step()

if args.save_model:
- torch.save(model.state_dict(), "mnist_cnn.pt")
+ toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


-if __name__ == '__main__':
- main()
+run(
+ main,
+ backend=Tractorch(),
+ yt_path=training_dir,
+ mesh=Mesh(node_count=2, process_per_node=2, gpu_per_process=0),
+ resources=Resources(
+ cpu_limit=8,
+ memory_limit=105899345920,
+ ),
+ proxy_stderr_mode=StderrMode.primary,
+)

</details>

<font color="red">IMPORTANT NOTE</font> In this example we are running tractorun directly from Jupyter notebook.

This is a convenient method for experiments and demonstrations, as tractorun uses pickle for easy serialization of the entire notebook state and transferring it to the cluster. This means that all variables will be available in the model training function, and tractorun will attempt to transfer all Python modules from the local environment to the cluster.

However, this method does not ensure reproducibility of the run of model's training. For production processes, use the execution via the tractorun CLI, which is described in basic notebook.

Let's run training on 2 hosts with 2 processes. This is not the most optimal configuration, but it demonstrates the tractor's capabilities in a simple way.

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

from tractorun.backend.tractorch import YtTensorDataset, Tractorch
from tractorun.toolbox import Toolbox
from tractorun.run import run
from tractorun.mesh import Mesh
from tractorun.resources import Resources
from tractorun.stderr_reader import StderrMode
from tractorun.backend.tractorch.serializer import TensorSerializer

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output


def train(args, model, device, train_loader, optimizer, epoch, first_batch_index, checkpoint_manager):
model.train()
ts = TensorSerializer()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx < first_batch_index:
continue
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()), file=sys.stderr)
if args.dry_run:
break
state_dict = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
metadata_dict = {
"first_batch_index": batch_idx + 1,
"loss": loss.item(),
}
checkpoint_manager.save_checkpoint(ts.serialize(state_dict), metadata_dict)


def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)), file=sys.stderr)


def main(toolbox: Toolbox):
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=14, metavar='N',
help='number of epochs to train (default: 14)')
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
help='learning rate (default: 1.0)')
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
help='Learning rate step gamma (default: 0.7)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--no-mps', action='store_true', default=False,
help='disables macOS GPU training')
parser.add_argument('--dry-run', action='store_true', default=False,
help='quickly check a single pass')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=False,
help='For Saving the current Model')
args = parser.parse_args([])
use_cuda = not args.no_cuda and torch.cuda.is_available()
use_mps = not args.no_mps and torch.backends.mps.is_available()

torch.manual_seed(args.seed)

if use_cuda:
device = torch.device("cuda")
elif use_mps:
device = torch.device("mps")
else:
device = torch.device("cpu")

train_kwargs = {'batch_size': args.batch_size}
test_kwargs = {'batch_size': args.test_batch_size}
if use_cuda:
cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)

transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset1 = YtTensorDataset(yt_client=toolbox.yt_client, path=dataset_train_path, columns=['data', 'labels'])
dataset2 = YtTensorDataset(yt_client=toolbox.yt_client, path=dataset_test_path, columns=['data', 'labels'])

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

ts = TensorSerializer()
first_batch_index = 0
checkpoint = toolbox.checkpoint_manager.get_last_checkpoint()
if checkpoint is not None:
first_batch_index = checkpoint.metadata["first_batch_index"]
print(
"Found checkpoint with index",
checkpoint.index,
"and first batch index",
first_batch_index,
file=sys.stderr,
)
checkpoint_dict = serializer.desirialize(checkpoint.value)
model.load_state_dict(checkpoint_dict["model"])
optimizer.load_state_dict(checkpoint_dict["optimizer"])

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch, first_batch_index, toolbox.checkpoint_manager)
test(model, device, test_loader)
scheduler.step()

if args.save_model:
toolbox.save_model(ts.serialize(model.state_dict()), dataset_train_path, metadata={})


run(
main,
backend=Tractorch(),
yt_path=training_dir,
mesh=Mesh(node_count=1, process_per_node=1, gpu_per_process=0),
resources=Resources(
cpu_limit=8,
memory_limit=55899345920,
),
proxy_stderr_mode=StderrMode.primary,
)