Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
P
Pytorch Without Pytorch
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Falguni Ghosh
Pytorch Without Pytorch
Commits
f81d62b9
Commit
f81d62b9
authored
Oct 15, 2023
by
Falguni Ghosh
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
26604135
Branches
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
3_RNN/BatchNormalization.py
+133
-0
133 additions, 0 deletions
3_RNN/BatchNormalization.py
with
133 additions
and
0 deletions
3_RNN/BatchNormalization.py
0 → 100644
+
133
−
0
View file @
f81d62b9
import
numpy
as
np
from
.Base
import
BaseLayer
from
.Helpers
import
compute_bn_gradients
class
BatchNormalization
(
BaseLayer
):
def
__init__
(
self
,
channels
):
super
().
__init__
()
self
.
num_channels
=
channels
self
.
weights
=
None
self
.
bias
=
None
self
.
trainable
=
True
self
.
output_tensor
=
None
self
.
next_layer_conv
=
None
self
.
input_tilde
=
None
self
.
mean_m_avg
=
None
self
.
var_m_avg
=
None
self
.
m_avg_decay
=
0.8
self
.
input_tensor
=
None
self
.
input_tensor_shape
=
None
self
.
error_tensor
=
None
self
.
reformat_tensor_shape
=
None
self
.
gradient_wrt_input
=
None
self
.
_optimizer
=
None
self
.
_gradient_weights
=
None
self
.
_gradient_bias
=
None
self
.
initialize
(
None
,
None
)
@property
def
gradient_weights
(
self
):
return
self
.
_gradient_weights
@gradient_weights.setter
def
gradient_weights
(
self
,
w
):
self
.
_gradient_weights
=
w
# gradient_weights = property(get_gradient_weights, set_gradient_weights)
@property
def
gradient_bias
(
self
):
return
self
.
_gradient_bias
@gradient_bias.setter
def
set_gradient_bias
(
self
,
b
):
self
.
_gradient_bias
=
b
# gradient_bias = property(get_gradient_bias, set_gradient_bias)
@property
def
optimizer
(
self
):
return
self
.
_optimizer
@optimizer.setter
def
optimizer
(
self
,
ow
):
self
.
_optimizer
=
ow
def
initialize
(
self
,
dummy_arg_1
,
dummy_arg_2
):
self
.
weights
=
np
.
ones
(
self
.
num_channels
)
self
.
bias
=
np
.
zeros
(
self
.
num_channels
)
self
.
mean_m_avg
=
0
self
.
var_m_avg
=
0
def
forward
(
self
,
input_tensor
):
# print("a")
# print(input_tensor.shape)
self
.
input_tensor
=
input_tensor
self
.
input_tensor_shape
=
self
.
input_tensor
.
shape
if
input_tensor
.
ndim
==
4
:
#convolution layer next
self
.
next_layer_conv
=
True
self
.
input_tensor
=
self
.
reformat
(
input_tensor
)
else
:
self
.
next_layer_conv
=
False
if
not
self
.
testing_phase
:
batch_mean
=
np
.
mean
(
self
.
input_tensor
,
axis
=
0
)
batch_var
=
np
.
std
(
self
.
input_tensor
,
axis
=
0
)
**
2
#print(batch_mean.shape)
#print(batch_var.shape)
self
.
input_tilde
=
(
self
.
input_tensor
-
batch_mean
)
/
(
np
.
sqrt
(
batch_var
+
np
.
finfo
(
float
).
eps
))
self
.
output_tensor
=
self
.
weights
*
self
.
input_tilde
+
self
.
bias
if
np
.
all
((
self
.
mean_m_avg
==
0
))
and
np
.
all
((
self
.
var_m_avg
==
0
)):
self
.
mean_m_avg
=
batch_mean
self
.
var_m_avg
=
batch_var
else
:
self
.
mean_m_avg
=
self
.
m_avg_decay
*
self
.
mean_m_avg
+
(
1
-
self
.
m_avg_decay
)
*
batch_mean
self
.
var_m_avg
=
self
.
m_avg_decay
*
self
.
var_m_avg
+
(
1
-
self
.
m_avg_decay
)
*
batch_var
else
:
self
.
input_tilde
=
(
self
.
input_tensor
-
self
.
mean_m_avg
)
/
(
np
.
sqrt
(
self
.
var_m_avg
+
np
.
finfo
(
float
).
eps
))
self
.
output_tensor
=
self
.
weights
*
self
.
input_tilde
+
self
.
bias
if
self
.
next_layer_conv
:
self
.
output_tensor
=
self
.
reformat
(
self
.
output_tensor
)
return
self
.
output_tensor
def
backward
(
self
,
error_tensor
):
#print("dummy print statement")
self
.
error_tensor
=
error_tensor
if
self
.
next_layer_conv
:
self
.
error_tensor
=
self
.
reformat
(
error_tensor
)
#print(error_tensor.shape)
#if self.next_layer_conv:
self
.
_gradient_weights
=
np
.
sum
(
self
.
error_tensor
*
self
.
input_tilde
,
axis
=
0
)
self
.
_gradient_bias
=
np
.
sum
(
self
.
error_tensor
,
axis
=
0
)
self
.
gradient_wrt_input
=
compute_bn_gradients
(
self
.
error_tensor
,
self
.
input_tensor
,
self
.
weights
,
self
.
mean_m_avg
,
self
.
var_m_avg
)
if
self
.
next_layer_conv
:
self
.
gradient_wrt_input
=
self
.
reformat
(
self
.
gradient_wrt_input
)
if
not
(
self
.
optimizer
==
None
):
self
.
weights
=
self
.
optimizer
.
calculate_update
(
self
.
weights
,
self
.
_gradient_weights
)
self
.
bias
=
self
.
optimizer
.
calculate_update
(
self
.
bias
,
self
.
_gradient_bias
)
return
self
.
gradient_wrt_input
def
reformat
(
self
,
tensor
):
if
tensor
.
ndim
==
4
:
self
.
reformat_tensor_shape
=
tensor
.
shape
tensor
=
tensor
.
reshape
(
self
.
reformat_tensor_shape
[
0
],
self
.
reformat_tensor_shape
[
1
],
self
.
reformat_tensor_shape
[
2
]
*
self
.
reformat_tensor_shape
[
3
])
tensor
=
np
.
transpose
(
tensor
,
(
0
,
2
,
1
))
return
tensor
.
reshape
(
self
.
reformat_tensor_shape
[
0
]
*
self
.
reformat_tensor_shape
[
2
]
*
self
.
reformat_tensor_shape
[
3
],
self
.
reformat_tensor_shape
[
1
])
else
:
# reversing previous operations
tensor
=
tensor
.
reshape
(
self
.
reformat_tensor_shape
[
0
],
self
.
reformat_tensor_shape
[
2
]
*
self
.
reformat_tensor_shape
[
3
],
self
.
reformat_tensor_shape
[
1
])
tensor
=
np
.
transpose
(
tensor
,
(
0
,
2
,
1
))
return
tensor
.
reshape
(
self
.
reformat_tensor_shape
[
0
],
self
.
reformat_tensor_shape
[
1
],
self
.
reformat_tensor_shape
[
2
],
self
.
reformat_tensor_shape
[
3
])
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment