diff --git a/3_RNN/ReLU.py b/3_RNN/ReLU.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b2acf73b9822c1f34481afcd3750cf9284733b --- /dev/null +++ b/3_RNN/ReLU.py @@ -0,0 +1,19 @@ +import numpy as np +from .Base import BaseLayer + + +class ReLU(BaseLayer): + + def __init__(self): + super().__init__() + self.buffered_input = None + + def forward(self, input_tensor): + self.buffered_input = input_tensor + input_tensor[input_tensor < 0] = 0 + return input_tensor + + def backward(self, error_tensor): + error_tensor[self.buffered_input <= 0] = 0 + return error_tensor +