diff --git a/SSLGlacier/processing/agumentations_.py b/SSLGlacier/processing/agumentations_.py
deleted file mode 100644
index 05cb774513ca1749f65dc55320f93754ab81db8a..0000000000000000000000000000000000000000
--- a/SSLGlacier/processing/agumentations_.py
+++ /dev/null
@@ -1,411 +0,0 @@
-# PILRandomGaussianBlur and get_color_distortion are used and implemented
-# by Swav  and  SimCLR - https://arxiv.org/abs/2002.05709
-import random
-from logging import getLogger
-
-import torch.nn
-from PIL import ImageFilter
-import numpy as np
-import torchvision.transforms as transforms
-import torchvision, tormentor
-from torchvision.transforms import functional as F
-import torch
-import matplotlib.pyplot as plt
-
-logger = getLogger()
-
-
-class Compose(object):
-    """
-    Class for chaining transforms together
-    """
-
-    def __init__(self, transforms):
-        self.transforms = transforms
-
-    def __call__(self, image, target):
-        for t in self.transforms:
-            image, target = t(image, target)
-        return [image, target]
-
-    def __getitem__(self, item):
-        return self.transforms[item]
-
-class DoNothing(torch.nn.Module):
-    def __init__(self):
-        super(DoNothing, self).__init__()
-
-    def __call__(self, img, mask):
-        return [img, mask]
-
-class Cropper(torch.nn.Module):
-    def __init__(self, i, j, h, w):
-        super(Cropper, self).__init__()
-        self.left = i
-        self.right = j
-        self.height = h
-        self.width = w
-
-    def __call__(self, img, mask):
-        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
-        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
-        return [cropped_img, cropped_mask]
-
-
-class RandomCropper(torch.nn.Module):
-    '''
-          This function returns one patch at time, if you need more patches in one image, call it in a loop
-          Args:
-              orig_img: get an png or jpg image and crop one patch randomly, in both image and mask
-              orig_mask: This is the images mask, we crop same area from
-          Returns:
-              A list of two argument, first cropped patch in image,second cropped patch in mask
-          '''
-
-    def __init__(self, size):
-        super(RandomCropper, self).__init__()
-        self.left = 0
-        self.right = 0
-        self.height = 0
-        self.width = 0
-        self.size = size
-
-    def forward(self, img, mask):
-        self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params(
-            img, output_size=(self.size, self.size))
-        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
-        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
-        return [cropped_img, cropped_mask]
-
-
-class PILRandomGaussianBlur(torch.nn.Module):
-    def __init__(self, radius_min=0.1, radius_max=2.):
-        """
-           Apply Gaussian Blur to the PIL image. Take the radius and probability of
-           application as the parameter.
-           This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
-           """
-        super(PILRandomGaussianBlur, self).__init__()
-        self.radius_min = radius_min
-        self.radius_max = radius_max
-
-    def forward(self, img, mask):
-        return [img.filter(
-            ImageFilter.GaussianBlur(
-                radius=random.uniform(self.radius_min, self.radius_max)
-            )
-        ), mask]
-
-
-class RandomHorizontalFlip(torch.nn.Module):
-    def __init__(self):
-        super(RandomHorizontalFlip, self).__init__()
-
-    def forward(self, img, mask):
-        image = torchvision.transforms.functional.hflip(img)
-        mask = torchvision.transforms.functional.hflip(mask)
-        return [image, mask]
-
-
-class GetColorDistortion(torch.nn.Module):
-    def __int__(self, s=1.0):
-        super(GetColorDistortion, self).__init__()
-        self.s = s
-
-    def forward(self, img, mask):
-        color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
-        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
-        rnd_gray = transforms.RandomGrayscale(p=0.2)
-        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
-        return color_distort(img, mask)
-
-
-# TODO this net
-class GaussNoise(torch.nn.Module):
-    def __init__(self, mean=0, var=10000):
-        super(GaussNoise, self).__init__()
-        self.mean = mean
-        self.var = var
-
-    def forward(self, img, mask):
-        row, col = img.size
-        sigma = self.var ** 0.5
-        gauss = np.random.normal(self.mean, sigma, (row, col))
-        gauss = gauss.reshape(row, col)
-        noisy = img + gauss
-        return [noisy, mask]
-
-
-class SaltPepperNoise(torch.nn.Module):
-    def __init__(self, salt_or_pepper=0.5):
-        super(SaltPepperNoise, self).__init__()
-        self.salt_or_pepper = salt_or_pepper
-
-    def forward(self, img, mask):
-        if len(img.size) == 3:
-            img_size = img.size[1] * img.size[2]
-        else:
-            img_size = img.size[0] * img.size[1]
-        amount = .4
-        noisy = np.copy(img)
-        # Salt mode
-        num_salt = np.ceil(amount * img_size * self.salt_or_pepper)
-        target_pixels = [np.random.randint(0, i - 1, int(num_salt))
-                         for i in img.size]
-        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
-        for i, j in target_pixels:
-            noisy[i][j] = 1
-        # Pepper mode
-        num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper))
-        target_pixels = [np.random.randint(0, i - 1, int(num_pepper))
-                         for i in img.size]
-        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
-
-        for i, j in target_pixels:
-            noisy[i][j] = 0
-
-        # plt.imshow(img)
-        # plt.imshow(noisy)
-        # plt.show()
-        return [noisy, mask]
-
-
-class PoissionNoise(torch.nn.Module):
-    def __init__(self):
-        super(PoissionNoise, self).__init__()
-
-    def forward(self, img, mask):
-        vals = len(np.unique(img))
-        vals = 2 ** np.ceil(np.log2(vals))
-        noisy = np.random.poisson(img * vals) / float(vals)
-        return [noisy, mask]
-
-
-# TODO this one
-class SpeckleNoise(torch.nn.Module):
-    def __init__(self):
-        super(SpeckleNoise, self).__init__()
-
-    def forward(self, img, mask):
-        row, col = img.size
-        gauss = np.random.randn(row, col)
-        gauss = gauss.reshape(row, col)
-        noisy = img + img * gauss
-        return [noisy, mask]
-
-
-class SetZeroNoise(torch.nn.Module):
-    def __init__(self):
-        super(SetZeroNoise, self).__init__()
-
-    def forward(self, img, mask):
-        row, col = img.size
-        img_size = row * col
-        random_rows = np.random.randint(0, int(row), (1, int(img_size)))
-        random_cols = np.random.randint(0, int(col), (1, int(img_size)))
-        target_pixels = list(zip(random_rows[0], random_cols[0]))
-        for pix in target_pixels:
-            img[pix[0], pix[1]] = 0
-        return [img, mask]
-
-
-####################################################################################
-
-###########Base_code Agumentations suggestions: #TODO do not need them now##########
-class MyWrap(torch.nn.Module):
-    """
-    Random wrap augmentation taken from tormentor
-    """
-
-    def __init__(self):
-        super(MyWrap, self).__init__()
-
-    def forward(self, img, target):
-        # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics
-        wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)),
-                                                          intensity=tormentor.random.Uniform(value_range=(.0, 1.)))
-        wrap = wrap_rand()
-        image = wrap(img)
-        mask = wrap(target, is_mask=True)
-        return [image, mask]
-
-
-class Rotate(torch.nn.Module):
-    """
-    Random rotation augmentation
-    """
-
-    def __init__(self):
-        super(Rotate, self).__init__()
-
-    def forward(self, img, target):
-        random = np.random.randint(0, 3)
-        angle = 90
-        if random == 1:
-            angle = 180
-        elif random == 2:
-            angle = 270
-        image = torchvision.transforms.functional.rotate(img, angle=angle)
-        mask = torchvision.transforms.functional.rotate(target, angle=angle)
-        return [image, mask.squeeze(0)]
-
-
-class Bright(torch.nn.Module):
-    """
-    Random brightness adjustment augmentations
-    """
-
-    def __init__(self,
-                 lower_band: float = -0.2,
-                 upper_band: float = 0.2):
-        super(Bright, self).__init__()
-        self.lower_band = lower_band
-        self.upper_band = upper_band
-
-    def forward(self, img, mask):
-        bright_rand = tormentor.Brightness.override_distributions(
-            brightness=tormentor.random.Uniform((self.lower_band, self.upper_band)))
-        # bright = bright_rand()
-        # image_transformed = img.clone()
-        image_transformed = bright_rand(np.asarray(img))
-        # set NA areas back to zero
-        image_transformed.seed[img == 0] = 0.0
-
-        return [image_transformed.seed, mask]
-
-
-class Noise(torch.nn.Module):
-    """
-    Random additive noise augmentation
-    """
-
-    def __init__(self):
-        super(Noise, self).__init__()
-
-    def forward(self, img, target):
-        # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again
-        noise = torch.normal(mean=0, std=0.3, size=img.size)
-        image = img + img * noise
-        image[image > 1.0] = 1.0
-        image[image < 0.0] = 0.0
-
-        return [image, target]
-
-
-class Resize(torch.nn.Module):
-    # Library code
-    """Resize the input image to the given size.
-    If the image is torch Tensor, it is expected
-    to have [..., H, W] shape, where ... means a maximum of two leading dimensions
-
-    .. warning::
-        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
-        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
-        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
-        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
-        closer.
-
-    Args:
-        size (sequence or int): Desired output size. If size is a sequence like
-            (h, w), output size will be matched to this. If size is an int,
-            smaller edge of the image will be matched to this number.
-            i.e, if height > width, then image will be rescaled to
-            (size * height / width, size).
-
-            .. note::
-                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
-        interpolation (InterpolationMode): Desired interpolation enum defined by
-            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
-            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
-            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
-            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
-        max_size (int, optional): The maximum allowed for the longer edge of
-            the resized image. If the longer edge of the image is greater
-            than ``max_size`` after being resized according to ``size``,
-            ``size`` will be overruled so that the longer edge is equal to
-            ``max_size``.
-            As a result, the smaller edge may be shorter than ``size``. This
-            is only supported if ``size`` is an int (or a sequence of length
-            1 in torchscript mode).
-        antialias (bool, optional): Whether to apply antialiasing.
-            It only affects **tensors** with bilinear or bicubic modes and it is
-            ignored otherwise: on PIL images, antialiasing is always applied on
-            bilinear or bicubic modes; on other modes (for PIL images and
-            tensors), antialiasing makes no sense and this parameter is ignored.
-            Possible values are:
-
-            - ``True``: will apply antialiasing for bilinear or bicubic modes.
-              Other mode aren't affected. This is probably what you want to use.
-            - ``False``: will not apply antialiasing for tensors on any mode. PIL
-              images are still antialiased on bilinear or bicubic modes, because
-              PIL doesn't support no antialias.
-            - ``None``: equivalent to ``False`` for tensors and ``True`` for
-              PIL images. This value exists for legacy reasons and you probably
-              don't want to use it unless you really know what you are doing.
-
-            The current default is ``None`` **but will change to** ``True`` **in
-            v0.17** for the PIL and Tensor backends to be consistent.
-    """
-
-    def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
-        super().__init__()
-        self.size = size
-        self.max_size = max_size
-        self.interpolation = interpolation
-        self.antialias = antialias
-
-    def forward(self, img, mask):
-        """
-        Args:
-            img (PIL Image or Tensor): Image to be scaled.
-
-        Returns:
-            PIL Image or Tensor: Rescaled image.
-        """
-        return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask]
-
-
-class CenterCrop(torch.nn.Module):
-    """Crops the given image at the center.
-    If the image is torch Tensor, it is expected
-    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
-    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
-
-    Args:
-        size (sequence or int): Desired output size of the crop. If size is an
-            int instead of sequence like (h, w), a square crop (size, size) is
-            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
-    """
-
-    def __init__(self, size):
-        super().__init__()
-        self.size = size  # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
-
-    def forward(self, img, mask):
-        """
-        Args:
-            img (PIL Image or Tensor): Image to be cropped.
-
-        Returns:
-            PIL Image or Tensor: Cropped image.
-        """
-        ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER
-        W, H = img.size
-
-        WW = (W // self.size) + 2
-        HH = (H // self.size) + 2
-
-        return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))]
-
-
-class ToTensorZones():
-    def __call__(self, image, target):
-        image = F.to_tensor(np.array(image).astype(np.float32))
-        target = torch.from_numpy(np.array(target).astype(np.float32))
-        # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254
-        target[target == 0] = 0
-        target[target == 64] = 1
-        target[target == 127] = 2
-        target[target == 254] = 3
-        # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3
-        return [image, target]
diff --git a/SSLGlacier/processing/datamodule_.py b/SSLGlacier/processing/datamodule_.py
deleted file mode 100644
index edf3648e314a451703627fcfc8d5f855c90d7870..0000000000000000000000000000000000000000
--- a/SSLGlacier/processing/datamodule_.py
+++ /dev/null
@@ -1,149 +0,0 @@
-from torchvision.transforms import functional as F
-from pl_bolts.models.self_supervised import swav
-from torch.utils.data import DataLoader
-from torch.utils.data import Dataset
-import pytorch_lightning as pl
-from torchvision import transforms as transform_lib
-from . import agumentations_ as our_transforms
-import torchvision
-import numpy as np
-from typing import Tuple, List, Any, Callable, Optional
-import torch
-import cv2
-from PIL import Image,ImageOps
-import os
-
-# TODO add moco, simclr,cpc and amdim transformers
-class SSLDataset(Dataset):
-    def __init__(self, parent_dir, transform,return_index=False, **kwargs):  # TODO Pass them as arguments
-        '''
-        Args:
-            mode: can be tested, train, or validation
-            parent_dir: directory in which the folders test, train and validation exists
-        '''
-        self.mode = kwargs['mode'] if kwargs['mode'] else 'train'
-        self.return_index = return_index
-        self.images_path = os.path.join(parent_dir, "sar_images", self.mode)
-        self.masks_path = os.path.join(parent_dir, 'zones', self.mode)
-        self.images = os.listdir(self.images_path)
-        self.masks = os.listdir(self.masks_path)
-        assert len(self.masks) == len(self.images), "You don't have the same number of images and masks"
-        self.transform = transform
-        # Sort than images and masks fit together
-        self.images.sort()
-        self.masks.sort()
-
-        # Let shuffle and save indices, and load them in next calls if they exist
-        if not os.path.exists(os.path.join("data_processing", "data_splits")):
-            os.makedirs(os.path.join("data_processing", "data_splits"))
-        if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")):
-            shuffle = np.random.permutation(len(self.images))
-            # Works for numpy version >= 1.5
-            np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
-                       newline=' ')
-
-        else:
-            # use already existing shuffle
-            with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp:
-                lines = fp.readlines()
-                shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines]
-                # if lengths do not match, we need to create a new permutation
-                if len(shuffle) != len(self.images):
-                    shuffle = np.random.permutation(len(self.images))
-                    np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
-                               newline=' ')
-
-        self.images = np.array(self.images)  # Why!!!!
-        self.masks = np.array(self.masks)
-        self.images = self.images[shuffle].copy()
-        self.masks = self.masks[shuffle].copy()
-        self.images = list(self.images)
-        self.masks = list(self.masks)  # Why!!!!
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, index):
-        img_name = self.images[index]
-        masks_name = self.masks[index]
-        assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \
-            "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name
-        #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE)
-        #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE)
-        image = Image.open(os.path.join(self.images_path, img_name).__str__())
-        image = ImageOps.grayscale(image)
-        #image = np.array(image).astype(np.float32)
-
-        #rgbimg = Image.new("RGBA", image.size)
-        #image = rgbimg.paste(image)
-        #image = ImageOps.grayscale(image)
-        mask = Image.open(os.path.join(self.masks_path, masks_name).__str__())
-        #rgbimg = Image.new("RGBA", mask.size)
-        #mask = rgbimg.paste(rgbimg)
-        mask = ImageOps.grayscale(mask)
-        #mask = np.array(mask).astype(np.float32)
-        # TODO check if it works or not, it should return multiple transformation for one image
-        # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!!
-        # TODO this is a list of tuples of img and masks, think if you want to change it later
-
-       # to_tensor = our_transforms.ToTensorZones()
-       # _, mask = to_tensor(image,mask)
-        if self.transform is not None:
-            augmented_imgs, masks = self.transform(image, mask)
-        if self.return_index:
-            return index, augmented_imgs, mask, img_name, masks_name
-        return augmented_imgs, masks
-
-
-class SSLDataModule(pl.LightningDataModule):
-    # Base is adopted from Nora's code
-    def __init__(self, batch_size, parent_dir, args):
-        """
-        :param batch_size: batch size
-        :param target: Either 'zones' or 'front'. Tells which masks should be used.
-        """
-        super().__init__()
-        self.transforms = None
-        self.batch_size = batch_size
-        self.glacier_test = None
-        self.glacier_train = None
-        self.glacier_val = None
-        self.parent_dir = parent_dir
-        self.aug_args = args.agumentations
-        self.size_crops = args.size_crops
-        self.num_crops = args.nmb_crops
-
-    def prepare_data(self):
-        # download,
-        # only called on 1 GPU/TPU in distributed
-        pass
-    def __len__(self):
-        return len(self.glacier_train)
-    def setup(self, stage=None):
-        # process and split here
-        # make assignments here (val/train/test split)
-        # called on every process in DDP
-        if stage == 'test' or stage is None:
-            self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
-            self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage)
-        if stage == 'fit' or stage is None:
-            self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
-            self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train')
-
-            self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
-            self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val')
-
-    def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True,
-                          drop_last=True)
-
-    def val_dataloader(self)-> DataLoader:
-        return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True)
-
-    def test_dataloader(self)-> DataLoader:
-        return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True,
-                          drop_last=True)  # TODO self.batch_size
-
-    def _default_transforms(self) -> Callable:
-        #return transform_lib.Compose([transform_lib.ToTensor()])
-        return our_transforms.Compose([our_transforms.ToTensorZones()])
diff --git a/SSLGlacier/processing/semi_module.py b/SSLGlacier/processing/semi_module.py
deleted file mode 100644
index 9798a3620f90f6dc9e849eb08eecafe45f256c24..0000000000000000000000000000000000000000
--- a/SSLGlacier/processing/semi_module.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from pytorch_lightning import LightningModule
-from typing import Any
-class Semi_(LightningModule):
-    def __init__(self):
-        pass
-
-    def setup(self, stage: str) -> None:
-        pass
-
-    def init_model(self):
-        pass
-
-    def forward(self, *args: Any, **kwargs: Any) -> Any:
-        pass
-
-    def on_train_epoch_start(self) -> None:
-        pass
-
-    def shared_step(self, batch):
-        inputs, y = batch
-        inputs = inputs[:, -1]
-        embedding = self.model(inputs)
-        ####Where should I add noise....
-        ####assume embedding[b][i] belongs to batch b view i
-        ####if we have two view, one with noise and one without, then
-        ##### select some portion of each view belongs to different classes.
-        ##### pos_set = same classess ..
-        ##### neg_set = diff classes
-
diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/processing/swav_module.py
deleted file mode 100644
index 3b57a4ede5265846c6a4aeb4b885d3d31169c206..0000000000000000000000000000000000000000
--- a/SSLGlacier/processing/swav_module.py
+++ /dev/null
@@ -1,562 +0,0 @@
-"""Adapted from official swav implementation: https://github.com/facebookresearch/swav."""
-import os
-from argparse import ArgumentParser
-
-import torch
-from pytorch_lightning import LightningModule, Trainer
-from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
-from torch import nn
-
-#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
-from SSLGlacier.models.losses import SWAVLoss
-from SSLGlacier.models.RESNET import resnet18, resnet50
-#from Base_Code.models.ParentUNet import UNet
-from SSLGlacier.models.UNET import UNet #The main network
-from pl_bolts.optimizers.lars import LARS
-from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
-
-##########Importing files for hook############
-from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
-######Temporary config file for hooknet#######
-from utils.config import get_config
-
-class SwAV(LightningModule):
-    def __init__(
-        self,
-        gpus: int,
-        num_samples: int,
-        batch_size: int,
-        dataset: str,
-        num_nodes: int = 1,
-        arch: str = "resnet50",
-        hidden_mlp: int = 2048,
-        feat_dim: int = 128,
-        warmup_epochs: int = 10,
-        max_epochs: int = 100,
-        num_prototypes: int = 3000,
-        freeze_prototypes_epochs: int = 1,
-        temperature: float = 0.1,
-        sinkhorn_iterations: int = 3,
-        queue_length: int = 0,  # must be divisible by total batch-size
-        queue_path: str = "queue",
-        epoch_queue_starts: int = 15,
-        crops_for_assign: tuple = (0, 1),
-        num_crops: tuple = (2, 6),
-        num_augs: int= 2,
-        first_conv: bool = True,
-        maxpool1: bool = True,
-        optimizer: str = "adam",
-        exclude_bn_bias: bool = False,
-        start_lr: float = 0.0,
-        learning_rate: float = 1e-3,
-        final_lr: float = 0.0,
-        weight_decay: float = 1e-6,
-        epsilon: float = 0.05,
-        just_aug_for_same_assign_views: bool = False,
-        swin_hparams = {}
-    ) -> None:
-        """
-        Args:
-            gpus: number of gpus per node used in training, passed to SwAV module
-                to manage the queue and select distributed sinkhorn
-            num_nodes: number of nodes to train on
-            num_samples: number of image samples used for training
-            batch_size: batch size per GPU in ddp
-            dataset: dataset being used for train/val
-            arch: encoder architecture used for pre-training
-            hidden_mlp: hidden layer of non-linear projection head, set to 0
-                to use a linear projection head
-            feat_dim: output dim of the projection head
-            warmup_epochs: apply linear warmup for this many epochs
-            max_epochs: epoch count for pre-training
-            num_prototypes: count of prototype vectors
-            freeze_prototypes_epochs: epoch till which gradients of prototype layer
-                are frozen
-            temperature: loss temperature
-            sinkhorn_iterations: iterations for sinkhorn normalization
-            queue_length: set queue when batch size is small,
-                must be divisible by total batch-size (i.e. total_gpus * batch_size),
-                set to 0 to remove the queue
-            queue_path: folder within the logs directory
-            epoch_queue_starts: start uing the queue after this epoch
-            crops_for_assign: list of crop ids for computing assignment
-            num_crops: number of global and local crops, ex: [2, 6]
-            first_conv: keep first conv same as the original resnet architecture,
-                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
-            maxpool1: keep first maxpool layer same as the original resnet architecture,
-                if set to false, first maxpool is turned off (cifar10, maybe stl10)
-            optimizer: optimizer to use
-            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
-            start_lr: starting lr for linear warmup
-            learning_rate: learning rate
-            final_lr: float = final learning rate for cosine weight decay
-            weight_decay: weight decay for optimizer
-            epsilon: epsilon val for swav assignments
-        """
-        super().__init__()
-        self.save_hyperparameters()
-
-        self.gpus = gpus
-        self.num_nodes = num_nodes
-        self.arch = arch
-        self.dataset = dataset
-        self.num_samples = num_samples
-        self.batch_size = batch_size
-
-        self.hidden_mlp = hidden_mlp
-        self.feat_dim = feat_dim
-        self.num_prototypes = num_prototypes
-        self.freeze_prototypes_epochs = freeze_prototypes_epochs
-        self.sinkhorn_iterations = sinkhorn_iterations
-
-        self.queue_length = queue_length
-        self.queue_path = queue_path
-        self.epoch_queue_starts = epoch_queue_starts
-        self.crops_for_assign = crops_for_assign
-        self.num_crops = num_crops
-        self.num_augs = num_augs
-
-        self.first_conv = first_conv
-        self.maxpool1 = maxpool1
-
-        self.optim = optimizer
-        self.exclude_bn_bias = exclude_bn_bias
-        self.weight_decay = weight_decay
-        self.epsilon = epsilon
-        self.temperature = temperature
-
-        self.start_lr = start_lr
-        self.final_lr = final_lr
-        self.learning_rate = learning_rate
-        self.warmup_epochs = warmup_epochs
-        self.max_epochs = max_epochs
-        self.just_aug_for_same_assign_views = just_aug_for_same_assign_views
-
-        self.model = self.init_model()
-        self.criterion = SWAVLoss(
-            gpus=self.gpus,
-            num_nodes=self.num_nodes,
-            temperature=self.temperature,
-            crops_for_assign=self.crops_for_assign,
-            num_crops=self.num_crops,
-            num_augs=self.num_augs,
-            sinkhorn_iterations=self.sinkhorn_iterations,
-            epsilon=self.epsilon,
-            just_aug_for_same_assign_views= self.just_aug_for_same_assign_views
-        )
-        self.use_the_queue = None
-        # compute iters per epoch
-        global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
-        self.train_iters_per_epoch = self.num_samples // global_batch_size
-        self.queue = None
-
-        ####For hook####
-        self.swin_hparams = swin_hparams
-    def setup(self, stage):
-        if self.queue_length > 0:
-            queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
-            if not os.path.exists(queue_folder):
-                os.makedirs(queue_folder)
-
-            self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth")
-
-            if os.path.isfile(self.queue_path):
-                self.queue = torch.load(self.queue_path)["queue"]
-
-    def init_model(self):
-        if self.arch == "resnet18":
-            backbone = resnet18(normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
-
-        elif self.arch == "resnet50":
-            backbone = resnet50(normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
-        elif self.arch == "hook":
-            backbone = SwinUnet(
-                                img_size=224,
-                                num_classes=5,
-                                normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1,######till here for basenet
-                                image_size= self.swin_hparams.image_size,####from here for swin itself
-                                swin_patch_size = self.swin_hparams.swin_patch_size,
-                                swin_in_chans=self.swin_hparams.swin_in_chans,
-                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
-                                swin_depths=self.swin_hparams.swin_depths,
-                                swin_num_heads=self.swin_hparams.swin_num_heads,
-                                swin_window_size=self.swin_hparams.swin_window_size,
-                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
-                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
-                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
-                                drop_rate=self.drop_rate,
-                                drop_path_rate=self.drop_path_rate,
-                                swin_ape=self.swin_hparams.swin_ape,
-                                swin_patch_norm=self.swin_hparams.swin_path_norm
-                                )
-
-        elif self.arch == "Unet" or "unet":
-            backbone = UNet(num_classes=5,
-                            input_channels=1,
-                            num_layers=5,
-                            features_start=64,
-                            bilinear=False,
-                            normalize=True,
-                            hidden_mlp=self.hidden_mlp,
-                            output_dim=self.feat_dim,
-                            num_prototypes=self.num_prototypes,
-                            first_conv=self.first_conv,
-                            maxpool1=self.maxpool1
-                            )
-        else:
-            raise f'{self.arch} model is not defined'
-
-        return backbone
-
-    def forward(self, x):
-        # pass single batch from the resnet backbone
-        return self.model.forward_backbone(x)
-
-    def on_train_epoch_start(self):
-        if self.queue_length > 0:
-            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
-                self.queue = torch.zeros(
-                    len(self.crops_for_assign),
-                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
-                    self.feat_dim,
-                )
-
-            if self.queue is not None:
-                self.queue = self.queue.to(self.device)
-
-        self.use_the_queue = False
-
-    def on_train_epoch_end(self) -> None:
-        if self.queue is not None:
-            torch.save({"queue": self.queue}, self.queue_path)
-
-    def on_after_backward(self):
-        if self.current_epoch < self.freeze_prototypes_epochs:
-            for name, p in self.model.named_parameters():
-                if "prototypes" in name:
-                    p.grad = None
-
-    def shared_step(self, batch):
-        if self.dataset == "stl10":
-            unlabeled_batch = batch[0]
-            batch = unlabeled_batch
-
-        inputs, y = batch
-        inputs = inputs[:-1]  # remove online train/eval transforms at this point
-
-        # 1. normalize the prototypes
-        with torch.no_grad():
-            w = self.model.prototypes.weight.data.clone()
-            w = nn.functional.normalize(w, dim=1, p=2)
-            self.model.prototypes.weight.copy_(w)
-
-        # 2. multi-res forward passes
-        embedding, output = self.model(inputs)
-        embedding = embedding.detach()
-        bs = inputs[0].size(0)
-
-        # SWAV loss computation
-        loss, queue, use_queue = self.criterion(
-            output=output,
-            embedding=embedding,
-            prototype_weights=self.model.prototypes.weight,
-            batch_size=bs,
-            queue=self.queue,
-            use_queue=self.use_the_queue,
-        )
-        self.queue = queue
-        self.use_the_queue = use_queue
-        return loss
-
-    def training_step(self, batch, batch_idx):
-        loss = self.shared_step(batch)
-
-        self.log("train_loss", loss, on_step=True, on_epoch=False)
-        return loss
-
-    def validation_step(self, batch, batch_idx):
-        loss = self.shared_step(batch)
-
-        self.log("val_loss", loss, on_step=False, on_epoch=True)
-        return loss
-
-    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")):
-        params = []
-        excluded_params = []
-
-        for name, param in named_params:
-            if not param.requires_grad:
-                continue
-            if any(layer_name in name for layer_name in skip_list):
-                excluded_params.append(param)
-            else:
-                params.append(param)
-
-        return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}]
-
-    def configure_optimizers(self):
-        if self.exclude_bn_bias:
-            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
-        else:
-            params = self.parameters()
-
-        if self.optim == "lars":
-            optimizer = LARS(
-                params,
-                lr=self.learning_rate,
-                momentum=0.9,
-                weight_decay=self.weight_decay,
-                trust_coefficient=0.001,
-            )
-        elif self.optim == "adam":
-            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
-
-        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
-        total_steps = self.train_iters_per_epoch * self.max_epochs
-
-        scheduler = {
-            "scheduler": torch.optim.lr_scheduler.LambdaLR(
-                optimizer,
-                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
-            ),
-            "interval": "step",
-            "frequency": 1,
-        }
-
-        return [optimizer], [scheduler]
-
-    @staticmethod
-    def add_model_specific_args(parent_parser):
-        parser = ArgumentParser(parents=[parent_parser], add_help=False)
-
-        # model params
-        parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
-        # specify flags to store false
-        parser.add_argument("--first_conv", action="store_false")
-        parser.add_argument("--maxpool1", action="store_false")
-        parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head")
-        parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension")
-        parser.add_argument("--online_ft", action="store_true")
-        parser.add_argument("--fp32", action="store_true")
-
-        # transform params
-        parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur")
-        parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength")
-        parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10")
-        parser.add_argument("--data_dir", type=str, default=".", help="path to download data")
-        parser.add_argument("--queue_path", type=str, default="queue", help="path for queue")
-
-        parser.add_argument(
-            "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])"
-        )
-        parser.add_argument(
-            "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])"
-        )
-        parser.add_argument(
-            "--min_scale_crops",
-            type=float,
-            default=[0.33, 0.10],
-            nargs="+",
-            help="argument in RandomResizedCrop (example: [0.14, 0.05])",
-        )
-        parser.add_argument(
-            "--max_scale_crops",
-            type=float,
-            default=[1, 0.33],
-            nargs="+",
-            help="argument in RandomResizedCrop (example: [1., 0.14])",
-        )
-
-        # training params
-        parser.add_argument("--fast_dev_run", default=1, type=int)
-        parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
-        parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
-        parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
-        parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
-        parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay")
-        parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
-        parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
-        parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
-        parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
-
-        parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay")
-        parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate")
-        parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate")
-        parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate")
-
-        # swav params
-        parser.add_argument(
-            "--crops_for_assign",
-            type=int,
-            nargs="+",
-            default=[0, 1],
-            help="list of crops id used for computing assignments",
-        )
-        parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss")
-        parser.add_argument(
-            "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm"
-        )
-        parser.add_argument(
-            "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm"
-        )
-        parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes")
-        parser.add_argument(
-            "--queue_length",
-            type=int,
-            default=0,
-            help="length of the queue (0 for no queue); must be divisible by total batch size",
-        )
-        parser.add_argument(
-            "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue"
-        )
-        parser.add_argument(
-            "--freeze_prototypes_epochs",
-            default=1,
-            type=int,
-            help="freeze the prototypes during this many epochs from the start",
-        )
-
-        return parser
-
-
-def cli_main():
-    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
-    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
-    from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
-
-    parser = ArgumentParser()
-
-    # model args
-    parser = SwAV.add_model_specific_args(parser)
-    args = parser.parse_args()
-
-    if args.dataset == "stl10":
-        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        dm.train_dataloader = dm.train_dataloader_mixed
-        dm.val_dataloader = dm.val_dataloader_mixed
-        args.num_samples = dm.num_unlabeled_samples
-
-        args.maxpool1 = False
-
-        normalization = stl10_normalization()
-    elif args.dataset == "cifar10":
-        args.batch_size = 2
-        args.num_workers = 0
-
-        dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        args.num_samples = dm.num_samples
-
-        args.maxpool1 = False
-        args.first_conv = False
-
-        normalization = cifar10_normalization()
-
-        # cifar10 specific params
-        args.size_crops = [32, 16]
-        args.num_crops = [2, 1]
-        args.gaussian_blur = False
-    elif args.dataset == "imagenet":
-        args.maxpool1 = True
-        args.first_conv = True
-        normalization = imagenet_normalization()
-
-        args.size_crops = [224, 96]
-        args.num_crops = [2, 6]
-        args.min_scale_crops = [0.14, 0.05]
-        args.max_scale_crops = [1.0, 0.14]
-        args.gaussian_blur = True
-        args.jitter_strength = 1.0
-
-        args.batch_size = 64
-        args.num_nodes = 8
-        args.gpus = 8  # per-node
-        args.max_epochs = 800
-
-        args.optimizer = "lars"
-        args.learning_rate = 4.8
-        args.final_lr = 0.0048
-        args.start_lr = 0.3
-
-        args.num_prototypes = 3000
-        args.online_ft = True
-
-        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        args.num_samples = dm.num_samples
-        args.input_height = dm.dims[-1]
-    else:
-        raise NotImplementedError("other datasets have not been implemented till now")
-
-    dm.train_transforms = SwAVTrainDataTransform(
-        normalize=normalization,
-        size_crops=args.size_crops,
-        num_crops=args.num_crops,
-        min_scale_crops=args.min_scale_crops,
-        max_scale_crops=args.max_scale_crops,
-        gaussian_blur=args.gaussian_blur,
-        jitter_strength=args.jitter_strength,
-    )
-
-    dm.val_transforms = SwAVEvalDataTransform(
-        normalize=normalization,
-        size_crops=args.size_crops,
-        num_crops=args.num_crops,
-        min_scale_crops=args.min_scale_crops,
-        max_scale_crops=args.max_scale_crops,
-        gaussian_blur=args.gaussian_blur,
-        jitter_strength=args.jitter_strength,
-    )
-
-    # swav model init
-    model = SwAV(**args.__dict__)
-
-    online_evaluator = None
-    if args.online_ft:
-        # online eval
-        online_evaluator = SSLOnlineEvaluator(
-            drop_p=0.0,
-            hidden_dim=None,
-            z_dim=args.hidden_mlp,
-            num_classes=dm.num_classes,
-            dataset=args.dataset,
-        )
-
-    lr_monitor = LearningRateMonitor(logging_interval="step")
-    model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss")
-    callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
-    callbacks.append(lr_monitor)
-
-    trainer = Trainer(
-        max_epochs=args.max_epochs,
-        max_steps=None if args.max_steps == -1 else args.max_steps,
-        gpus=args.gpus,
-        num_nodes=args.num_nodes,
-        accelerator="ddp" if args.gpus > 1 else None,
-        sync_batchnorm=args.gpus > 1,
-        precision=32 if args.fp32 else 16,
-        callbacks=callbacks,
-        fast_dev_run=args.fast_dev_run,
-    )
-
-    trainer.fit(model, datamodule=dm)
-
-
-if __name__ == "__main__":
-    cli_main()
\ No newline at end of file
diff --git a/SSLGlacier/processing/transformers_.py b/SSLGlacier/processing/transformers_.py
deleted file mode 100644
index f8290bb69b2c33649e92a9dd9ae7797595b94e85..0000000000000000000000000000000000000000
--- a/SSLGlacier/processing/transformers_.py
+++ /dev/null
@@ -1,248 +0,0 @@
-from pl_bolts.models.self_supervised import swav
-from torchvision import transforms  # need it for validation and fine tuning
-from . import agumentations_ as ag
-from typing import Tuple, List, Dict
-from torch import Tensor
-
-
-
-class OurTrainTransformer():
-    def __init__(self,
-                 size_crops: Tuple[int] = (294, 94),
-                 nmb_crops: Tuple[int] = (2, 4),
-                 augs:Dict = {'SaltPepper':1},
-                 use_hook_former: bool = True) -> List:
-        # TODO you can set it in a way that making sure you have one sample from each argumentation method
-        self.size_crops = size_crops
-        self.num_crops = nmb_crops
-        augmented_imgs = []
-        self.use_hook_former = use_hook_former
-        self.transform = []
-        if augs['OrigPixelValues']:
-            augmented_imgs.append(ag.DoNothing())
-        if augs['RGaussianB']:
-            augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.))
-        if augs['GaussiN']:
-            augmented_imgs.append(ag.GaussNoise(mean=0, var=10000))
-        if augs['SaltPepper']:
-            augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2))
-        if augs['ZeroN']:
-            augmented_imgs.append(ag.SetZeroNoise())
-        # Noras augmentations methods
-        if augs['flip']:
-            # TODO manage this, how to add mask for the times we need transformation for masks too
-            augmented_imgs.append(ag.RandomHorizentalFlip())
-            # image = torchvision.transforms.functional.hflip(image)
-        # mask = torchvision.transforms.functional.hflip(mask)
-        # image, mask = ag.ToTensorZones(image=image, target=np.array(mask))
-        # augmented_imgs.append((image, mask))
-        if augs['rotate']:
-            augmented_imgs.append(ag.Rotate())
-        if augs['bright']:
-            augmented_imgs.append(ag.Bright())
-        if augs['wrap']:
-            augmented_imgs.append()
-        if augs['noise']:
-            augmented_imgs.append(ag.Noise())
-
-        if augs['normalize'] is not None:
-            self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']])
-        else:
-            self.final_transform = ag.ToTensorZones()
-
-        if self.use_hook_former:
-            self.hook_former_transformer(augmented_imgs)
-        else:
-            self.other_networks_transformers(augmented_imgs)
-
-
-
-    def __call__(self, image, mask):
-        """
-        import matplotlib.pyplot as plt
-        img1 =transformed_imgs[0].detach().cpu().numpy().squeeze()
-        plt.imshow(img1)
-        plt.show()
-        """
-        transformed_img_mask = []
-        for transform in self.transform:
-            if isinstance(transform.__getitem__(0), ag.RandomCropper):
-                transformed_img_mask.append(transform(image, mask))
-                self.i = transform.__getitem__(0).left
-                self.j = transform.__getitem__(0).right
-                self.w = transform.__getitem__(0).width
-                self.h = transform.__getitem__(0).height
-            elif isinstance(transform.__getitem__(0), ag.Cropper):
-                transform.__getitem__(0).left = self.i
-                transform.__getitem__(0).right = self.j
-                transform.__getitem__(0).width = self.w
-                transform.__getitem__(0).height = self.h
-                transformed_img_mask.append(transform(image, mask))
-            else:
-                transformed_img_mask.append(transform(image,mask))
-        transformed_imgs = [item[0] for item in transformed_img_mask]
-        transformed_masks = [item[1] for item in transformed_img_mask]
-
-
-        # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12))
-        # fig.suptitle("Patches used in training", fontsize=18, y=.95)
-        #
-        # for i, image_id in enumerate(transformed_imgs):
-        #     img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze()
-        #     mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze()
-        #     axs[0, i].imshow(img1)
-        #     axs[0, i].imshow(mask1)
-        #     axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze())
-        #     axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze())
-        # plt.savefig('High.png')
-        return transformed_imgs, transformed_masks
-
-    def hook_former_transformer(self, augmented_imgs):
-        transform = []
-        for i in range(len(self.size_crops)):
-            i_transform = []
-            for ith_aug, aug in enumerate(augmented_imgs):
-                # For the first time we crop randomly, then save coordinates for further transformation
-                if ith_aug == 0:
-                    global_crop = ag.RandomCropper(self.size_crops[i])
-                    i_transform.extend(
-                        [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])]
-                    )
-                else:  # Crop same area, and do augmentation
-                    local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4)
-                    i_transform.extend(
-                        [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])])
-            transform += i_transform * self.num_crops[i]
-        self.transform =  transform
-        # add online train transform of the size of global view
-        online_train_transform = ag.Compose(
-            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
-        )
-
-        self.transform.append(online_train_transform)
-        return transform
-
-    def other_networks_transformers(self, augmented_imgs):
-        transform = []
-        for i in range(len(self.size_crops)):
-            i_transform = []
-            for ith_aug, aug in enumerate(augmented_imgs):
-                # For the first time we crop randomly, then save coordinates for further transformation
-                if ith_aug == 0:
-                    random_crop = ag.RandomCropper(self.size_crops[i])
-                    i_transform.extend(
-                        [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])]
-                    )
-                else:  # Crop same area, and do augmentation
-                    fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width)
-                    i_transform.extend(
-                        [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])])
-            transform += i_transform * self.num_crops[i]
-
-        self.transform = transform
-        # add online train transform of the size of global view
-        #What was that for?? I forgot!
-        online_train_transform = ag.Compose(
-            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
-        )
-
-        self.transform.append(online_train_transform)
-        return transform
-
-class OurEvalTransformer(OurTrainTransformer):
-    def __init__(self,
-                 size_crops: Tuple[int] = (294, 294),
-                 nmb_crops: Tuple[int] = (2, 4),
-                 augs:Dict = {'SaltPepper':1},
-                 use_hook_former: bool =True) -> List:
-        super().__init__(size_crops=size_crops,
-                       nmb_crops=nmb_crops,
-                       augs = augs,
-                         use_hook_former=use_hook_former)
-
-        input_height = self.size_crops[0]  # get global view crop
-        test_transform = ag.Compose(
-            [
-                ag.Resize(int(input_height + 0.1 * input_height)),
-                ag.CenterCrop(input_height),
-                self.final_transform,
-            ]
-        )
-
-    # replace last transform to eval transform in self.transform list
-        self.transform[-1] = test_transform
-
-class OurFinetuneTransform:
-    def __init__(self,
-                 input_height: int = 224,
-                 normalize = None,
-                 eval_transform: bool = False) -> None :
-
-        if not eval_transform:# TODO think about it
-            data_transforms = [
-                transforms.RandomResizedCrop(size=self.input_height),
-                ag.RandomHorizontalFlip(),
-                ag.SaltPepperNoise(salt_or_pepper=0.5),
-                transforms.RandomGrayscale(p=0.2),
-            ]
-        else:
-            data_transforms = ag.Compose(
-                [
-                    transforms.Resize(int(input_height + 0.1 * input_height)),
-                    transforms.CenterCrop(input_height),
-                    self.final_transform,
-                ]
-            )
-
-        if normalize is None:
-            final_transform = transforms.ToTensor()
-        else:
-            final_transform = transforms.Compose([transforms.ToTensor(), normalize])
-
-        data_transforms.append(final_transform)
-        self.transform = ag.Compose(data_transforms)
-
-    def __call__(self, img: Tensor)->Tensor:
-        return self.transform(img)
-
-# TODO use it in the code with a flag
-class SwapDefaultTransformer():
-    def __init__(
-            self,
-            size_crops: Tuple[int] = (96, 36),
-            nmb_crops: Tuple[int] = (2, 4),
-            min_scale_crops: Tuple[float] = (0.33, 0.10),
-            max_scale_crops: Tuple[float] = (1, 0.33),
-            gaussian_blur: bool = True,
-            jitter_strength: float = 1.0,
-    ) -> object:
-
-        self.size_crops = size_crops
-        self.num_crops = nmb_crops
-        self.min_scale_crops = min_scale_crops
-        self.max_scale_crops = max_scale_crops
-        self.gaussian_blur = gaussian_blur
-        self.jitter_strength = jitter_strength
-
-    def get(self, mode):
-        if mode == 'train':
-            return swav.SwAVTrainDataTransform(
-                #normalize=self.normalization(),
-                size_crops=self.size_crops,
-                num_crops=self.num_crops,
-                min_scale_crops=self.min_scale_crops,
-                max_scale_crops=self.max_scale_crops,
-                gaussian_blur=self.gaussian_blur,
-                jitter_strength=self.jitter_strength
-            )
-        elif mode == 'val':
-            return swav.SwAVEvalDataTransform(
-                #normalize=self.normalization(),
-                size_crops=self.size_crops,
-                num_crops=self.num_crops,
-                min_scale_crops=self.min_scale_crops,
-                max_scale_crops=self.max_scale_crops,
-                gaussian_blur=self.gaussian_blur,
-                jitter_strength=self.jitter_strength
-            )
-