train_cnn.py 753 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #MACHINE LEARNING
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torchvision
  6. #GENERAL USE
  7. import numpy as np
  8. import pandas as pd
  9. from datetime import datetime
  10. #SYSTEM
  11. import tomli as toml
  12. import os
  13. #DATA PROCESSING
  14. from sklearn.model_selection import train_test_split
  15. #CUSTOM MODULES
  16. import utils.models.cnn as cnn
  17. #CONFIGURATION
  18. if os.getenv('ADL_CONFIG_PATH') is None:
  19. with open ('config.toml', 'rb') as f:
  20. config = toml.load(f)
  21. else:
  22. with open(os.getenv('ADL_CONFIG_PATH'), 'rb') as f:
  23. config = toml.load(f)
  24. #Set up the model
  25. model = cnn.CNN(config)
  26. criterion = nn.BCELoss()
  27. optimizer = optim.Adam(model.parameters(), lr = config['training']['learning_rate'])
  28. #Load datasets