diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py
index 73d9277c31dc350c06c993cc0d123f6e608665ad..cc7d5b4b06f1cf3bda20348953614c420070015f 100644
--- a/SSLGlacier/main.py
+++ b/SSLGlacier/main.py
@@ -5,12 +5,12 @@ from torchinfo import summary as torchinfo_summery
 from torchsummary import summary as torch_summary
 from torch._inductor.config import trace
 
-from processing import datamodule_, transformers_
+from modules import datamodule_, transformers_
 import pytorch_lightning as pl
 from models import ResNet_light_mine, Swav_Orig_ResNet
 import os
 import torch
-from processing.swav_module import SwAV
+from modules.swav_module import SwAV
 from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
 import numpy as np
 from utils.utils import str2bool, getDataPath
diff --git a/SSLGlacier/modules/agumentations_.py b/SSLGlacier/modules/agumentations_.py
new file mode 100644
index 0000000000000000000000000000000000000000..05cb774513ca1749f65dc55320f93754ab81db8a
--- /dev/null
+++ b/SSLGlacier/modules/agumentations_.py
@@ -0,0 +1,411 @@
+# PILRandomGaussianBlur and get_color_distortion are used and implemented
+# by Swav  and  SimCLR - https://arxiv.org/abs/2002.05709
+import random
+from logging import getLogger
+
+import torch.nn
+from PIL import ImageFilter
+import numpy as np
+import torchvision.transforms as transforms
+import torchvision, tormentor
+from torchvision.transforms import functional as F
+import torch
+import matplotlib.pyplot as plt
+
+logger = getLogger()
+
+
+class Compose(object):
+    """
+    Class for chaining transforms together
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return [image, target]
+
+    def __getitem__(self, item):
+        return self.transforms[item]
+
+class DoNothing(torch.nn.Module):
+    def __init__(self):
+        super(DoNothing, self).__init__()
+
+    def __call__(self, img, mask):
+        return [img, mask]
+
+class Cropper(torch.nn.Module):
+    def __init__(self, i, j, h, w):
+        super(Cropper, self).__init__()
+        self.left = i
+        self.right = j
+        self.height = h
+        self.width = w
+
+    def __call__(self, img, mask):
+        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
+        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
+        return [cropped_img, cropped_mask]
+
+
+class RandomCropper(torch.nn.Module):
+    '''
+          This function returns one patch at time, if you need more patches in one image, call it in a loop
+          Args:
+              orig_img: get an png or jpg image and crop one patch randomly, in both image and mask
+              orig_mask: This is the images mask, we crop same area from
+          Returns:
+              A list of two argument, first cropped patch in image,second cropped patch in mask
+          '''
+
+    def __init__(self, size):
+        super(RandomCropper, self).__init__()
+        self.left = 0
+        self.right = 0
+        self.height = 0
+        self.width = 0
+        self.size = size
+
+    def forward(self, img, mask):
+        self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params(
+            img, output_size=(self.size, self.size))
+        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
+        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
+        return [cropped_img, cropped_mask]
+
+
+class PILRandomGaussianBlur(torch.nn.Module):
+    def __init__(self, radius_min=0.1, radius_max=2.):
+        """
+           Apply Gaussian Blur to the PIL image. Take the radius and probability of
+           application as the parameter.
+           This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
+           """
+        super(PILRandomGaussianBlur, self).__init__()
+        self.radius_min = radius_min
+        self.radius_max = radius_max
+
+    def forward(self, img, mask):
+        return [img.filter(
+            ImageFilter.GaussianBlur(
+                radius=random.uniform(self.radius_min, self.radius_max)
+            )
+        ), mask]
+
+
+class RandomHorizontalFlip(torch.nn.Module):
+    def __init__(self):
+        super(RandomHorizontalFlip, self).__init__()
+
+    def forward(self, img, mask):
+        image = torchvision.transforms.functional.hflip(img)
+        mask = torchvision.transforms.functional.hflip(mask)
+        return [image, mask]
+
+
+class GetColorDistortion(torch.nn.Module):
+    def __int__(self, s=1.0):
+        super(GetColorDistortion, self).__init__()
+        self.s = s
+
+    def forward(self, img, mask):
+        color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
+        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+        rnd_gray = transforms.RandomGrayscale(p=0.2)
+        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
+        return color_distort(img, mask)
+
+
+# TODO this net
+class GaussNoise(torch.nn.Module):
+    def __init__(self, mean=0, var=10000):
+        super(GaussNoise, self).__init__()
+        self.mean = mean
+        self.var = var
+
+    def forward(self, img, mask):
+        row, col = img.size
+        sigma = self.var ** 0.5
+        gauss = np.random.normal(self.mean, sigma, (row, col))
+        gauss = gauss.reshape(row, col)
+        noisy = img + gauss
+        return [noisy, mask]
+
+
+class SaltPepperNoise(torch.nn.Module):
+    def __init__(self, salt_or_pepper=0.5):
+        super(SaltPepperNoise, self).__init__()
+        self.salt_or_pepper = salt_or_pepper
+
+    def forward(self, img, mask):
+        if len(img.size) == 3:
+            img_size = img.size[1] * img.size[2]
+        else:
+            img_size = img.size[0] * img.size[1]
+        amount = .4
+        noisy = np.copy(img)
+        # Salt mode
+        num_salt = np.ceil(amount * img_size * self.salt_or_pepper)
+        target_pixels = [np.random.randint(0, i - 1, int(num_salt))
+                         for i in img.size]
+        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
+        for i, j in target_pixels:
+            noisy[i][j] = 1
+        # Pepper mode
+        num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper))
+        target_pixels = [np.random.randint(0, i - 1, int(num_pepper))
+                         for i in img.size]
+        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
+
+        for i, j in target_pixels:
+            noisy[i][j] = 0
+
+        # plt.imshow(img)
+        # plt.imshow(noisy)
+        # plt.show()
+        return [noisy, mask]
+
+
+class PoissionNoise(torch.nn.Module):
+    def __init__(self):
+        super(PoissionNoise, self).__init__()
+
+    def forward(self, img, mask):
+        vals = len(np.unique(img))
+        vals = 2 ** np.ceil(np.log2(vals))
+        noisy = np.random.poisson(img * vals) / float(vals)
+        return [noisy, mask]
+
+
+# TODO this one
+class SpeckleNoise(torch.nn.Module):
+    def __init__(self):
+        super(SpeckleNoise, self).__init__()
+
+    def forward(self, img, mask):
+        row, col = img.size
+        gauss = np.random.randn(row, col)
+        gauss = gauss.reshape(row, col)
+        noisy = img + img * gauss
+        return [noisy, mask]
+
+
+class SetZeroNoise(torch.nn.Module):
+    def __init__(self):
+        super(SetZeroNoise, self).__init__()
+
+    def forward(self, img, mask):
+        row, col = img.size
+        img_size = row * col
+        random_rows = np.random.randint(0, int(row), (1, int(img_size)))
+        random_cols = np.random.randint(0, int(col), (1, int(img_size)))
+        target_pixels = list(zip(random_rows[0], random_cols[0]))
+        for pix in target_pixels:
+            img[pix[0], pix[1]] = 0
+        return [img, mask]
+
+
+####################################################################################
+
+###########Base_code Agumentations suggestions: #TODO do not need them now##########
+class MyWrap(torch.nn.Module):
+    """
+    Random wrap augmentation taken from tormentor
+    """
+
+    def __init__(self):
+        super(MyWrap, self).__init__()
+
+    def forward(self, img, target):
+        # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics
+        wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)),
+                                                          intensity=tormentor.random.Uniform(value_range=(.0, 1.)))
+        wrap = wrap_rand()
+        image = wrap(img)
+        mask = wrap(target, is_mask=True)
+        return [image, mask]
+
+
+class Rotate(torch.nn.Module):
+    """
+    Random rotation augmentation
+    """
+
+    def __init__(self):
+        super(Rotate, self).__init__()
+
+    def forward(self, img, target):
+        random = np.random.randint(0, 3)
+        angle = 90
+        if random == 1:
+            angle = 180
+        elif random == 2:
+            angle = 270
+        image = torchvision.transforms.functional.rotate(img, angle=angle)
+        mask = torchvision.transforms.functional.rotate(target, angle=angle)
+        return [image, mask.squeeze(0)]
+
+
+class Bright(torch.nn.Module):
+    """
+    Random brightness adjustment augmentations
+    """
+
+    def __init__(self,
+                 lower_band: float = -0.2,
+                 upper_band: float = 0.2):
+        super(Bright, self).__init__()
+        self.lower_band = lower_band
+        self.upper_band = upper_band
+
+    def forward(self, img, mask):
+        bright_rand = tormentor.Brightness.override_distributions(
+            brightness=tormentor.random.Uniform((self.lower_band, self.upper_band)))
+        # bright = bright_rand()
+        # image_transformed = img.clone()
+        image_transformed = bright_rand(np.asarray(img))
+        # set NA areas back to zero
+        image_transformed.seed[img == 0] = 0.0
+
+        return [image_transformed.seed, mask]
+
+
+class Noise(torch.nn.Module):
+    """
+    Random additive noise augmentation
+    """
+
+    def __init__(self):
+        super(Noise, self).__init__()
+
+    def forward(self, img, target):
+        # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again
+        noise = torch.normal(mean=0, std=0.3, size=img.size)
+        image = img + img * noise
+        image[image > 1.0] = 1.0
+        image[image < 0.0] = 0.0
+
+        return [image, target]
+
+
+class Resize(torch.nn.Module):
+    # Library code
+    """Resize the input image to the given size.
+    If the image is torch Tensor, it is expected
+    to have [..., H, W] shape, where ... means a maximum of two leading dimensions
+
+    .. warning::
+        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
+        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
+        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
+        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
+        closer.
+
+    Args:
+        size (sequence or int): Desired output size. If size is a sequence like
+            (h, w), output size will be matched to this. If size is an int,
+            smaller edge of the image will be matched to this number.
+            i.e, if height > width, then image will be rescaled to
+            (size * height / width, size).
+
+            .. note::
+                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
+            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
+            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
+        max_size (int, optional): The maximum allowed for the longer edge of
+            the resized image. If the longer edge of the image is greater
+            than ``max_size`` after being resized according to ``size``,
+            ``size`` will be overruled so that the longer edge is equal to
+            ``max_size``.
+            As a result, the smaller edge may be shorter than ``size``. This
+            is only supported if ``size`` is an int (or a sequence of length
+            1 in torchscript mode).
+        antialias (bool, optional): Whether to apply antialiasing.
+            It only affects **tensors** with bilinear or bicubic modes and it is
+            ignored otherwise: on PIL images, antialiasing is always applied on
+            bilinear or bicubic modes; on other modes (for PIL images and
+            tensors), antialiasing makes no sense and this parameter is ignored.
+            Possible values are:
+
+            - ``True``: will apply antialiasing for bilinear or bicubic modes.
+              Other mode aren't affected. This is probably what you want to use.
+            - ``False``: will not apply antialiasing for tensors on any mode. PIL
+              images are still antialiased on bilinear or bicubic modes, because
+              PIL doesn't support no antialias.
+            - ``None``: equivalent to ``False`` for tensors and ``True`` for
+              PIL images. This value exists for legacy reasons and you probably
+              don't want to use it unless you really know what you are doing.
+
+            The current default is ``None`` **but will change to** ``True`` **in
+            v0.17** for the PIL and Tensor backends to be consistent.
+    """
+
+    def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
+        super().__init__()
+        self.size = size
+        self.max_size = max_size
+        self.interpolation = interpolation
+        self.antialias = antialias
+
+    def forward(self, img, mask):
+        """
+        Args:
+            img (PIL Image or Tensor): Image to be scaled.
+
+        Returns:
+            PIL Image or Tensor: Rescaled image.
+        """
+        return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask]
+
+
+class CenterCrop(torch.nn.Module):
+    """Crops the given image at the center.
+    If the image is torch Tensor, it is expected
+    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
+
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
+    """
+
+    def __init__(self, size):
+        super().__init__()
+        self.size = size  # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
+    def forward(self, img, mask):
+        """
+        Args:
+            img (PIL Image or Tensor): Image to be cropped.
+
+        Returns:
+            PIL Image or Tensor: Cropped image.
+        """
+        ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER
+        W, H = img.size
+
+        WW = (W // self.size) + 2
+        HH = (H // self.size) + 2
+
+        return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))]
+
+
+class ToTensorZones():
+    def __call__(self, image, target):
+        image = F.to_tensor(np.array(image).astype(np.float32))
+        target = torch.from_numpy(np.array(target).astype(np.float32))
+        # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254
+        target[target == 0] = 0
+        target[target == 64] = 1
+        target[target == 127] = 2
+        target[target == 254] = 3
+        # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3
+        return [image, target]
diff --git a/SSLGlacier/modules/datamodule_.py b/SSLGlacier/modules/datamodule_.py
new file mode 100644
index 0000000000000000000000000000000000000000..edf3648e314a451703627fcfc8d5f855c90d7870
--- /dev/null
+++ b/SSLGlacier/modules/datamodule_.py
@@ -0,0 +1,149 @@
+from torchvision.transforms import functional as F
+from pl_bolts.models.self_supervised import swav
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+import pytorch_lightning as pl
+from torchvision import transforms as transform_lib
+from . import agumentations_ as our_transforms
+import torchvision
+import numpy as np
+from typing import Tuple, List, Any, Callable, Optional
+import torch
+import cv2
+from PIL import Image,ImageOps
+import os
+
+# TODO add moco, simclr,cpc and amdim transformers
+class SSLDataset(Dataset):
+    def __init__(self, parent_dir, transform,return_index=False, **kwargs):  # TODO Pass them as arguments
+        '''
+        Args:
+            mode: can be tested, train, or validation
+            parent_dir: directory in which the folders test, train and validation exists
+        '''
+        self.mode = kwargs['mode'] if kwargs['mode'] else 'train'
+        self.return_index = return_index
+        self.images_path = os.path.join(parent_dir, "sar_images", self.mode)
+        self.masks_path = os.path.join(parent_dir, 'zones', self.mode)
+        self.images = os.listdir(self.images_path)
+        self.masks = os.listdir(self.masks_path)
+        assert len(self.masks) == len(self.images), "You don't have the same number of images and masks"
+        self.transform = transform
+        # Sort than images and masks fit together
+        self.images.sort()
+        self.masks.sort()
+
+        # Let shuffle and save indices, and load them in next calls if they exist
+        if not os.path.exists(os.path.join("data_processing", "data_splits")):
+            os.makedirs(os.path.join("data_processing", "data_splits"))
+        if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")):
+            shuffle = np.random.permutation(len(self.images))
+            # Works for numpy version >= 1.5
+            np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
+                       newline=' ')
+
+        else:
+            # use already existing shuffle
+            with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp:
+                lines = fp.readlines()
+                shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines]
+                # if lengths do not match, we need to create a new permutation
+                if len(shuffle) != len(self.images):
+                    shuffle = np.random.permutation(len(self.images))
+                    np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
+                               newline=' ')
+
+        self.images = np.array(self.images)  # Why!!!!
+        self.masks = np.array(self.masks)
+        self.images = self.images[shuffle].copy()
+        self.masks = self.masks[shuffle].copy()
+        self.images = list(self.images)
+        self.masks = list(self.masks)  # Why!!!!
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, index):
+        img_name = self.images[index]
+        masks_name = self.masks[index]
+        assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \
+            "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name
+        #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE)
+        #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE)
+        image = Image.open(os.path.join(self.images_path, img_name).__str__())
+        image = ImageOps.grayscale(image)
+        #image = np.array(image).astype(np.float32)
+
+        #rgbimg = Image.new("RGBA", image.size)
+        #image = rgbimg.paste(image)
+        #image = ImageOps.grayscale(image)
+        mask = Image.open(os.path.join(self.masks_path, masks_name).__str__())
+        #rgbimg = Image.new("RGBA", mask.size)
+        #mask = rgbimg.paste(rgbimg)
+        mask = ImageOps.grayscale(mask)
+        #mask = np.array(mask).astype(np.float32)
+        # TODO check if it works or not, it should return multiple transformation for one image
+        # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!!
+        # TODO this is a list of tuples of img and masks, think if you want to change it later
+
+       # to_tensor = our_transforms.ToTensorZones()
+       # _, mask = to_tensor(image,mask)
+        if self.transform is not None:
+            augmented_imgs, masks = self.transform(image, mask)
+        if self.return_index:
+            return index, augmented_imgs, mask, img_name, masks_name
+        return augmented_imgs, masks
+
+
+class SSLDataModule(pl.LightningDataModule):
+    # Base is adopted from Nora's code
+    def __init__(self, batch_size, parent_dir, args):
+        """
+        :param batch_size: batch size
+        :param target: Either 'zones' or 'front'. Tells which masks should be used.
+        """
+        super().__init__()
+        self.transforms = None
+        self.batch_size = batch_size
+        self.glacier_test = None
+        self.glacier_train = None
+        self.glacier_val = None
+        self.parent_dir = parent_dir
+        self.aug_args = args.agumentations
+        self.size_crops = args.size_crops
+        self.num_crops = args.nmb_crops
+
+    def prepare_data(self):
+        # download,
+        # only called on 1 GPU/TPU in distributed
+        pass
+    def __len__(self):
+        return len(self.glacier_train)
+    def setup(self, stage=None):
+        # process and split here
+        # make assignments here (val/train/test split)
+        # called on every process in DDP
+        if stage == 'test' or stage is None:
+            self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
+            self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage)
+        if stage == 'fit' or stage is None:
+            self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
+            self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train')
+
+            self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
+            self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val')
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True,
+                          drop_last=True)
+
+    def val_dataloader(self)-> DataLoader:
+        return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True)
+
+    def test_dataloader(self)-> DataLoader:
+        return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True,
+                          drop_last=True)  # TODO self.batch_size
+
+    def _default_transforms(self) -> Callable:
+        #return transform_lib.Compose([transform_lib.ToTensor()])
+        return our_transforms.Compose([our_transforms.ToTensorZones()])
diff --git a/SSLGlacier/modules/semi_module.py b/SSLGlacier/modules/semi_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..9798a3620f90f6dc9e849eb08eecafe45f256c24
--- /dev/null
+++ b/SSLGlacier/modules/semi_module.py
@@ -0,0 +1,29 @@
+from pytorch_lightning import LightningModule
+from typing import Any
+class Semi_(LightningModule):
+    def __init__(self):
+        pass
+
+    def setup(self, stage: str) -> None:
+        pass
+
+    def init_model(self):
+        pass
+
+    def forward(self, *args: Any, **kwargs: Any) -> Any:
+        pass
+
+    def on_train_epoch_start(self) -> None:
+        pass
+
+    def shared_step(self, batch):
+        inputs, y = batch
+        inputs = inputs[:, -1]
+        embedding = self.model(inputs)
+        ####Where should I add noise....
+        ####assume embedding[b][i] belongs to batch b view i
+        ####if we have two view, one with noise and one without, then
+        ##### select some portion of each view belongs to different classes.
+        ##### pos_set = same classess ..
+        ##### neg_set = diff classes
+
diff --git a/SSLGlacier/modules/swav_module.py b/SSLGlacier/modules/swav_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b57a4ede5265846c6a4aeb4b885d3d31169c206
--- /dev/null
+++ b/SSLGlacier/modules/swav_module.py
@@ -0,0 +1,562 @@
+"""Adapted from official swav implementation: https://github.com/facebookresearch/swav."""
+import os
+from argparse import ArgumentParser
+
+import torch
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
+from torch import nn
+
+#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
+from SSLGlacier.models.losses import SWAVLoss
+from SSLGlacier.models.RESNET import resnet18, resnet50
+#from Base_Code.models.ParentUNet import UNet
+from SSLGlacier.models.UNET import UNet #The main network
+from pl_bolts.optimizers.lars import LARS
+from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
+
+##########Importing files for hook############
+from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
+######Temporary config file for hooknet#######
+from utils.config import get_config
+
+class SwAV(LightningModule):
+    def __init__(
+        self,
+        gpus: int,
+        num_samples: int,
+        batch_size: int,
+        dataset: str,
+        num_nodes: int = 1,
+        arch: str = "resnet50",
+        hidden_mlp: int = 2048,
+        feat_dim: int = 128,
+        warmup_epochs: int = 10,
+        max_epochs: int = 100,
+        num_prototypes: int = 3000,
+        freeze_prototypes_epochs: int = 1,
+        temperature: float = 0.1,
+        sinkhorn_iterations: int = 3,
+        queue_length: int = 0,  # must be divisible by total batch-size
+        queue_path: str = "queue",
+        epoch_queue_starts: int = 15,
+        crops_for_assign: tuple = (0, 1),
+        num_crops: tuple = (2, 6),
+        num_augs: int= 2,
+        first_conv: bool = True,
+        maxpool1: bool = True,
+        optimizer: str = "adam",
+        exclude_bn_bias: bool = False,
+        start_lr: float = 0.0,
+        learning_rate: float = 1e-3,
+        final_lr: float = 0.0,
+        weight_decay: float = 1e-6,
+        epsilon: float = 0.05,
+        just_aug_for_same_assign_views: bool = False,
+        swin_hparams = {}
+    ) -> None:
+        """
+        Args:
+            gpus: number of gpus per node used in training, passed to SwAV module
+                to manage the queue and select distributed sinkhorn
+            num_nodes: number of nodes to train on
+            num_samples: number of image samples used for training
+            batch_size: batch size per GPU in ddp
+            dataset: dataset being used for train/val
+            arch: encoder architecture used for pre-training
+            hidden_mlp: hidden layer of non-linear projection head, set to 0
+                to use a linear projection head
+            feat_dim: output dim of the projection head
+            warmup_epochs: apply linear warmup for this many epochs
+            max_epochs: epoch count for pre-training
+            num_prototypes: count of prototype vectors
+            freeze_prototypes_epochs: epoch till which gradients of prototype layer
+                are frozen
+            temperature: loss temperature
+            sinkhorn_iterations: iterations for sinkhorn normalization
+            queue_length: set queue when batch size is small,
+                must be divisible by total batch-size (i.e. total_gpus * batch_size),
+                set to 0 to remove the queue
+            queue_path: folder within the logs directory
+            epoch_queue_starts: start uing the queue after this epoch
+            crops_for_assign: list of crop ids for computing assignment
+            num_crops: number of global and local crops, ex: [2, 6]
+            first_conv: keep first conv same as the original resnet architecture,
+                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
+            maxpool1: keep first maxpool layer same as the original resnet architecture,
+                if set to false, first maxpool is turned off (cifar10, maybe stl10)
+            optimizer: optimizer to use
+            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
+            start_lr: starting lr for linear warmup
+            learning_rate: learning rate
+            final_lr: float = final learning rate for cosine weight decay
+            weight_decay: weight decay for optimizer
+            epsilon: epsilon val for swav assignments
+        """
+        super().__init__()
+        self.save_hyperparameters()
+
+        self.gpus = gpus
+        self.num_nodes = num_nodes
+        self.arch = arch
+        self.dataset = dataset
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        self.hidden_mlp = hidden_mlp
+        self.feat_dim = feat_dim
+        self.num_prototypes = num_prototypes
+        self.freeze_prototypes_epochs = freeze_prototypes_epochs
+        self.sinkhorn_iterations = sinkhorn_iterations
+
+        self.queue_length = queue_length
+        self.queue_path = queue_path
+        self.epoch_queue_starts = epoch_queue_starts
+        self.crops_for_assign = crops_for_assign
+        self.num_crops = num_crops
+        self.num_augs = num_augs
+
+        self.first_conv = first_conv
+        self.maxpool1 = maxpool1
+
+        self.optim = optimizer
+        self.exclude_bn_bias = exclude_bn_bias
+        self.weight_decay = weight_decay
+        self.epsilon = epsilon
+        self.temperature = temperature
+
+        self.start_lr = start_lr
+        self.final_lr = final_lr
+        self.learning_rate = learning_rate
+        self.warmup_epochs = warmup_epochs
+        self.max_epochs = max_epochs
+        self.just_aug_for_same_assign_views = just_aug_for_same_assign_views
+
+        self.model = self.init_model()
+        self.criterion = SWAVLoss(
+            gpus=self.gpus,
+            num_nodes=self.num_nodes,
+            temperature=self.temperature,
+            crops_for_assign=self.crops_for_assign,
+            num_crops=self.num_crops,
+            num_augs=self.num_augs,
+            sinkhorn_iterations=self.sinkhorn_iterations,
+            epsilon=self.epsilon,
+            just_aug_for_same_assign_views= self.just_aug_for_same_assign_views
+        )
+        self.use_the_queue = None
+        # compute iters per epoch
+        global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
+        self.train_iters_per_epoch = self.num_samples // global_batch_size
+        self.queue = None
+
+        ####For hook####
+        self.swin_hparams = swin_hparams
+    def setup(self, stage):
+        if self.queue_length > 0:
+            queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
+            if not os.path.exists(queue_folder):
+                os.makedirs(queue_folder)
+
+            self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth")
+
+            if os.path.isfile(self.queue_path):
+                self.queue = torch.load(self.queue_path)["queue"]
+
+    def init_model(self):
+        if self.arch == "resnet18":
+            backbone = resnet18(normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1)
+
+        elif self.arch == "resnet50":
+            backbone = resnet50(normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1)
+        elif self.arch == "hook":
+            backbone = SwinUnet(
+                                img_size=224,
+                                num_classes=5,
+                                normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1,######till here for basenet
+                                image_size= self.swin_hparams.image_size,####from here for swin itself
+                                swin_patch_size = self.swin_hparams.swin_patch_size,
+                                swin_in_chans=self.swin_hparams.swin_in_chans,
+                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
+                                swin_depths=self.swin_hparams.swin_depths,
+                                swin_num_heads=self.swin_hparams.swin_num_heads,
+                                swin_window_size=self.swin_hparams.swin_window_size,
+                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
+                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
+                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
+                                drop_rate=self.drop_rate,
+                                drop_path_rate=self.drop_path_rate,
+                                swin_ape=self.swin_hparams.swin_ape,
+                                swin_patch_norm=self.swin_hparams.swin_path_norm
+                                )
+
+        elif self.arch == "Unet" or "unet":
+            backbone = UNet(num_classes=5,
+                            input_channels=1,
+                            num_layers=5,
+                            features_start=64,
+                            bilinear=False,
+                            normalize=True,
+                            hidden_mlp=self.hidden_mlp,
+                            output_dim=self.feat_dim,
+                            num_prototypes=self.num_prototypes,
+                            first_conv=self.first_conv,
+                            maxpool1=self.maxpool1
+                            )
+        else:
+            raise f'{self.arch} model is not defined'
+
+        return backbone
+
+    def forward(self, x):
+        # pass single batch from the resnet backbone
+        return self.model.forward_backbone(x)
+
+    def on_train_epoch_start(self):
+        if self.queue_length > 0:
+            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
+                self.queue = torch.zeros(
+                    len(self.crops_for_assign),
+                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
+                    self.feat_dim,
+                )
+
+            if self.queue is not None:
+                self.queue = self.queue.to(self.device)
+
+        self.use_the_queue = False
+
+    def on_train_epoch_end(self) -> None:
+        if self.queue is not None:
+            torch.save({"queue": self.queue}, self.queue_path)
+
+    def on_after_backward(self):
+        if self.current_epoch < self.freeze_prototypes_epochs:
+            for name, p in self.model.named_parameters():
+                if "prototypes" in name:
+                    p.grad = None
+
+    def shared_step(self, batch):
+        if self.dataset == "stl10":
+            unlabeled_batch = batch[0]
+            batch = unlabeled_batch
+
+        inputs, y = batch
+        inputs = inputs[:-1]  # remove online train/eval transforms at this point
+
+        # 1. normalize the prototypes
+        with torch.no_grad():
+            w = self.model.prototypes.weight.data.clone()
+            w = nn.functional.normalize(w, dim=1, p=2)
+            self.model.prototypes.weight.copy_(w)
+
+        # 2. multi-res forward passes
+        embedding, output = self.model(inputs)
+        embedding = embedding.detach()
+        bs = inputs[0].size(0)
+
+        # SWAV loss computation
+        loss, queue, use_queue = self.criterion(
+            output=output,
+            embedding=embedding,
+            prototype_weights=self.model.prototypes.weight,
+            batch_size=bs,
+            queue=self.queue,
+            use_queue=self.use_the_queue,
+        )
+        self.queue = queue
+        self.use_the_queue = use_queue
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        loss = self.shared_step(batch)
+
+        self.log("train_loss", loss, on_step=True, on_epoch=False)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        loss = self.shared_step(batch)
+
+        self.log("val_loss", loss, on_step=False, on_epoch=True)
+        return loss
+
+    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")):
+        params = []
+        excluded_params = []
+
+        for name, param in named_params:
+            if not param.requires_grad:
+                continue
+            if any(layer_name in name for layer_name in skip_list):
+                excluded_params.append(param)
+            else:
+                params.append(param)
+
+        return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}]
+
+    def configure_optimizers(self):
+        if self.exclude_bn_bias:
+            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
+        else:
+            params = self.parameters()
+
+        if self.optim == "lars":
+            optimizer = LARS(
+                params,
+                lr=self.learning_rate,
+                momentum=0.9,
+                weight_decay=self.weight_decay,
+                trust_coefficient=0.001,
+            )
+        elif self.optim == "adam":
+            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
+
+        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
+        total_steps = self.train_iters_per_epoch * self.max_epochs
+
+        scheduler = {
+            "scheduler": torch.optim.lr_scheduler.LambdaLR(
+                optimizer,
+                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
+            ),
+            "interval": "step",
+            "frequency": 1,
+        }
+
+        return [optimizer], [scheduler]
+
+    @staticmethod
+    def add_model_specific_args(parent_parser):
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+
+        # model params
+        parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
+        # specify flags to store false
+        parser.add_argument("--first_conv", action="store_false")
+        parser.add_argument("--maxpool1", action="store_false")
+        parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head")
+        parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension")
+        parser.add_argument("--online_ft", action="store_true")
+        parser.add_argument("--fp32", action="store_true")
+
+        # transform params
+        parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur")
+        parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength")
+        parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10")
+        parser.add_argument("--data_dir", type=str, default=".", help="path to download data")
+        parser.add_argument("--queue_path", type=str, default="queue", help="path for queue")
+
+        parser.add_argument(
+            "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])"
+        )
+        parser.add_argument(
+            "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])"
+        )
+        parser.add_argument(
+            "--min_scale_crops",
+            type=float,
+            default=[0.33, 0.10],
+            nargs="+",
+            help="argument in RandomResizedCrop (example: [0.14, 0.05])",
+        )
+        parser.add_argument(
+            "--max_scale_crops",
+            type=float,
+            default=[1, 0.33],
+            nargs="+",
+            help="argument in RandomResizedCrop (example: [1., 0.14])",
+        )
+
+        # training params
+        parser.add_argument("--fast_dev_run", default=1, type=int)
+        parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
+        parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
+        parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
+        parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
+        parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay")
+        parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
+        parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
+        parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
+        parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
+
+        parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay")
+        parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate")
+        parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate")
+        parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate")
+
+        # swav params
+        parser.add_argument(
+            "--crops_for_assign",
+            type=int,
+            nargs="+",
+            default=[0, 1],
+            help="list of crops id used for computing assignments",
+        )
+        parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss")
+        parser.add_argument(
+            "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm"
+        )
+        parser.add_argument(
+            "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm"
+        )
+        parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes")
+        parser.add_argument(
+            "--queue_length",
+            type=int,
+            default=0,
+            help="length of the queue (0 for no queue); must be divisible by total batch size",
+        )
+        parser.add_argument(
+            "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue"
+        )
+        parser.add_argument(
+            "--freeze_prototypes_epochs",
+            default=1,
+            type=int,
+            help="freeze the prototypes during this many epochs from the start",
+        )
+
+        return parser
+
+
+def cli_main():
+    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
+    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
+    from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
+
+    parser = ArgumentParser()
+
+    # model args
+    parser = SwAV.add_model_specific_args(parser)
+    args = parser.parse_args()
+
+    if args.dataset == "stl10":
+        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        dm.train_dataloader = dm.train_dataloader_mixed
+        dm.val_dataloader = dm.val_dataloader_mixed
+        args.num_samples = dm.num_unlabeled_samples
+
+        args.maxpool1 = False
+
+        normalization = stl10_normalization()
+    elif args.dataset == "cifar10":
+        args.batch_size = 2
+        args.num_workers = 0
+
+        dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        args.num_samples = dm.num_samples
+
+        args.maxpool1 = False
+        args.first_conv = False
+
+        normalization = cifar10_normalization()
+
+        # cifar10 specific params
+        args.size_crops = [32, 16]
+        args.num_crops = [2, 1]
+        args.gaussian_blur = False
+    elif args.dataset == "imagenet":
+        args.maxpool1 = True
+        args.first_conv = True
+        normalization = imagenet_normalization()
+
+        args.size_crops = [224, 96]
+        args.num_crops = [2, 6]
+        args.min_scale_crops = [0.14, 0.05]
+        args.max_scale_crops = [1.0, 0.14]
+        args.gaussian_blur = True
+        args.jitter_strength = 1.0
+
+        args.batch_size = 64
+        args.num_nodes = 8
+        args.gpus = 8  # per-node
+        args.max_epochs = 800
+
+        args.optimizer = "lars"
+        args.learning_rate = 4.8
+        args.final_lr = 0.0048
+        args.start_lr = 0.3
+
+        args.num_prototypes = 3000
+        args.online_ft = True
+
+        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        args.num_samples = dm.num_samples
+        args.input_height = dm.dims[-1]
+    else:
+        raise NotImplementedError("other datasets have not been implemented till now")
+
+    dm.train_transforms = SwAVTrainDataTransform(
+        normalize=normalization,
+        size_crops=args.size_crops,
+        num_crops=args.num_crops,
+        min_scale_crops=args.min_scale_crops,
+        max_scale_crops=args.max_scale_crops,
+        gaussian_blur=args.gaussian_blur,
+        jitter_strength=args.jitter_strength,
+    )
+
+    dm.val_transforms = SwAVEvalDataTransform(
+        normalize=normalization,
+        size_crops=args.size_crops,
+        num_crops=args.num_crops,
+        min_scale_crops=args.min_scale_crops,
+        max_scale_crops=args.max_scale_crops,
+        gaussian_blur=args.gaussian_blur,
+        jitter_strength=args.jitter_strength,
+    )
+
+    # swav model init
+    model = SwAV(**args.__dict__)
+
+    online_evaluator = None
+    if args.online_ft:
+        # online eval
+        online_evaluator = SSLOnlineEvaluator(
+            drop_p=0.0,
+            hidden_dim=None,
+            z_dim=args.hidden_mlp,
+            num_classes=dm.num_classes,
+            dataset=args.dataset,
+        )
+
+    lr_monitor = LearningRateMonitor(logging_interval="step")
+    model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss")
+    callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
+    callbacks.append(lr_monitor)
+
+    trainer = Trainer(
+        max_epochs=args.max_epochs,
+        max_steps=None if args.max_steps == -1 else args.max_steps,
+        gpus=args.gpus,
+        num_nodes=args.num_nodes,
+        accelerator="ddp" if args.gpus > 1 else None,
+        sync_batchnorm=args.gpus > 1,
+        precision=32 if args.fp32 else 16,
+        callbacks=callbacks,
+        fast_dev_run=args.fast_dev_run,
+    )
+
+    trainer.fit(model, datamodule=dm)
+
+
+if __name__ == "__main__":
+    cli_main()
\ No newline at end of file
diff --git a/SSLGlacier/modules/transformers_.py b/SSLGlacier/modules/transformers_.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8290bb69b2c33649e92a9dd9ae7797595b94e85
--- /dev/null
+++ b/SSLGlacier/modules/transformers_.py
@@ -0,0 +1,248 @@
+from pl_bolts.models.self_supervised import swav
+from torchvision import transforms  # need it for validation and fine tuning
+from . import agumentations_ as ag
+from typing import Tuple, List, Dict
+from torch import Tensor
+
+
+
+class OurTrainTransformer():
+    def __init__(self,
+                 size_crops: Tuple[int] = (294, 94),
+                 nmb_crops: Tuple[int] = (2, 4),
+                 augs:Dict = {'SaltPepper':1},
+                 use_hook_former: bool = True) -> List:
+        # TODO you can set it in a way that making sure you have one sample from each argumentation method
+        self.size_crops = size_crops
+        self.num_crops = nmb_crops
+        augmented_imgs = []
+        self.use_hook_former = use_hook_former
+        self.transform = []
+        if augs['OrigPixelValues']:
+            augmented_imgs.append(ag.DoNothing())
+        if augs['RGaussianB']:
+            augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.))
+        if augs['GaussiN']:
+            augmented_imgs.append(ag.GaussNoise(mean=0, var=10000))
+        if augs['SaltPepper']:
+            augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2))
+        if augs['ZeroN']:
+            augmented_imgs.append(ag.SetZeroNoise())
+        # Noras augmentations methods
+        if augs['flip']:
+            # TODO manage this, how to add mask for the times we need transformation for masks too
+            augmented_imgs.append(ag.RandomHorizentalFlip())
+            # image = torchvision.transforms.functional.hflip(image)
+        # mask = torchvision.transforms.functional.hflip(mask)
+        # image, mask = ag.ToTensorZones(image=image, target=np.array(mask))
+        # augmented_imgs.append((image, mask))
+        if augs['rotate']:
+            augmented_imgs.append(ag.Rotate())
+        if augs['bright']:
+            augmented_imgs.append(ag.Bright())
+        if augs['wrap']:
+            augmented_imgs.append()
+        if augs['noise']:
+            augmented_imgs.append(ag.Noise())
+
+        if augs['normalize'] is not None:
+            self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']])
+        else:
+            self.final_transform = ag.ToTensorZones()
+
+        if self.use_hook_former:
+            self.hook_former_transformer(augmented_imgs)
+        else:
+            self.other_networks_transformers(augmented_imgs)
+
+
+
+    def __call__(self, image, mask):
+        """
+        import matplotlib.pyplot as plt
+        img1 =transformed_imgs[0].detach().cpu().numpy().squeeze()
+        plt.imshow(img1)
+        plt.show()
+        """
+        transformed_img_mask = []
+        for transform in self.transform:
+            if isinstance(transform.__getitem__(0), ag.RandomCropper):
+                transformed_img_mask.append(transform(image, mask))
+                self.i = transform.__getitem__(0).left
+                self.j = transform.__getitem__(0).right
+                self.w = transform.__getitem__(0).width
+                self.h = transform.__getitem__(0).height
+            elif isinstance(transform.__getitem__(0), ag.Cropper):
+                transform.__getitem__(0).left = self.i
+                transform.__getitem__(0).right = self.j
+                transform.__getitem__(0).width = self.w
+                transform.__getitem__(0).height = self.h
+                transformed_img_mask.append(transform(image, mask))
+            else:
+                transformed_img_mask.append(transform(image,mask))
+        transformed_imgs = [item[0] for item in transformed_img_mask]
+        transformed_masks = [item[1] for item in transformed_img_mask]
+
+
+        # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12))
+        # fig.suptitle("Patches used in training", fontsize=18, y=.95)
+        #
+        # for i, image_id in enumerate(transformed_imgs):
+        #     img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze()
+        #     mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze()
+        #     axs[0, i].imshow(img1)
+        #     axs[0, i].imshow(mask1)
+        #     axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze())
+        #     axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze())
+        # plt.savefig('High.png')
+        return transformed_imgs, transformed_masks
+
+    def hook_former_transformer(self, augmented_imgs):
+        transform = []
+        for i in range(len(self.size_crops)):
+            i_transform = []
+            for ith_aug, aug in enumerate(augmented_imgs):
+                # For the first time we crop randomly, then save coordinates for further transformation
+                if ith_aug == 0:
+                    global_crop = ag.RandomCropper(self.size_crops[i])
+                    i_transform.extend(
+                        [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])]
+                    )
+                else:  # Crop same area, and do augmentation
+                    local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4)
+                    i_transform.extend(
+                        [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])])
+            transform += i_transform * self.num_crops[i]
+        self.transform =  transform
+        # add online train transform of the size of global view
+        online_train_transform = ag.Compose(
+            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
+        )
+
+        self.transform.append(online_train_transform)
+        return transform
+
+    def other_networks_transformers(self, augmented_imgs):
+        transform = []
+        for i in range(len(self.size_crops)):
+            i_transform = []
+            for ith_aug, aug in enumerate(augmented_imgs):
+                # For the first time we crop randomly, then save coordinates for further transformation
+                if ith_aug == 0:
+                    random_crop = ag.RandomCropper(self.size_crops[i])
+                    i_transform.extend(
+                        [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])]
+                    )
+                else:  # Crop same area, and do augmentation
+                    fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width)
+                    i_transform.extend(
+                        [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])])
+            transform += i_transform * self.num_crops[i]
+
+        self.transform = transform
+        # add online train transform of the size of global view
+        #What was that for?? I forgot!
+        online_train_transform = ag.Compose(
+            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
+        )
+
+        self.transform.append(online_train_transform)
+        return transform
+
+class OurEvalTransformer(OurTrainTransformer):
+    def __init__(self,
+                 size_crops: Tuple[int] = (294, 294),
+                 nmb_crops: Tuple[int] = (2, 4),
+                 augs:Dict = {'SaltPepper':1},
+                 use_hook_former: bool =True) -> List:
+        super().__init__(size_crops=size_crops,
+                       nmb_crops=nmb_crops,
+                       augs = augs,
+                         use_hook_former=use_hook_former)
+
+        input_height = self.size_crops[0]  # get global view crop
+        test_transform = ag.Compose(
+            [
+                ag.Resize(int(input_height + 0.1 * input_height)),
+                ag.CenterCrop(input_height),
+                self.final_transform,
+            ]
+        )
+
+    # replace last transform to eval transform in self.transform list
+        self.transform[-1] = test_transform
+
+class OurFinetuneTransform:
+    def __init__(self,
+                 input_height: int = 224,
+                 normalize = None,
+                 eval_transform: bool = False) -> None :
+
+        if not eval_transform:# TODO think about it
+            data_transforms = [
+                transforms.RandomResizedCrop(size=self.input_height),
+                ag.RandomHorizontalFlip(),
+                ag.SaltPepperNoise(salt_or_pepper=0.5),
+                transforms.RandomGrayscale(p=0.2),
+            ]
+        else:
+            data_transforms = ag.Compose(
+                [
+                    transforms.Resize(int(input_height + 0.1 * input_height)),
+                    transforms.CenterCrop(input_height),
+                    self.final_transform,
+                ]
+            )
+
+        if normalize is None:
+            final_transform = transforms.ToTensor()
+        else:
+            final_transform = transforms.Compose([transforms.ToTensor(), normalize])
+
+        data_transforms.append(final_transform)
+        self.transform = ag.Compose(data_transforms)
+
+    def __call__(self, img: Tensor)->Tensor:
+        return self.transform(img)
+
+# TODO use it in the code with a flag
+class SwapDefaultTransformer():
+    def __init__(
+            self,
+            size_crops: Tuple[int] = (96, 36),
+            nmb_crops: Tuple[int] = (2, 4),
+            min_scale_crops: Tuple[float] = (0.33, 0.10),
+            max_scale_crops: Tuple[float] = (1, 0.33),
+            gaussian_blur: bool = True,
+            jitter_strength: float = 1.0,
+    ) -> object:
+
+        self.size_crops = size_crops
+        self.num_crops = nmb_crops
+        self.min_scale_crops = min_scale_crops
+        self.max_scale_crops = max_scale_crops
+        self.gaussian_blur = gaussian_blur
+        self.jitter_strength = jitter_strength
+
+    def get(self, mode):
+        if mode == 'train':
+            return swav.SwAVTrainDataTransform(
+                #normalize=self.normalization(),
+                size_crops=self.size_crops,
+                num_crops=self.num_crops,
+                min_scale_crops=self.min_scale_crops,
+                max_scale_crops=self.max_scale_crops,
+                gaussian_blur=self.gaussian_blur,
+                jitter_strength=self.jitter_strength
+            )
+        elif mode == 'val':
+            return swav.SwAVEvalDataTransform(
+                #normalize=self.normalization(),
+                size_crops=self.size_crops,
+                num_crops=self.num_crops,
+                min_scale_crops=self.min_scale_crops,
+                max_scale_crops=self.max_scale_crops,
+                gaussian_blur=self.gaussian_blur,
+                jitter_strength=self.jitter_strength
+            )
+