Skip to content
Snippets Groups Projects
Commit 9d6350b1 authored by andres's avatar andres
Browse files

adding autocast

parent f337a95c
Branches
Tags
No related merge requests found
name: basic name: basic
train: train:
_target_: torch.utils.data.DataLoader _target_: torch.utils.data.DataLoader
batch_size: 64 batch_size: 256 #64
shuffle: True shuffle: True
num_workers: 2 num_workers: 2
pin_memory: True pin_memory: True
......
# Training configuration for the model model # Training configuration for the model model
name: basic # name of the training configuration name: basic # name of the training configuration
use_cuda: True # Default: True, flag to enable CUDA training. seed: 1
use_mps: True # Default: True, flag to enable macOS GPU training. device: cuda
dry_run: false # Perform a dry run (do not update weights) (bool)
seed: 1 # Seed for random number generation (int)
save_model: true # Whether to save the model to disk (bool)
loss: loss:
_target_: torch.nn.NLLLoss _target_: torch.nn.NLLLoss
optimizer: optimizer:
...@@ -15,19 +12,18 @@ scheduler: ...@@ -15,19 +12,18 @@ scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
mode: 'min' mode: 'min'
factor: 0.1 factor: 0.1
patience: 10 patience: 3
scheduler_monitor: train_loss_epoch min_epochs: 10
min_epochs: 50
max_epochs: 300 max_epochs: 300
early_stopping_config: early_stopping_config:
monitor: valid_acc_epoch monitor: valid_acc_epoch
min_delta: 0.001 min_delta: 0.001
patience: 20 patience: 10
verbose: False verbose: False
mode: max mode: max
gradient_clip_val: 1 gradient_clip_val: 1
metrics: ['acc', 'f1'] metrics: ['acc', 'f1']
gradient_accumulation_steps: 1 gradient_accumulation_steps: 5
loggers: loggers:
tensorboard: tensorboard:
_target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
......
...@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader ...@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
from modules.utils import MetricAggregator, EarlyStoppingCustom from modules.utils import MetricAggregator, EarlyStoppingCustom
import logging import logging
def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig, device="cpu"): def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig):
""" """
Trains and evaluates a neural network model. Trains and evaluates a neural network model.
...@@ -27,7 +27,7 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ...@@ -27,7 +27,7 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
############################## ##############################
# Device Setup # Device Setup
############################## ##############################
device = torch.device(cfg.device) device = "cuda" if (cfg.device=="cuda" and torch.cuda.is_available()) else "cpu"
model.to(device) model.to(device)
############################## ##############################
...@@ -58,6 +58,9 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ...@@ -58,6 +58,9 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
# Early stopping and checkpointing # Early stopping and checkpointing
early_stopping = EarlyStoppingCustom(**cfg.early_stopping_config) early_stopping = EarlyStoppingCustom(**cfg.early_stopping_config)
# scaler for automatic mixed precision
scaler = torch.cuda.amp.GradScaler()
############################## ##############################
# Epoch Loop # Epoch Loop
############################## ##############################
...@@ -67,24 +70,30 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ...@@ -67,24 +70,30 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
############################## ##############################
model.train() model.train()
lr = torch.tensor(optimizer.param_groups[0]['lr']) lr = torch.tensor(optimizer.param_groups[0]['lr'])
for batch_idx, batch in enumerate(train_loader): for batch_idx, (x, y) in enumerate(train_loader):
x, y = batch
x, y = x.to(device), y.to(device) x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device):
out = model(x) out = model(x)
logprob = F.log_softmax(out, dim=1) logprob = F.log_softmax(out, dim=1)
y_hat_prob = torch.exp(logprob) y_hat_prob = torch.exp(logprob)
loss = criterion(logprob, y) loss = criterion(logprob, y)
loss.backward() loss = loss / cfg.gradient_accumulation_steps
# Accumulates scaled gradients.
scaler.scale(loss).backward()
# Gradient accumulation # Gradient accumulation
if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
optimizer.step() scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
# Update metrics # Update metrics
with torch.autocast(device_type=device):
metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train") metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train")
# Compute and log metrics # Compute and log metrics
with torch.autocast(device_type=device):
train_results = metric_aggregator.compute(phase="train") train_results = metric_aggregator.compute(phase="train")
logger.info(f"Epoch {epoch+1} Train: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('train_','') for k,v in train_results.items() if isinstance(v,float)])}") logger.info(f"Epoch {epoch+1} Train: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('train_','') for k,v in train_results.items() if isinstance(v,float)])}")
...@@ -93,18 +102,20 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ...@@ -93,18 +102,20 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
############################## ##############################
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
for batch in valid_loader: for batch_idx, (x, y) in enumerate(valid_loader):
x, y = batch
x, y = x.to(device), y.to(device) x, y = x.to(device), y.to(device)
with torch.autocast(device_type=device):
out = model(x) out = model(x)
logprob = F.log_softmax(out, dim=1) logprob = F.log_softmax(out, dim=1)
y_hat_prob = torch.exp(logprob) y_hat_prob = torch.exp(logprob)
val_loss = criterion(logprob, y) val_loss = criterion(logprob, y)
# Update metrics # Update metrics
with torch.autocast(device_type=device):
metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid") metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid")
# Compute and log metrics # Compute and log metrics
with torch.autocast(device_type=device):
valid_results = metric_aggregator.compute(phase="valid") valid_results = metric_aggregator.compute(phase="valid")
logger.info(f"Epoch {epoch+1} Valid: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('valid_','') for k,v in valid_results.items() if isinstance(v,float)])}") logger.info(f"Epoch {epoch+1} Valid: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('valid_','') for k,v in valid_results.items() if isinstance(v,float)])}")
......
...@@ -41,10 +41,10 @@ class MetricAggregator: ...@@ -41,10 +41,10 @@ class MetricAggregator:
self.init_agg(phase=phase, metric=metric) self.init_agg(phase=phase, metric=metric)
for metric in ["acc", "cm", "f1"]: for metric in ["acc", "cm", "f1"]:
if metric in self.metrics: if metric in self.metrics:
self.aggregators[phase][metric](y_hat_prob.to(self.device), y.to(self.device)) self.aggregators[phase][metric].update(y_hat_prob.to(self.device), y.to(self.device))
for k, v in kwargs.items(): for k, v in kwargs.items():
if k in self.aggregators[phase]: if k in self.aggregators[phase]:
self.aggregators[phase][k](v.to(self.device)) self.aggregators[phase][k].update(v.to(self.device))
for logger in self.loggers: for logger in self.loggers:
logger.log_metrics({f"{phase}_{k}_step":v.detach().cpu().tolist() for k,v in kwargs.items()}, step=self.step_num) logger.log_metrics({f"{phase}_{k}_step":v.detach().cpu().tolist() for k,v in kwargs.items()}, step=self.step_num)
self.step_num+=1 self.step_num+=1
......
...@@ -91,15 +91,11 @@ def main(cfg: DictConfig) -> None: ...@@ -91,15 +91,11 @@ def main(cfg: DictConfig) -> None:
model_save_dir = Path(cfg.path.base_path_models) / cfg.path.results model_save_dir = Path(cfg.path.base_path_models) / cfg.path.results
model_save_dir.mkdir(parents=True, exist_ok=True) model_save_dir.mkdir(parents=True, exist_ok=True)
# Determine if CUDA or MPS should be used based on configuration and availability
use_cuda = cfg.training.use_cuda and torch.cuda.is_available()
use_mps = cfg.training.use_mps and torch.backends.mps.is_available()
# Set the random seed for reproducibility # Set the random seed for reproducibility
seed_everything(cfg.training.seed) seed_everything(cfg.training.seed)
# Select the device for computation (CUDA, MPS, or CPU) # Select the device for computation (CUDA, MPS, or CPU)
device = torch.device("cuda") if use_cuda else torch.device("mps") if use_mps else torch.device("cpu") device = "cuda" if (cfg.training.device=="cuda" and torch.cuda.is_available()) else "cpu"
############################## ##############################
# Object Instantiation # Object Instantiation
...@@ -125,29 +121,16 @@ def main(cfg: DictConfig) -> None: ...@@ -125,29 +121,16 @@ def main(cfg: DictConfig) -> None:
############################## ##############################
# Training loop # Training loop
train_model(model, train_loader, test_loader, cfg.training, device) train_model(model, train_loader, test_loader, cfg.training)
############################## ##############################
# Saving Results # Saving Results
############################## ##############################
# Save the model checkpoint if configured to do so # Save the model checkpoint if configured to do so
if cfg.training.save_model: model_path = model_save_dir / f"checkpoint.ckpt"
model_path = model_save_dir / f"checkpoint.pt"
torch.save(model.state_dict(), model_path) torch.save(model.state_dict(), model_path)
# Save the result dictionary
results = {
"model_name": cfg.model.name,
"training_name": cfg.training.name,
"epochs": cfg.training.epochs,
"seed": cfg.training.seed,
"final_model_path": str(model_path),
"timestamp": time.time()
}
with open(model_save_dir / "results.json", "w") as f:
json.dump(results, f, indent=4)
# Save the configuration # Save the configuration
config_path = model_save_dir / "config.yaml" config_path = model_save_dir / "config.yaml"
with open(config_path, "w") as f: with open(config_path, "w") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment