diff --git a/data/config/loader/basic.yaml b/data/config/loader/basic.yaml
index 52896f468e558432bf8031223a9d08238665cc15..e47c69c0944b906bbb795a944efaa94f61ddd886 100644
--- a/data/config/loader/basic.yaml
+++ b/data/config/loader/basic.yaml
@@ -1,7 +1,7 @@
 name: basic
 train:
   _target_: torch.utils.data.DataLoader
-  batch_size: 64
+  batch_size: 256 #64
   shuffle: True
   num_workers: 2
   pin_memory: True
diff --git a/data/config/training/basic.yaml b/data/config/training/basic.yaml
index a1f20749cc57debab29510909c7e6e22dd6e270a..5a1a747301d532076778db0eb88eb9b01d1aa1cb 100644
--- a/data/config/training/basic.yaml
+++ b/data/config/training/basic.yaml
@@ -1,10 +1,7 @@
 # Training configuration for the model model
 name: basic # name of the training configuration
-use_cuda: True  # Default: True, flag to enable CUDA training.
-use_mps: True  # Default: True, flag to enable macOS GPU training.
-dry_run: false        # Perform a dry run (do not update weights) (bool)
-seed: 1               # Seed for random number generation (int)
-save_model: true     # Whether to save the model to disk (bool)
+seed: 1
+device: cuda
 loss: 
   _target_: torch.nn.NLLLoss
 optimizer:
@@ -15,19 +12,18 @@ scheduler:
   _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
   mode: 'min'
   factor: 0.1
-  patience: 10
-scheduler_monitor: train_loss_epoch
-min_epochs: 50
+  patience: 3
+min_epochs: 10
 max_epochs: 300
 early_stopping_config:
   monitor: valid_acc_epoch
   min_delta: 0.001
-  patience: 20
+  patience: 10
   verbose: False
   mode: max
 gradient_clip_val: 1
 metrics: ['acc', 'f1']
-gradient_accumulation_steps: 1
+gradient_accumulation_steps: 5
 loggers:
   tensorboard:
     _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
diff --git a/modules/training/training.py b/modules/training/training.py
index d57d1f4d4c779f0fa8dbc6c616dae3a67f091f39..a9cedbd1ee37ddb25283c1663ab87a6141a2405f 100644
--- a/modules/training/training.py
+++ b/modules/training/training.py
@@ -9,7 +9,7 @@ from torch.utils.data import DataLoader
 from modules.utils import MetricAggregator, EarlyStoppingCustom
 import logging
 
