From 35ca4f8bc13ad859bfc3ba1239f6cd2a9411de67 Mon Sep 17 00:00:00 2001 From: Falguni Ghosh <falguni.ghosh@fau.de> Date: Sun, 15 Oct 2023 21:05:22 +0000 Subject: [PATCH] Upload New File --- 2_CNN/Flatten.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 2_CNN/Flatten.py diff --git a/2_CNN/Flatten.py b/2_CNN/Flatten.py new file mode 100644 index 0000000..3beb7e6 --- /dev/null +++ b/2_CNN/Flatten.py @@ -0,0 +1,26 @@ +from .Base import BaseLayer +import numpy as np + + +class Flatten(BaseLayer): + + def __init__(self): + super().__init__() + self.input_size = None + + def forward(self, input_tensor): + + self.input_size = np.asarray(input_tensor.shape) + + output_shape = (self.input_size[0],np.prod(self.input_size[1::])) + output = np.empty(output_shape) + + for i in range(self.input_size[0]): + curr_img = input_tensor[i] + output[i] = curr_img.reshape(1, -1).squeeze() + + return output + + + def backward(self, error_tensor): + return error_tensor.reshape(self.input_size) -- GitLab