Browse Source

Ran formatting and work on model

Nicholas Schense 5 months ago
parent
commit
21912ecf80
6 changed files with 306 additions and 117 deletions
  1. 162 0
      .gitignore
  2. 59 28
      train_cnn.py
  3. 4 0
      utils/data/datasets.py
  4. 13 25
      utils/models/cnn.py
  5. 33 38
      utils/models/layers.py
  6. 35 26
      utils/training.py

+ 162 - 0
.gitignore

@@ -0,0 +1,162 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+#   For a library or package, you might want to ignore these files since the code is
+#   intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+#   This is especially recommended for binary packages to ensure reproducibility, and is more
+#   commonly ignored for libraries.
+#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+#   in version control.
+#   https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+#  and can be added to the global gitignore or merged into this file.  For a more nuclear
+#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/

+ 59 - 28
train_cnn.py

@@ -1,50 +1,81 @@
-#MACHINE LEARNING
+# MACHINE LEARNING
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.optim as optim
 import torch.optim as optim
 
 
-#GENERAL USE
+# GENERAL USE
 import random as rand
 import random as rand
 
 
-#SYSTEM
+# SYSTEM
 import tomli as toml
 import tomli as toml
 import os
 import os
 
 
-#DATA PROCESSING
+# DATA PROCESSING
 from sklearn.model_selection import train_test_split
 from sklearn.model_selection import train_test_split
 
 
-#CUSTOM MODULES
+# CUSTOM MODULES
 import utils.models.cnn as cnn
 import utils.models.cnn as cnn
-from utils.data.datasets import prepare_datasets, initalize_dataloaders 
+from utils.data.datasets import prepare_datasets, initalize_dataloaders
 import utils.training as train
 import utils.training as train
 
 
-#CONFIGURATION
-if os.getenv('ADL_CONFIG_PATH') is None:
-    with open ('config.toml', 'rb') as f:
+# CONFIGURATION
+if os.getenv("ADL_CONFIG_PATH") is None:
+    with open("config.toml", "rb") as f:
         config = toml.load(f)
         config = toml.load(f)
 else:
 else:
-    with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
+    with open(os.getenv("ADL_CONFIG_PATH"), "rb") as f:
         config = toml.load(f)
         config = toml.load(f)
-        
-for i in range(config['training']['runs']):           
-    #Set up the model
-    model = cnn.CNN(config['model']['image_channels'], config['model']['clin_data_channels'], config['hyperparameters']['droprate']).float()
+
+for i in range(config["training"]["runs"]):
+    # Set up the model
+    model = cnn.CNN(
+        config["model"]["image_channels"],
+        config["model"]["clin_data_channels"],
+        config["hyperparameters"]["droprate"],
+    ).float()
     criterion = nn.BCELoss()
     criterion = nn.BCELoss()
-    optimizer = optim.Adam(model.parameters(), lr = config['hyperparameters']['learning_rate'])
+    optimizer = optim.Adam(
+        model.parameters(), lr=config["hyperparameters"]["learning_rate"]
+    )
 
 
-    #Generate seed for each run
+    # Generate seed for each run
     seed = rand.randint(0, 1000)
     seed = rand.randint(0, 1000)
 
 
-    #Prepare data
-    train_dataset, val_dataset, test_dataset = prepare_datasets(config['paths']['mri_data'], config['paths']['xls_data'], config['dataset']['validation_split'], seed)
-    train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(train_dataset, val_dataset, test_dataset, config['hyperparameters']['batch_size'])
-    
-    #Train the model
-    history = train.train_model(model, train_dataloader, val_dataloader, criterion, optimizer, config)
-    
-    #Save model
-    if not os.path.exists(config['paths']['model_output'] + "/" + str(config['model']['name'])):
-        os.makedirs(config['paths']['model_output'] + "/" + str(config['model']['name']))
-    
-    torch.save(model, config['paths']['model_output'] + "/" + str(config['model']['name']) + "/" + str(i) + "_" + "s-" + str(seed) + ".pt")
+    # Prepare data
+    train_dataset, val_dataset, test_dataset = prepare_datasets(
+        config["paths"]["mri_data"],
+        config["paths"]["xls_data"],
+        config["dataset"]["validation_split"],
+        seed,
+    )
+    train_dataloader, val_dataloader, test_dataloader = initalize_dataloaders(
+        train_dataset,
+        val_dataset,
+        test_dataset,
+        config["hyperparameters"]["batch_size"],
+    )
+
+    # Train the model
+    history = train.train_model(
+        model, train_dataloader, val_dataloader, criterion, optimizer, config
+    )
+
+    # Save model
+    if not os.path.exists(
+        config["paths"]["model_output"] + "/" + str(config["model"]["name"])
+    ):
+        os.makedirs(
+            config["paths"]["model_output"] + "/" + str(config["model"]["name"])
+        )
 
 
+    torch.save(
+        model,
+        config["paths"]["model_output"]
+        + "/"
+        + str(config["model"]["name"])
+        + "/"
+        + str(i)
+        + "_s-"
+        + str(seed)
+        + ".pt",
+    )

