diff --git a/3_RNN/BatchNormalization.py b/3_RNN/BatchNormalization.py new file mode 100644 index 0000000000000000000000000000000000000000..0d631b7fa46d78026cde39bd2a045626c4126c77 --- /dev/null +++ b/3_RNN/BatchNormalization.py @@ -0,0 +1,133 @@ +import numpy as np +from .Base import BaseLayer +from .Helpers import compute_bn_gradients + + +class BatchNormalization(BaseLayer): + def __init__(self, channels): + super().__init__() + self.num_channels = channels + self.weights = None + self.bias = None + self.trainable = True + self.output_tensor = None + self.next_layer_conv = None + self.input_tilde = None + self.mean_m_avg = None + self.var_m_avg = None + self.m_avg_decay = 0.8 + self.input_tensor = None + self.input_tensor_shape = None + self.error_tensor = None + self.reformat_tensor_shape = None + self.gradient_wrt_input = None + self._optimizer = None + self._gradient_weights = None + self._gradient_bias = None + self.initialize(None, None) + + @property + def gradient_weights(self): + return self._gradient_weights + + @gradient_weights.setter + def gradient_weights(self, w): + self._gradient_weights = w + + # gradient_weights = property(get_gradient_weights, set_gradient_weights) + + @property + def gradient_bias(self): + return self._gradient_bias + + @gradient_bias.setter + def set_gradient_bias(self, b): + self._gradient_bias = b + + # gradient_bias = property(get_gradient_bias, set_gradient_bias) + + @property + def optimizer(self): + return self._optimizer + + @optimizer.setter + def optimizer(self, ow): + self._optimizer = ow + + def initialize(self, dummy_arg_1, dummy_arg_2): + self.weights = np.ones(self.num_channels) + self.bias = np.zeros(self.num_channels) + self.mean_m_avg = 0 + self.var_m_avg = 0 + + def forward(self, input_tensor): + # print("a") + # print(input_tensor.shape) + self.input_tensor = input_tensor + self.input_tensor_shape = self.input_tensor.shape + + if input_tensor.ndim == 4: #convolution layer next + self.next_layer_conv = True + self.input_tensor = self.reformat(input_tensor) + else: + self.next_layer_conv = False + + if not self.testing_phase: + batch_mean = np.mean(self.input_tensor, axis=0) + batch_var = np.std(self.input_tensor, axis=0) ** 2 + #print(batch_mean.shape) + #print(batch_var.shape) + self.input_tilde = (self.input_tensor - batch_mean) / (np.sqrt(batch_var + np.finfo(float).eps)) + self.output_tensor = self.weights * self.input_tilde + self.bias + + if np.all((self.mean_m_avg == 0)) and np.all((self.var_m_avg == 0)): + self.mean_m_avg = batch_mean + self.var_m_avg = batch_var + else: + self.mean_m_avg = self.m_avg_decay * self.mean_m_avg + (1 - self.m_avg_decay) * batch_mean + self.var_m_avg = self.m_avg_decay * self.var_m_avg + (1 - self.m_avg_decay) * batch_var + + else: + self.input_tilde = (self.input_tensor - self.mean_m_avg) / (np.sqrt(self.var_m_avg + np.finfo(float).eps)) + self.output_tensor = self.weights * self.input_tilde + self.bias + + if self.next_layer_conv: + self.output_tensor = self.reformat(self.output_tensor) + + return self.output_tensor + + def backward(self, error_tensor): + #print("dummy print statement") + self.error_tensor = error_tensor + if self.next_layer_conv: + self.error_tensor = self.reformat(error_tensor) + #print(error_tensor.shape) + #if self.next_layer_conv: + self._gradient_weights = np.sum(self.error_tensor * self.input_tilde, axis=0) + self._gradient_bias = np.sum(self.error_tensor, axis=0) + self.gradient_wrt_input = compute_bn_gradients(self.error_tensor, self.input_tensor, self.weights, self.mean_m_avg, + self.var_m_avg) + if self.next_layer_conv: + self.gradient_wrt_input = self.reformat(self.gradient_wrt_input) + if not (self.optimizer== None): + self.weights = self.optimizer.calculate_update(self.weights, self._gradient_weights) + self.bias = self.optimizer.calculate_update(self.bias, self._gradient_bias) + + return self.gradient_wrt_input + + + + def reformat(self, tensor): + if tensor.ndim == 4: + self.reformat_tensor_shape = tensor.shape + tensor = tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[1], self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3]) + tensor = np.transpose(tensor, (0, 2, 1)) + return tensor.reshape(self.reformat_tensor_shape[0] * self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3], self.reformat_tensor_shape[1]) + else: # reversing previous operations + tensor = tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3], self.reformat_tensor_shape[1]) + tensor = np.transpose(tensor, (0, 2, 1)) + return tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[1], self.reformat_tensor_shape[2], self.reformat_tensor_shape[3]) + + + +