From 35ca4f8bc13ad859bfc3ba1239f6cd2a9411de67 Mon Sep 17 00:00:00 2001
From: Falguni Ghosh <falguni.ghosh@fau.de>
Date: Sun, 15 Oct 2023 21:05:22 +0000
Subject: [PATCH] Upload New File

---
 2_CNN/Flatten.py | 26 ++++++++++++++++++++++++++
 1 file changed, 26 insertions(+)
 create mode 100644 2_CNN/Flatten.py

diff --git a/2_CNN/Flatten.py b/2_CNN/Flatten.py
new file mode 100644
index 0000000..3beb7e6
--- /dev/null
+++ b/2_CNN/Flatten.py
@@ -0,0 +1,26 @@
+from .Base import BaseLayer
+import numpy as np
+
+
+class Flatten(BaseLayer):
+
+    def __init__(self):
+        super().__init__()
+        self.input_size = None
+
+    def forward(self, input_tensor):
+
+        self.input_size = np.asarray(input_tensor.shape)
+
+        output_shape = (self.input_size[0],np.prod(self.input_size[1::]))
+        output = np.empty(output_shape)
+
+        for i in range(self.input_size[0]):
+            curr_img = input_tensor[i]
+            output[i] = curr_img.reshape(1, -1).squeeze()
+
+        return output
+
+
+    def backward(self, error_tensor):
+        return error_tensor.reshape(self.input_size)
-- 
GitLab