+ 4 - 0
utils/data/datasets.py

@@ -121,6 +121,10 @@ class ADNIDataset(Dataset):
         #Convert to one-hot and squeeze
         #Convert to one-hot and squeeze
         class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
         class_id = torch.nn.functional.one_hot(class_id, num_classes=2).squeeze(0)
         
         
+        #Convert to float
+        mri_tensor = mri_tensor.float()
+        xls_tensor = xls_tensor.float()
+        class_id = class_id.float()
 
 
         return (mri_tensor, xls_tensor), class_id
         return (mri_tensor, xls_tensor), class_id
     
     

+ 13 - 25
utils/models/cnn.py

@@ -27,37 +27,29 @@ class Parameters:
 
 
 
 
 class CNN(nn.Module):
 class CNN(nn.Module):
-
     def __init__(self, image_channels, clin_data_channels, droprate):
     def __init__(self, image_channels, clin_data_channels, droprate):
         super().__init__()
         super().__init__()
 
 
-        #Image Section
+        # Image Section
         self.image_section = CNN_Image_Section(image_channels, droprate)
         self.image_section = CNN_Image_Section(image_channels, droprate)
 
 
-        #Data Layers, fully connected
+        # Data Layers, fully connected
         self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
         self.fc_clin1 = ly.FullConnBlock(clin_data_channels, 64, droprate=droprate)
         self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
         self.fc_clin2 = ly.FullConnBlock(64, 20, droprate=droprate)
-        
 
 
-        #Final Dense Layer
+        # Final Dense Layer
         self.dense1 = nn.Linear(40, 5)
         self.dense1 = nn.Linear(40, 5)
         self.dense2 = nn.Linear(5, 2)
         self.dense2 = nn.Linear(5, 2)
-        self.softmax = nn.Softmax(dim = 1)
-
-       
+        self.softmax = nn.Softmax(dim=1)
 
 
     def forward(self, x):
     def forward(self, x):
-
         image, clin_data = x
         image, clin_data = x
 
 
         image = self.image_section(image)
         image = self.image_section(image)
 
 
-        
-
         clin_data = self.fc_clin1(clin_data)
         clin_data = self.fc_clin1(clin_data)
         clin_data = self.fc_clin2(clin_data)
         clin_data = self.fc_clin2(clin_data)
 
 
-
         x = torch.cat((image, clin_data), dim=1)
         x = torch.cat((image, clin_data), dim=1)
         x = self.dense1(x)
         x = self.dense1(x)
         x = self.dense2(x)
         x = self.dense2(x)
@@ -65,29 +57,27 @@ class CNN(nn.Module):
         return x
         return x
 
 
 
 
-
-
-
 class CNN_Image_Section(nn.Module):
 class CNN_Image_Section(nn.Module):
     def __init__(self, image_channels, droprate):
     def __init__(self, image_channels, droprate):
         super().__init__()
         super().__init__()
-            # Initial Convolutional Blocks
+        # Initial Convolutional Blocks
         self.conv1 = ly.ConvBlock(
         self.conv1 = ly.ConvBlock(
-            image_channels, 192, (11, 13, 11), stride=(4, 4, 4), droprate=droprate, pool=False
-        )
-        self.conv2 = ly.ConvBlock(
-            192, 384, (5, 6, 5), droprate=droprate, pool=False
+            image_channels,
+            192,
+            (11, 13, 11),
+            stride=(4, 4, 4),
+            droprate=droprate,
+            pool=False,
         )
         )
+        self.conv2 = ly.ConvBlock(192, 384, (5, 6, 5), droprate=droprate, pool=False)
 
 
         # Midflow Block
         # Midflow Block
         self.midflow = ly.MidFlowBlock(384, droprate)
         self.midflow = ly.MidFlowBlock(384, droprate)
 
 
-        
-
         # Split Convolutional Block
         # Split Convolutional Block
         self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
         self.splitconv = ly.SplitConvBlock(384, 192, 96, 1, droprate)
 
 
-        #Fully Connected Block
+        # Fully Connected Block
         self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
         self.fc_image = ly.FullConnBlock(227136, 20, droprate=droprate)
 
 
     def forward(self, x):
     def forward(self, x):
@@ -99,5 +89,3 @@ class CNN_Image_Section(nn.Module):
         x = self.fc_image(x)
         x = self.fc_image(x)
 
 
         return x
         return x
-
-

+ 33 - 38
utils/models/layers.py

@@ -26,7 +26,7 @@ class SepConv3d(nn.Module):
     def forward(self, x):
     def forward(self, x):
         x = self.depthwise(x)
         x = self.depthwise(x)
         return x
         return x
-    
+
 
 
 class SplitConvBlock(nn.Module):
 class SplitConvBlock(nn.Module):
     def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
     def __init__(self, in_channels, mid_channels, out_channels, split_dim, drop_rate):
@@ -35,25 +35,22 @@ class SplitConvBlock(nn.Module):
         self.split_dim = split_dim
         self.split_dim = split_dim
 
 
         self.leftconv_1 = SepConvBlock(
         self.leftconv_1 = SepConvBlock(
-            in_channels //2, mid_channels //2, (3, 4, 3),  droprate=drop_rate
+            in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
         )
         )
         self.rightconv_1 = SepConvBlock(
         self.rightconv_1 = SepConvBlock(
-            in_channels //2, mid_channels //2, (3, 4, 3),  droprate=drop_rate
+            in_channels // 2, mid_channels // 2, (3, 4, 3), droprate=drop_rate
         )
         )
 
 
         self.leftconv_2 = SepConvBlock(
         self.leftconv_2 = SepConvBlock(
-            mid_channels //2, out_channels //2, (3, 4, 3),  droprate=drop_rate
+            mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
         )
         )
         self.rightconv_2 = SepConvBlock(
         self.rightconv_2 = SepConvBlock(
-            mid_channels //2, out_channels //2, (3, 4, 3),  droprate=drop_rate
+            mid_channels // 2, out_channels // 2, (3, 4, 3), droprate=drop_rate
         )
         )
 
 
