Select Git revision
train_model.py
train_model.py 5.44 KiB
import argparse
import os, sys
import time
import typing
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def parse_command_line():
parser = argparse.ArgumentParser()
parser.add_argument("--device", required=False, type=str, choices=['cpu', 'cuda'], default="cuda")
parser.add_argument("--num_epochs", required=False, type=int, default=2)
parser.add_argument("--batch_size", required=False, type=int, default=128)
parser.add_argument("--num_workers", required=False, type=int, default=1)
args = parser.parse_args()
return args
def load_dataset(args):
# define the following transformations
# - transform back to tensor
# - normalize data using mean and std
trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# load MNIST dataset splits for train and test and apply the transformations
# note: you need to download the dataset once at the beginning
ds_train = MNIST("datasets", train=True, download=True, transform=trans)
ds_test = MNIST("datasets", train=False, download=False, transform=trans)
# finally create separate data loaders for train and test with the following common arguments
common_kwargs = {"batch_size": args.batch_size, "num_workers": args.num_workers, "pin_memory": True}
loader_train = DataLoader(ds_train, **(common_kwargs))
loader_test = DataLoader(ds_test, **(common_kwargs))
return loader_train, loader_test
def train(args, model, loader_train, optimizer, epoch):
# use a CrossEntropyLoss loss function
loss_func = torch.nn.CrossEntropyLoss()
# set model into train mode
model.train()
# track accuracy for complete epoch
total, correct = 0, 0
total_steps = len(loader_train)
elapsed_time = time.time()
for i, (x_batch, y_batch) in enumerate(loader_train):
# transfer data to the device
x_batch = x_batch.to(args.device, non_blocking=True)
y_batch = y_batch.to(args.device, non_blocking=True)
# run forward pass
y_pred = model(x_batch)
# calculate loss
loss = loss_func(y_pred, y_batch)
# run backward pass and optimizer to update weights
optimizer.zero_grad()
loss.backward()
optimizer.step()
# track training accuracy
_, predicted = y_pred.max(1)
total += y_batch.size(0)
correct += predicted.eq(y_batch).sum().item()
if i % 20 == 0:
print(f"Epoch {epoch+1}/{args.num_epochs}\tStep {i:4d} / {total_steps:4d}")
sys.stdout.flush()
elapsed_time = time.time() - elapsed_time
print(f"Epoch {epoch+1}/{args.num_epochs}\tElapsed: {elapsed_time:.3f} sec\tAcc: {(correct/total):.3f}")
sys.stdout.flush()
def test(args, model, loader_test, epoch):
# set model into evaluation mode
model.eval()
with torch.no_grad():
correct, total = 0, 0
for i, (x_batch, y_batch) in enumerate(loader_test):
# transfer data to the device
x_batch = x_batch.to(args.device, non_blocking=True)
y_batch = y_batch.to(args.device, non_blocking=True)
# predict class
outputs = model(x_batch)
_, predicted = outputs.max(1)
# track test accuracy
total += y_batch.size(0)
correct += (predicted == y_batch).sum().item()
print(f"Epoch {epoch+1}/{args.num_epochs}\tTest Acc: {(correct/total):.3f}")
sys.stdout.flush()
def setup(args) -> None:
# set gpu device on local machine
if args.device == 'cuda':
# optimization hint for torch runtime
torch.backends.cudnn.benchmark = True
print("Current configuration:")
for arg in vars(args):
print(f" --{arg}, {getattr(args, arg)}")
def cleanup(args: typing.Dict[str, typing.Any]):
pass
def main():
# parse command line arguments
args = parse_command_line()
# run setup (e.g., create distributed environment if desired)
setup(args)
# get data loaders for train and test split
loader_train, loader_test = load_dataset(args)
# create model with random weights
model = Net().to(args.device)
# initialize optimizer with model parameters
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# train and test model for configured number of epochs
for epoch in range(args.num_epochs):
train(args, model, loader_train, optimizer, epoch)
test (args, model, loader_test, epoch)
# cleaup env
cleanup(args)
if __name__ == "__main__":
main()