diff --git a/2_CNN/Flatten.py b/2_CNN/Flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..3beb7e64284401920ab4a8ced770b11868476e24 --- /dev/null +++ b/2_CNN/Flatten.py @@ -0,0 +1,26 @@ +from .Base import BaseLayer +import numpy as np + + +class Flatten(BaseLayer): + + def __init__(self): + super().__init__() + self.input_size = None + + def forward(self, input_tensor): + + self.input_size = np.asarray(input_tensor.shape) + + output_shape = (self.input_size[0],np.prod(self.input_size[1::])) + output = np.empty(output_shape) + + for i in range(self.input_size[0]): + curr_img = input_tensor[i] + output[i] = curr_img.reshape(1, -1).squeeze() + + return output + + + def backward(self, error_tensor): + return error_tensor.reshape(self.input_size)