-        
-
     def forward(self, x):
     def forward(self, x):
         (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
         (left, right) = torch.tensor_split(x, 2, dim=self.split_dim)
 
 
-
         self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
         self.leftblock = nn.Sequential(self.leftconv_1, self.leftconv_2)
         self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
         self.rightblock = nn.Sequential(self.rightconv_1, self.rightconv_2)
 
 
@@ -61,47 +58,47 @@ class SplitConvBlock(nn.Module):
         right = self.rightblock(right)
         right = self.rightblock(right)
         x = torch.cat((left, right), dim=self.split_dim)
         x = torch.cat((left, right), dim=self.split_dim)
         return x
         return x
-    
+
 
 
 class MidFlowBlock(nn.Module):
 class MidFlowBlock(nn.Module):
     def __init__(self, channels, drop_rate):
     def __init__(self, channels, drop_rate):
         super(MidFlowBlock, self).__init__()
         super(MidFlowBlock, self).__init__()
 
 
         self.conv1 = ConvBlock(
         self.conv1 = ConvBlock(
-            channels, channels, (3, 3, 3),  droprate=drop_rate, padding="same"
+            channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
         )
         self.conv2 = ConvBlock(
         self.conv2 = ConvBlock(
             channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
             channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
         )
         self.conv3 = ConvBlock(
         self.conv3 = ConvBlock(
-            channels, channels, (3, 3, 3),  droprate=drop_rate, padding="same"
+            channels, channels, (3, 3, 3), droprate=drop_rate, padding="same"
         )
         )
 
 
-        #self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
+        # self.block = nn.Sequential(self.conv1, self.conv2, self.conv3)
         self.block = self.conv1
         self.block = self.conv1
 
 
     def forward(self, x):
     def forward(self, x):
         x = nn.ELU()(self.block(x) + x)
         x = nn.ELU()(self.block(x) + x)
         return x
         return x
 
 
-        
+
 class ConvBlock(nn.Module):
 class ConvBlock(nn.Module):
     def __init__(
     def __init__(
-            self,
-            in_channels,
-            out_channels,
-            kernel_size,
-            stride=(1, 1, 1),
-            padding="valid",
-            droprate=None,
-            pool=False,
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=(1, 1, 1),
+        padding="valid",
+        droprate=None,
+        pool=False,
     ):
     ):
         super(ConvBlock, self).__init__()
         super(ConvBlock, self).__init__()
         self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
         self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
         self.norm = nn.BatchNorm3d(out_channels)
         self.norm = nn.BatchNorm3d(out_channels)
         self.elu = nn.ELU()
         self.elu = nn.ELU()
         self.dropout = nn.Dropout(droprate)
         self.dropout = nn.Dropout(droprate)
-        
+
         if pool:
         if pool:
             self.maxpool = nn.MaxPool3d(3, stride=2)
             self.maxpool = nn.MaxPool3d(3, stride=2)
         else:
         else:
@@ -111,13 +108,12 @@ class ConvBlock(nn.Module):
         x = self.conv(x)
         x = self.conv(x)
         x = self.norm(x)
         x = self.norm(x)
         x = self.elu(x)
         x = self.elu(x)
-       
-        
+
         if self.maxpool:
         if self.maxpool:
             x = self.maxpool(x)
             x = self.maxpool(x)
 
 
         x = self.dropout(x)
         x = self.dropout(x)
-            
+
         return x
         return x
 
 
 
 
@@ -135,25 +131,25 @@ class FullConnBlock(nn.Module):
         x = self.elu(x)
         x = self.elu(x)
         x = self.dropout(x)
         x = self.dropout(x)
         return x
         return x
-    
+
 
 
 class SepConvBlock(nn.Module):
 class SepConvBlock(nn.Module):
     def __init__(
     def __init__(
-      self,
-      in_channels,
-      out_channels,
-      kernel_size,
-      stride = (1, 1, 1),
-      padding = "valid",
-      droprate = None,
-      pool = False,      
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        stride=(1, 1, 1),
+        padding="valid",
+        droprate=None,
+        pool=False,
     ):
     ):
         super(SepConvBlock, self).__init__()
         super(SepConvBlock, self).__init__()
         self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
         self.conv = SepConv3d(in_channels, out_channels, kernel_size, stride, padding)
         self.norm = nn.BatchNorm3d(out_channels)
         self.norm = nn.BatchNorm3d(out_channels)
         self.elu = nn.ELU()
         self.elu = nn.ELU()
         self.dropout = nn.Dropout(droprate)
         self.dropout = nn.Dropout(droprate)
-        
+
         if pool:
         if pool:
             self.maxpool = nn.MaxPool3d(3, stride=2)
             self.maxpool = nn.MaxPool3d(3, stride=2)
         else:
         else:
@@ -163,11 +159,10 @@ class SepConvBlock(nn.Module):
         x = self.conv(x)
         x = self.conv(x)
         x = self.norm(x)
         x = self.norm(x)
         x = self.elu(x)
         x = self.elu(x)
-        
-        
+
         if self.maxpool:
         if self.maxpool:
             x = self.maxpool(x)
             x = self.maxpool(x)
 
 
         x = self.dropout(x)
         x = self.dropout(x)
 
 
-        return x
+        return x

+ 35 - 26
utils/training.py

@@ -1,28 +1,34 @@
-import torch 
+import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.optim as optim
 import torch.optim as optim
 import torchvision
 import torchvision
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 import pandas as pd
 import pandas as pd
+from tqdm import tqdm
 
 
 
 
-
-def train_epoch(model, train_loader, val_loader, criterion, optimizer):
+def train_epoch(model, train_loader, val_loader, criterion, optimizer, config, epoch):
     model.train()
     model.train()
     train_loss = 0
     train_loss = 0
     val_loss = 0
     val_loss = 0
-    for i, (data, target) in enumerate(train_loader):
+    for i, (data, target) in tqdm(
+        enumerate(train_loader),
+        total=len(train_loader),
+        desc="Epoch " + str(epoch + 1) + "/" + str(config["training"]["max_epochs"]),
+        unit="batch",
+        disable=config["operation"]["silent"],
+    ):
         optimizer.zero_grad()
         optimizer.zero_grad()
-        
+
         output = model(data)
         output = model(data)
         loss = criterion(output, target)
         loss = criterion(output, target)
         loss.backward()
         loss.backward()
         optimizer.step()
         optimizer.step()
-        
+
         train_loss += loss.item()
         train_loss += loss.item()
-        
+
     train_loss /= len(train_loader)
     train_loss /= len(train_loader)
-    
+
     model.eval()
     model.eval()
     with torch.no_grad():
     with torch.no_grad():
         for i, (data, target) in enumerate(val_loader):
         for i, (data, target) in enumerate(val_loader):
@@ -30,47 +36,50 @@ def train_epoch(model, train_loader, val_loader, criterion, optimizer):
             loss = criterion(output, target)
             loss = criterion(output, target)
             val_loss += loss.item()
             val_loss += loss.item()
         val_loss /= len(val_loader)
         val_loss /= len(val_loader)
-        
+
     return train_loss, val_loss
     return train_loss, val_loss
 
 
+
 def evaluate_accuracy(model, loader):
 def evaluate_accuracy(model, loader):
     model.eval()
     model.eval()
     correct = 0
     correct = 0
     total = 0
     total = 0
     predictions = []
     predictions = []
     actual = []
     actual = []
-    
+
     with torch.no_grad():
     with torch.no_grad():
         for data, target in loader:
         for data, target in loader:
             output = model(data)
             output = model(data)
             _, predicted = torch.max(output.data, 1)
             _, predicted = torch.max(output.data, 1)
             total += target.size(0)
             total += target.size(0)
             correct += (predicted == target).sum().item()
             correct += (predicted == target).sum().item()
-            
+
             out = output[:, 1].tolist()
             out = output[:, 1].tolist()
             predictions.extend(out)
             predictions.extend(out)
-            
+
             act = target[:, 1].tolist()
             act = target[:, 1].tolist()
             actual.extend(act)
             actual.extend(act)
-            
+
     return correct / total, predictions, actual
     return correct / total, predictions, actual
 
 
+
 def train_model(model, train_loader, val_loader, criterion, optimizer, config):
 def train_model(model, train_loader, val_loader, criterion, optimizer, config):
-    
-    history = pd.DataFrame(columns = ["Epoch", "Train Loss", "Val Loss", "Train Acc","Val Acc"]).set_index("Epoch")
-    
-    
+    history = pd.DataFrame(
+        columns=["Epoch", "Train Loss", "Val Loss", "Train Acc", "Val Acc"]
+    ).set_index("Epoch")
+
     for epoch in range(config["training"]["max_epochs"]):
     for epoch in range(config["training"]["max_epochs"]):
-        train_loss, val_loss = train_epoch(model, train_loader, val_loader, criterion, optimizer)
-        if config["operation"]["silent"] is False: print(f"Epoch {epoch + 1} - Train Loss: {train_loss} - Val Loss: {val_loss}")
-        
+        train_loss, val_loss = train_epoch(
+            model, train_loader, val_loader, criterion, optimizer, config, epoch
+        )
+        if config["operation"]["silent"] is False:
+            print(
+                f"Epoch {epoch + 1} - Train Loss: {train_loss} - Val Loss: {val_loss}"
+            )
+
         train_acc, _, _ = evaluate_accuracy(model, train_loader)
         train_acc, _, _ = evaluate_accuracy(model, train_loader)
         val_acc, _, _ = evaluate_accuracy(model, val_loader)
         val_acc, _, _ = evaluate_accuracy(model, val_loader)
-        
+
         history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
         history.loc[epoch] = [train_loss, val_loss, train_acc, val_acc]
-        
+
     return history
     return history
-        
-    
-        
-