Skip to content
Snippets Groups Projects
Commit f81d62b9 authored by Falguni Ghosh's avatar Falguni Ghosh
Browse files

Upload New File

parent 26604135
Branches
No related tags found
No related merge requests found
import numpy as np
from .Base import BaseLayer
from .Helpers import compute_bn_gradients
class BatchNormalization(BaseLayer):
def __init__(self, channels):
super().__init__()
self.num_channels = channels
self.weights = None
self.bias = None
self.trainable = True
self.output_tensor = None
self.next_layer_conv = None
self.input_tilde = None
self.mean_m_avg = None
self.var_m_avg = None
self.m_avg_decay = 0.8
self.input_tensor = None
self.input_tensor_shape = None
self.error_tensor = None
self.reformat_tensor_shape = None
self.gradient_wrt_input = None
self._optimizer = None
self._gradient_weights = None
self._gradient_bias = None
self.initialize(None, None)
@property
def gradient_weights(self):
return self._gradient_weights
@gradient_weights.setter
def gradient_weights(self, w):
self._gradient_weights = w
# gradient_weights = property(get_gradient_weights, set_gradient_weights)
@property
def gradient_bias(self):
return self._gradient_bias
@gradient_bias.setter
def set_gradient_bias(self, b):
self._gradient_bias = b
# gradient_bias = property(get_gradient_bias, set_gradient_bias)
@property
def optimizer(self):
return self._optimizer
@optimizer.setter
def optimizer(self, ow):
self._optimizer = ow
def initialize(self, dummy_arg_1, dummy_arg_2):
self.weights = np.ones(self.num_channels)
self.bias = np.zeros(self.num_channels)
self.mean_m_avg = 0
self.var_m_avg = 0
def forward(self, input_tensor):
# print("a")
# print(input_tensor.shape)
self.input_tensor = input_tensor
self.input_tensor_shape = self.input_tensor.shape
if input_tensor.ndim == 4: #convolution layer next
self.next_layer_conv = True
self.input_tensor = self.reformat(input_tensor)
else:
self.next_layer_conv = False
if not self.testing_phase:
batch_mean = np.mean(self.input_tensor, axis=0)
batch_var = np.std(self.input_tensor, axis=0) ** 2
#print(batch_mean.shape)
#print(batch_var.shape)
self.input_tilde = (self.input_tensor - batch_mean) / (np.sqrt(batch_var + np.finfo(float).eps))
self.output_tensor = self.weights * self.input_tilde + self.bias
if np.all((self.mean_m_avg == 0)) and np.all((self.var_m_avg == 0)):
self.mean_m_avg = batch_mean
self.var_m_avg = batch_var
else:
self.mean_m_avg = self.m_avg_decay * self.mean_m_avg + (1 - self.m_avg_decay) * batch_mean
self.var_m_avg = self.m_avg_decay * self.var_m_avg + (1 - self.m_avg_decay) * batch_var
else:
self.input_tilde = (self.input_tensor - self.mean_m_avg) / (np.sqrt(self.var_m_avg + np.finfo(float).eps))
self.output_tensor = self.weights * self.input_tilde + self.bias
if self.next_layer_conv:
self.output_tensor = self.reformat(self.output_tensor)
return self.output_tensor
def backward(self, error_tensor):
#print("dummy print statement")
self.error_tensor = error_tensor
if self.next_layer_conv:
self.error_tensor = self.reformat(error_tensor)
#print(error_tensor.shape)
#if self.next_layer_conv:
self._gradient_weights = np.sum(self.error_tensor * self.input_tilde, axis=0)
self._gradient_bias = np.sum(self.error_tensor, axis=0)
self.gradient_wrt_input = compute_bn_gradients(self.error_tensor, self.input_tensor, self.weights, self.mean_m_avg,
self.var_m_avg)
if self.next_layer_conv:
self.gradient_wrt_input = self.reformat(self.gradient_wrt_input)
if not (self.optimizer== None):
self.weights = self.optimizer.calculate_update(self.weights, self._gradient_weights)
self.bias = self.optimizer.calculate_update(self.bias, self._gradient_bias)
return self.gradient_wrt_input
def reformat(self, tensor):
if tensor.ndim == 4:
self.reformat_tensor_shape = tensor.shape
tensor = tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[1], self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3])
tensor = np.transpose(tensor, (0, 2, 1))
return tensor.reshape(self.reformat_tensor_shape[0] * self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3], self.reformat_tensor_shape[1])
else: # reversing previous operations
tensor = tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[2] * self.reformat_tensor_shape[3], self.reformat_tensor_shape[1])
tensor = np.transpose(tensor, (0, 2, 1))
return tensor.reshape(self.reformat_tensor_shape[0], self.reformat_tensor_shape[1], self.reformat_tensor_shape[2], self.reformat_tensor_shape[3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment