diff --git a/data/config/loader/basic.yaml b/data/config/loader/basic.yaml index 52896f468e558432bf8031223a9d08238665cc15..e47c69c0944b906bbb795a944efaa94f61ddd886 100644 --- a/data/config/loader/basic.yaml +++ b/data/config/loader/basic.yaml @@ -1,7 +1,7 @@ name: basic train: _target_: torch.utils.data.DataLoader - batch_size: 64 + batch_size: 256 #64 shuffle: True num_workers: 2 pin_memory: True diff --git a/data/config/training/basic.yaml b/data/config/training/basic.yaml index a1f20749cc57debab29510909c7e6e22dd6e270a..5a1a747301d532076778db0eb88eb9b01d1aa1cb 100644 --- a/data/config/training/basic.yaml +++ b/data/config/training/basic.yaml @@ -1,10 +1,7 @@ # Training configuration for the model model name: basic # name of the training configuration -use_cuda: True # Default: True, flag to enable CUDA training. -use_mps: True # Default: True, flag to enable macOS GPU training. -dry_run: false # Perform a dry run (do not update weights) (bool) -seed: 1 # Seed for random number generation (int) -save_model: true # Whether to save the model to disk (bool) +seed: 1 +device: cuda loss: _target_: torch.nn.NLLLoss optimizer: @@ -15,19 +12,18 @@ scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau mode: 'min' factor: 0.1 - patience: 10 -scheduler_monitor: train_loss_epoch -min_epochs: 50 + patience: 3 +min_epochs: 10 max_epochs: 300 early_stopping_config: monitor: valid_acc_epoch min_delta: 0.001 - patience: 20 + patience: 10 verbose: False mode: max gradient_clip_val: 1 metrics: ['acc', 'f1'] -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 5 loggers: tensorboard: _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger diff --git a/modules/training/training.py b/modules/training/training.py index d57d1f4d4c779f0fa8dbc6c616dae3a67f091f39..a9cedbd1ee37ddb25283c1663ab87a6141a2405f 100644 --- a/modules/training/training.py +++ b/modules/training/training.py @@ -9,7 +9,7 @@ from torch.utils.data import DataLoader from modules.utils import MetricAggregator, EarlyStoppingCustom import logging -def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig, device="cpu"): +def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig): """ Trains and evaluates a neural network model. @@ -27,7 +27,7 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ############################## # Device Setup ############################## - device = torch.device(cfg.device) + device = "cuda" if (cfg.device=="cuda" and torch.cuda.is_available()) else "cpu" model.to(device) ############################## @@ -58,6 +58,9 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo # Early stopping and checkpointing early_stopping = EarlyStoppingCustom(**cfg.early_stopping_config) + # scaler for automatic mixed precision + scaler = torch.cuda.amp.GradScaler() + ############################## # Epoch Loop ############################## @@ -67,25 +70,31 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ############################## model.train() lr = torch.tensor(optimizer.param_groups[0]['lr']) - for batch_idx, batch in enumerate(train_loader): - x, y = batch + for batch_idx, (x, y) in enumerate(train_loader): x, y = x.to(device), y.to(device) - out = model(x) - logprob = F.log_softmax(out, dim=1) - y_hat_prob = torch.exp(logprob) - loss = criterion(logprob, y) - loss.backward() + with torch.autocast(device_type=device): + out = model(x) + logprob = F.log_softmax(out, dim=1) + y_hat_prob = torch.exp(logprob) + loss = criterion(logprob, y) + loss = loss / cfg.gradient_accumulation_steps + + # Accumulates scaled gradients. + scaler.scale(loss).backward() # Gradient accumulation if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): - optimizer.step() + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() # Update metrics - metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train") + with torch.autocast(device_type=device): + metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train") # Compute and log metrics - train_results = metric_aggregator.compute(phase="train") + with torch.autocast(device_type=device): + train_results = metric_aggregator.compute(phase="train") logger.info(f"Epoch {epoch+1} Train: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('train_','') for k,v in train_results.items() if isinstance(v,float)])}") ############################## @@ -93,19 +102,21 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo ############################## model.eval() with torch.no_grad(): - for batch in valid_loader: - x, y = batch + for batch_idx, (x, y) in enumerate(valid_loader): x, y = x.to(device), y.to(device) - out = model(x) - logprob = F.log_softmax(out, dim=1) - y_hat_prob = torch.exp(logprob) - val_loss = criterion(logprob, y) + with torch.autocast(device_type=device): + out = model(x) + logprob = F.log_softmax(out, dim=1) + y_hat_prob = torch.exp(logprob) + val_loss = criterion(logprob, y) # Update metrics - metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid") + with torch.autocast(device_type=device): + metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid") # Compute and log metrics - valid_results = metric_aggregator.compute(phase="valid") + with torch.autocast(device_type=device): + valid_results = metric_aggregator.compute(phase="valid") logger.info(f"Epoch {epoch+1} Valid: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('valid_','') for k,v in valid_results.items() if isinstance(v,float)])}") ############################## diff --git a/modules/utils/metric_aggregator.py b/modules/utils/metric_aggregator.py index d51db784f90ed8faeab1016698ed62e789c85f97..7b56d515aea4bc656d75f7410e2f3e47f55cbaf9 100644 --- a/modules/utils/metric_aggregator.py +++ b/modules/utils/metric_aggregator.py @@ -41,10 +41,10 @@ class MetricAggregator: self.init_agg(phase=phase, metric=metric) for metric in ["acc", "cm", "f1"]: if metric in self.metrics: - self.aggregators[phase][metric](y_hat_prob.to(self.device), y.to(self.device)) + self.aggregators[phase][metric].update(y_hat_prob.to(self.device), y.to(self.device)) for k, v in kwargs.items(): if k in self.aggregators[phase]: - self.aggregators[phase][k](v.to(self.device)) + self.aggregators[phase][k].update(v.to(self.device)) for logger in self.loggers: logger.log_metrics({f"{phase}_{k}_step":v.detach().cpu().tolist() for k,v in kwargs.items()}, step=self.step_num) self.step_num+=1 diff --git a/runs/01_train_model.py b/runs/01_train_model.py index 8270d3bcaaa3fb2808bcbb8b969154c050ef02b4..0888e342c758f75d87daa4c6ab353b44e96f849f 100644 --- a/runs/01_train_model.py +++ b/runs/01_train_model.py @@ -91,15 +91,11 @@ def main(cfg: DictConfig) -> None: model_save_dir = Path(cfg.path.base_path_models) / cfg.path.results model_save_dir.mkdir(parents=True, exist_ok=True) - # Determine if CUDA or MPS should be used based on configuration and availability - use_cuda = cfg.training.use_cuda and torch.cuda.is_available() - use_mps = cfg.training.use_mps and torch.backends.mps.is_available() - # Set the random seed for reproducibility seed_everything(cfg.training.seed) # Select the device for computation (CUDA, MPS, or CPU) - device = torch.device("cuda") if use_cuda else torch.device("mps") if use_mps else torch.device("cpu") + device = "cuda" if (cfg.training.device=="cuda" and torch.cuda.is_available()) else "cpu" ############################## # Object Instantiation @@ -125,28 +121,15 @@ def main(cfg: DictConfig) -> None: ############################## # Training loop - train_model(model, train_loader, test_loader, cfg.training, device) + train_model(model, train_loader, test_loader, cfg.training) ############################## # Saving Results ############################## # Save the model checkpoint if configured to do so - if cfg.training.save_model: - model_path = model_save_dir / f"checkpoint.pt" - torch.save(model.state_dict(), model_path) - - # Save the result dictionary - results = { - "model_name": cfg.model.name, - "training_name": cfg.training.name, - "epochs": cfg.training.epochs, - "seed": cfg.training.seed, - "final_model_path": str(model_path), - "timestamp": time.time() - } - with open(model_save_dir / "results.json", "w") as f: - json.dump(results, f, indent=4) + model_path = model_save_dir / f"checkpoint.ckpt" + torch.save(model.state_dict(), model_path) # Save the configuration config_path = model_save_dir / "config.yaml"