From 2cd01454f1076b080231274fbec7f112322766d2 Mon Sep 17 00:00:00 2001 From: Falguni Ghosh <falguni.ghosh@fau.de> Date: Sun, 15 Oct 2023 21:21:48 +0000 Subject: [PATCH] Upload New File --- .../train.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 4_Resnet_Solar_panel_defect_Identification/train.py diff --git a/4_Resnet_Solar_panel_defect_Identification/train.py b/4_Resnet_Solar_panel_defect_Identification/train.py new file mode 100644 index 0000000..9bacb8b --- /dev/null +++ b/4_Resnet_Solar_panel_defect_Identification/train.py @@ -0,0 +1,46 @@ +import torch as t +import torch.utils.data + +from data import ChallengeDataset +from trainer import Trainer +from matplotlib import pyplot as plt +import numpy as np +import model +import pandas as pd +from sklearn.model_selection import train_test_split + + +# load the data from the csv file and perform a train-test-split +# this can be accomplished using the already imported pandas and sklearn.model_selection modules +# TODO +dataframe = pd.read_csv('data.csv', sep=";") + +# set up data loading for the training and validation set each using t.utils.data.DataLoader and ChallengeDataset objects +# TODO +train_set, validation_set = train_test_split(dataframe, shuffle=True) + +# create an instance of our ResNet model +# TODO +model = model.ResNet() +batch_size = 32 +# set up a suitable loss criterion (you can find a pre-implemented loss functions in t.nn) +# set up the optimizer (see t.optim) +# create an object of type Trainer and set its early stopping criterion +# TODO +train_challenge = torch.utils.data.DataLoader(ChallengeDataset(train_set, 'train'), batch_size=batch_size) + +validation_challenge = torch.utils.data.DataLoader(ChallengeDataset(validation_set, 'val'), batch_size=batch_size) +loss_fn = t.nn.BCELoss() +optimizer = t.optim.SGD(model.parameters(), lr=0.009, momentum=0.8) +trainer = Trainer(model, loss_fn, optimizer, train_challenge, validation_challenge, cuda = True,early_stopping_patience=-1) + + +# go, go, go... call fit on trainer +res = trainer.fit(epochs=50) + +# plot the results +plt.plot(np.arange(len(res[0])), res[0], label='train loss') +plt.plot(np.arange(len(res[1])), res[1], label='val loss') +plt.yscale('log') +plt.legend() +plt.savefig('losses.png') \ No newline at end of file -- GitLab