diff --git a/4_Resnet_Solar_panel_defect_Identification/train.py b/4_Resnet_Solar_panel_defect_Identification/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9bacb8b7acf5cf423657eab44ad3256a90964cc4 --- /dev/null +++ b/4_Resnet_Solar_panel_defect_Identification/train.py @@ -0,0 +1,46 @@ +import torch as t +import torch.utils.data + +from data import ChallengeDataset +from trainer import Trainer +from matplotlib import pyplot as plt +import numpy as np +import model +import pandas as pd +from sklearn.model_selection import train_test_split + + +# load the data from the csv file and perform a train-test-split +# this can be accomplished using the already imported pandas and sklearn.model_selection modules +# TODO +dataframe = pd.read_csv('data.csv', sep=";") + +# set up data loading for the training and validation set each using t.utils.data.DataLoader and ChallengeDataset objects +# TODO +train_set, validation_set = train_test_split(dataframe, shuffle=True) + +# create an instance of our ResNet model +# TODO +model = model.ResNet() +batch_size = 32 +# set up a suitable loss criterion (you can find a pre-implemented loss functions in t.nn) +# set up the optimizer (see t.optim) +# create an object of type Trainer and set its early stopping criterion +# TODO +train_challenge = torch.utils.data.DataLoader(ChallengeDataset(train_set, 'train'), batch_size=batch_size) + +validation_challenge = torch.utils.data.DataLoader(ChallengeDataset(validation_set, 'val'), batch_size=batch_size) +loss_fn = t.nn.BCELoss() +optimizer = t.optim.SGD(model.parameters(), lr=0.009, momentum=0.8) +trainer = Trainer(model, loss_fn, optimizer, train_challenge, validation_challenge, cuda = True,early_stopping_patience=-1) + + +# go, go, go... call fit on trainer +res = trainer.fit(epochs=50) + +# plot the results +plt.plot(np.arange(len(res[0])), res[0], label='train loss') +plt.plot(np.arange(len(res[1])), res[1], label='val loss') +plt.yscale('log') +plt.legend() +plt.savefig('losses.png') \ No newline at end of file