Skip to content
Snippets Groups Projects
Select Git revision
  • 3cb067e9fda8b3e244f228f2823a2182255821a9
  • main default protected
2 results

train_model.py

Blame
  • train_model.py 5.44 KiB
    import argparse
    import os, sys
    import time
    import typing
    
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)
    
        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)
            output = F.log_softmax(x, dim=1)
            return output
    
    def parse_command_line():
        parser = argparse.ArgumentParser()    
        parser.add_argument("--device", required=False, type=str, choices=['cpu', 'cuda'], default="cuda")
        parser.add_argument("--num_epochs", required=False, type=int, default=2)
        parser.add_argument("--batch_size", required=False, type=int, default=128)
        parser.add_argument("--num_workers", required=False, type=int, default=1)
        args = parser.parse_args()
        return args
    
    def load_dataset(args):
        # define the following transformations
        # - transform back to tensor
        # - normalize data using mean and std
        trans = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        
        # load MNIST dataset splits for train and test and apply the transformations
        # note: you need to download the dataset once at the beginning
        ds_train = MNIST("datasets", train=True,  download=True,  transform=trans)
        ds_test  = MNIST("datasets", train=False, download=False, transform=trans)
        
        # finally create separate data loaders for train and test with the following common arguments
        common_kwargs = {"batch_size": args.batch_size, "num_workers": args.num_workers, "pin_memory": True}
        loader_train = DataLoader(ds_train, **(common_kwargs))
        loader_test  = DataLoader(ds_test,  **(common_kwargs))
        
        return loader_train, loader_test
    
    def train(args, model, loader_train, optimizer, epoch):
        # use a CrossEntropyLoss loss function
        loss_func = torch.nn.CrossEntropyLoss()
    
        # set model into train mode
        model.train()
        
        # track accuracy for complete epoch
        total, correct = 0, 0
        total_steps = len(loader_train)
        
        elapsed_time = time.time()
        for i, (x_batch, y_batch) in enumerate(loader_train):
            # transfer data to the device
            x_batch = x_batch.to(args.device, non_blocking=True)
            y_batch = y_batch.to(args.device, non_blocking=True)
            
            # run forward pass
            y_pred = model(x_batch)
            
            # calculate loss
            loss = loss_func(y_pred, y_batch)
            
            # run backward pass and optimizer to update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # track training accuracy
            _, predicted = y_pred.max(1)
            total += y_batch.size(0)
            correct += predicted.eq(y_batch).sum().item()
            
            if i % 20 == 0:
                print(f"Epoch {epoch+1}/{args.num_epochs}\tStep {i:4d} / {total_steps:4d}")
                sys.stdout.flush()
        elapsed_time = time.time() - elapsed_time
    
        print(f"Epoch {epoch+1}/{args.num_epochs}\tElapsed: {elapsed_time:.3f} sec\tAcc: {(correct/total):.3f}")
        sys.stdout.flush()
    
    def test(args, model, loader_test, epoch):
        # set model into evaluation mode
        model.eval()
        
        with torch.no_grad():
            correct, total = 0, 0
            for i, (x_batch, y_batch) in enumerate(loader_test):
                # transfer data to the device
                x_batch = x_batch.to(args.device, non_blocking=True)
                y_batch = y_batch.to(args.device, non_blocking=True)
                
                # predict class
                outputs = model(x_batch)
                _, predicted = outputs.max(1)
                
                # track test accuracy
                total += y_batch.size(0)
                correct += (predicted == y_batch).sum().item()
            
            print(f"Epoch {epoch+1}/{args.num_epochs}\tTest Acc: {(correct/total):.3f}")
            sys.stdout.flush()
    
    def setup(args) -> None:
    
        # set gpu device on local machine
        if args.device == 'cuda':
            # optimization hint for torch runtime
            torch.backends.cudnn.benchmark = True
    
        print("Current configuration:")
        for arg in vars(args):
            print(f"  --{arg}, {getattr(args, arg)}")
    
    def cleanup(args: typing.Dict[str, typing.Any]):
        pass
    
    def main():
        # parse command line arguments
        args = parse_command_line()
    
        # run setup (e.g., create distributed environment if desired)
        setup(args)
        
        # get data loaders for train and test split
        loader_train, loader_test = load_dataset(args)
        
        # create model with random weights
        model = Net().to(args.device)
    
        # initialize optimizer with model parameters
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
        # train and test model for configured number of epochs
        for epoch in range(args.num_epochs):
            train(args, model, loader_train, optimizer, epoch)
            test (args, model, loader_test, epoch)
    
        # cleaup env
        cleanup(args)
    
    if __name__ == "__main__":
        main()