-def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig, device="cpu"):
+def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLoader, cfg: DictConfig):
     """
     Trains and evaluates a neural network model.
 
@@ -27,7 +27,7 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
     ##############################
     # Device Setup
     ##############################
-    device = torch.device(cfg.device)
+    device = "cuda" if (cfg.device=="cuda" and torch.cuda.is_available()) else "cpu"
     model.to(device)
 
     ##############################
@@ -58,6 +58,9 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
     # Early stopping and checkpointing
     early_stopping = EarlyStoppingCustom(**cfg.early_stopping_config)
 
+    # scaler for automatic mixed precision
+    scaler = torch.cuda.amp.GradScaler()
+
     ##############################
     # Epoch Loop
     ##############################
@@ -67,25 +70,31 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
         ##############################
         model.train()
         lr = torch.tensor(optimizer.param_groups[0]['lr'])
-        for batch_idx, batch in enumerate(train_loader):
-            x, y = batch
+        for batch_idx, (x, y) in enumerate(train_loader):
             x, y = x.to(device), y.to(device)
-            out = model(x)
-            logprob = F.log_softmax(out, dim=1)
-            y_hat_prob = torch.exp(logprob)
-            loss = criterion(logprob, y)
-            loss.backward()
+            with torch.autocast(device_type=device):
+                out = model(x)
+                logprob = F.log_softmax(out, dim=1)
+                y_hat_prob = torch.exp(logprob)
+                loss = criterion(logprob, y)
+                loss = loss / cfg.gradient_accumulation_steps
+
+            # Accumulates scaled gradients.
+            scaler.scale(loss).backward()
 
             # Gradient accumulation
             if (batch_idx + 1) % cfg.gradient_accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
-                optimizer.step()
+                scaler.step(optimizer)
+                scaler.update()
                 optimizer.zero_grad()
 
             # Update metrics
-            metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train")
+            with torch.autocast(device_type=device):
+                metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="train")
 
         # Compute and log metrics
-        train_results = metric_aggregator.compute(phase="train")
+        with torch.autocast(device_type=device):
+            train_results = metric_aggregator.compute(phase="train")
         logger.info(f"Epoch {epoch+1} Train: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('train_','') for k,v in train_results.items() if isinstance(v,float)])}")
 
         ##############################
@@ -93,19 +102,21 @@ def train_model(model: nn.Module, train_loader: DataLoader, valid_loader: DataLo
         ##############################
         model.eval()
         with torch.no_grad():
-            for batch in valid_loader:
-                x, y = batch
+            for batch_idx, (x, y) in enumerate(valid_loader):
                 x, y = x.to(device), y.to(device)
-                out = model(x)
-                logprob = F.log_softmax(out, dim=1)
-                y_hat_prob = torch.exp(logprob)
-                val_loss = criterion(logprob, y)
+                with torch.autocast(device_type=device):
+                    out = model(x)
+                    logprob = F.log_softmax(out, dim=1)
+                    y_hat_prob = torch.exp(logprob)
+                    val_loss = criterion(logprob, y)
 
                 # Update metrics
-                metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid")
+                with torch.autocast(device_type=device):
+                    metric_aggregator.step(y_hat_prob=y_hat_prob, y=y, loss=loss, epoch=torch.tensor(epoch+1), lr=lr, phase="valid")
 
         # Compute and log metrics
-        valid_results = metric_aggregator.compute(phase="valid")
+        with torch.autocast(device_type=device):
+            valid_results = metric_aggregator.compute(phase="valid")
         logger.info(f"Epoch {epoch+1} Valid: {' '.join([f'{k}:{v:.3E}'.replace('_epoch','').replace('valid_','') for k,v in valid_results.items() if isinstance(v,float)])}")
 
         ##############################
diff --git a/modules/utils/metric_aggregator.py b/modules/utils/metric_aggregator.py
index d51db784f90ed8faeab1016698ed62e789c85f97..7b56d515aea4bc656d75f7410e2f3e47f55cbaf9 100644
--- a/modules/utils/metric_aggregator.py
+++ b/modules/utils/metric_aggregator.py
@@ -41,10 +41,10 @@ class MetricAggregator:
                 self.init_agg(phase=phase, metric=metric)
         for metric in ["acc", "cm", "f1"]:
             if metric in self.metrics:
-                self.aggregators[phase][metric](y_hat_prob.to(self.device), y.to(self.device))
+                self.aggregators[phase][metric].update(y_hat_prob.to(self.device), y.to(self.device))
         for k, v in kwargs.items():
             if k in self.aggregators[phase]:
-                self.aggregators[phase][k](v.to(self.device))
+                self.aggregators[phase][k].update(v.to(self.device))
         for logger in self.loggers:
             logger.log_metrics({f"{phase}_{k}_step":v.detach().cpu().tolist() for k,v in kwargs.items()}, step=self.step_num)
         self.step_num+=1
diff --git a/runs/01_train_model.py b/runs/01_train_model.py
index 8270d3bcaaa3fb2808bcbb8b969154c050ef02b4..0888e342c758f75d87daa4c6ab353b44e96f849f 100644
--- a/runs/01_train_model.py
+++ b/runs/01_train_model.py
@@ -91,15 +91,11 @@ def main(cfg: DictConfig) -> None:
     model_save_dir = Path(cfg.path.base_path_models) / cfg.path.results
     model_save_dir.mkdir(parents=True, exist_ok=True)
 
-    # Determine if CUDA or MPS should be used based on configuration and availability
-    use_cuda = cfg.training.use_cuda and torch.cuda.is_available()
-    use_mps = cfg.training.use_mps and torch.backends.mps.is_available()
-
     # Set the random seed for reproducibility
     seed_everything(cfg.training.seed)
 
     # Select the device for computation (CUDA, MPS, or CPU)
-    device = torch.device("cuda") if use_cuda else torch.device("mps") if use_mps else torch.device("cpu")
+    device = "cuda" if (cfg.training.device=="cuda" and torch.cuda.is_available()) else "cpu"
 
     ##############################
     # Object Instantiation
@@ -125,28 +121,15 @@ def main(cfg: DictConfig) -> None:
     ##############################
 
     # Training loop    
-    train_model(model, train_loader, test_loader, cfg.training, device)
+    train_model(model, train_loader, test_loader, cfg.training)
 
     ##############################
     # Saving Results
     ##############################
 
     # Save the model checkpoint if configured to do so
-    if cfg.training.save_model:
-        model_path = model_save_dir / f"checkpoint.pt"
-        torch.save(model.state_dict(), model_path)
-
-    # Save the result dictionary
-    results = {
-        "model_name": cfg.model.name,
-        "training_name": cfg.training.name,
-        "epochs": cfg.training.epochs,
-        "seed": cfg.training.seed,
-        "final_model_path": str(model_path),
-        "timestamp": time.time()
-    }
-    with open(model_save_dir / "results.json", "w") as f:
-        json.dump(results, f, indent=4)
+    model_path = model_save_dir / f"checkpoint.ckpt"
+    torch.save(model.state_dict(), model_path)
 
     # Save the configuration
     config_path = model_save_dir / "config.yaml"