From 0013a50e923782dc64ee559f1591d088d7b66679 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 12:53:17 +0100
Subject: [PATCH 01/11] Add a hook net config file, temporary: Just adding a
 config file in utils, for setting hooknet parammeters.

---
 SSLGlacier/utils/config.py | 236 +++++++++++++++++++++++++++++++++++++
 1 file changed, 236 insertions(+)
 create mode 100644 SSLGlacier/utils/config.py

diff --git a/SSLGlacier/utils/config.py b/SSLGlacier/utils/config.py
new file mode 100644
index 0000000..c0b0021
--- /dev/null
+++ b/SSLGlacier/utils/config.py
@@ -0,0 +1,236 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------'
+
+import os
+import yaml
+from yacs.config import CfgNode as CN
+
+_C = CN()
+
+# Base config files
+_C.BASE = ['']
+
+# -----------------------------------------------------------------------------
+# Data settings
+# -----------------------------------------------------------------------------
+_C.DATA = CN()
+# Batch size for a single GPU, could be overwritten by command line argument
+_C.DATA.BATCH_SIZE = 128
+# Path to dataset, could be overwritten by command line argument
+_C.DATA.DATA_PATH = ''
+# Dataset name
+_C.DATA.DATASET = 'imagenet'
+# Input image size
+_C.DATA.IMG_SIZE = 224
+# _C.DATA.IMG_SIZE = 56
+# Interpolation to resize image (random, bilinear, bicubic)
+_C.DATA.INTERPOLATION = 'bicubic'
+# Use zipped dataset instead of folder dataset
+# could be overwritten by command line argument
+_C.DATA.ZIP_MODE = False
+# Cache Data in Memory, could be overwritten by command line argument
+_C.DATA.CACHE_MODE = 'part'
+# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
+_C.DATA.PIN_MEMORY = True
+# Number of data loading threads
+_C.DATA.NUM_WORKERS = 8
+
+# -----------------------------------------------------------------------------
+# Model settings
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+# Model type
+_C.MODEL.TYPE = 'swin'
+# Model name
+_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
+# Checkpoint to resume, could be overwritten by command line argument
+# _C.MODEL.PRETRAIN_CKPT = '/home/wf/Swin-Unet-main/output/epoch_150.pth'
+_C.MODEL.PRETRAIN_CKPT = None
+_C.MODEL.RESUME = ''
+# Number of classes, overwritten in data preparation
+_C.MODEL.NUM_CLASSES = 1000
+# Dropout rate
+_C.MODEL.DROP_RATE = 0.0
+# Drop path rate
+_C.MODEL.DROP_PATH_RATE = 0.1
+# Label Smoothing
+_C.MODEL.LABEL_SMOOTHING = 0.1
+
+# Swin Transformer parameters
+_C.MODEL.SWIN = CN()
+_C.MODEL.SWIN.PATCH_SIZE = 4
+# _C.MODEL.SWIN.PATCH_SIZE = 2
+_C.MODEL.SWIN.IN_CHANS = 3
+# _C.MODEL.SWIN.IN_CHANS = 1
+_C.MODEL.SWIN.EMBED_DIM = 96
+_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN.WINDOW_SIZE = 7
+# _C.MODEL.SWIN.WINDOW_SIZE = 2
+_C.MODEL.SWIN.MLP_RATIO = 4.
+_C.MODEL.SWIN.QKV_BIAS = True
+_C.MODEL.SWIN.QK_SCALE = None
+_C.MODEL.SWIN.APE = False
+_C.MODEL.SWIN.PATCH_NORM = True
+_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
+
+# -----------------------------------------------------------------------------
+# Training settings
+# -----------------------------------------------------------------------------
+_C.TRAIN = CN()
+_C.TRAIN.START_EPOCH = 0
+_C.TRAIN.EPOCHS = 300
+_C.TRAIN.WARMUP_EPOCHS = 20
+_C.TRAIN.WEIGHT_DECAY = 0.05
+_C.TRAIN.BASE_LR = 5e-4
+_C.TRAIN.WARMUP_LR = 5e-7
+_C.TRAIN.MIN_LR = 5e-6
+# Clip gradient norm
+_C.TRAIN.CLIP_GRAD = 5.0
+# Auto resume from latest checkpoint
+_C.TRAIN.AUTO_RESUME = True
+# Gradient accumulation steps
+# could be overwritten by command line argument
+_C.TRAIN.ACCUMULATION_STEPS = 0
+# Whether to use gradient checkpointing to save memory
+# could be overwritten by command line argument
+_C.TRAIN.USE_CHECKPOINT = False
+
+# LR scheduler
+_C.TRAIN.LR_SCHEDULER = CN()
+_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
+# Epoch interval to decay LR, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
+# LR decay rate, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
+
+# Optimizer
+_C.TRAIN.OPTIMIZER = CN()
+_C.TRAIN.OPTIMIZER.NAME = 'adamw'       # - NOT USED
+# Optimizer Epsilon
+_C.TRAIN.OPTIMIZER.EPS = 1e-8
+# Optimizer Betas
+_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
+# SGD momentum
+_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
+
+# -----------------------------------------------------------------------------
+# Augmentation settings - NOT USED
+# -----------------------------------------------------------------------------
+_C.AUG = CN()
+# Color jitter factor
+_C.AUG.COLOR_JITTER = 0.4
+# Use AutoAugment policy. "v0" or "original"
+_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
+# Random erase prob
+_C.AUG.REPROB = 0.25
+# Random erase mode
+_C.AUG.REMODE = 'pixel'
+# Random erase count
+_C.AUG.RECOUNT = 1
+# Mixup alpha, mixup enabled if > 0
+_C.AUG.MIXUP = 0.8
+# Cutmix alpha, cutmix enabled if > 0
+_C.AUG.CUTMIX = 1.0
+# Cutmix min/max ratio, overrides alpha and enables cutmix if set
+_C.AUG.CUTMIX_MINMAX = None
+# Probability of performing mixup or cutmix when either/both is enabled
+_C.AUG.MIXUP_PROB = 1.0
+# Probability of switching to cutmix when both mixup and cutmix enabled
+_C.AUG.MIXUP_SWITCH_PROB = 0.5
+# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
+_C.AUG.MIXUP_MODE = 'batch'
+
+# -----------------------------------------------------------------------------
+# Testing settings
+# -----------------------------------------------------------------------------
+_C.TEST = CN()
+# Whether to use center crop when testing
+_C.TEST.CROP = True
+
+# -----------------------------------------------------------------------------
+# Misc
+# -----------------------------------------------------------------------------
+# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
+# overwritten by command line argument
+_C.AMP_OPT_LEVEL = ''
+# Path to output folder, overwritten by command line argument
+_C.OUTPUT = ''
+# Tag of experiment, overwritten by command line argument
+_C.TAG = 'default'
+# Frequency to save checkpoint
+_C.SAVE_FREQ = 1
+# Frequency to logging info
+_C.PRINT_FREQ = 10
+# Fixed random seed
+_C.SEED = 0
+# Perform evaluation only, overwritten by command line argument
+_C.EVAL_MODE = False
+# Test throughput only, overwritten by command line argument
+_C.THROUGHPUT_MODE = False
+# local rank for DistributedDataParallel, given by command line argument
+_C.LOCAL_RANK = 0
+
+
+# NOT USED FROM HERE ON
+
+def _update_config_from_file(config, cfg_file):
+    config.defrost()
+    with open(cfg_file, 'r') as f:
+        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
+
+    for cfg in yaml_cfg.setdefault('BASE', ['']):
+        if cfg:
+            _update_config_from_file(
+                config, os.path.join(os.path.dirname(cfg_file), cfg)
+            )
+    print('=> merge config from {}'.format(cfg_file))
+    config.merge_from_file(cfg_file)
+    config.freeze()
+
+
+def update_config(config, args):
+    _update_config_from_file(config, args.cfg)
+
+    config.defrost()
+    if args.opts:
+        config.merge_from_list(args.opts)
+
+    # merge from specific arguments
+    if args.batch_size:
+        config.DATA.BATCH_SIZE = args.batch_size
+    if args.zip:
+        config.DATA.ZIP_MODE = True
+    if args.cache_mode:
+        config.DATA.CACHE_MODE = args.cache_mode
+    if args.resume:
+        config.MODEL.RESUME = args.resume
+    if args.accumulation_steps:
+        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
+    if args.use_checkpoint:
+        config.TRAIN.USE_CHECKPOINT = True
+    if args.amp_opt_level:
+        config.AMP_OPT_LEVEL = args.amp_opt_level
+    if args.tag:
+        config.TAG = args.tag
+    if args.eval:
+        config.EVAL_MODE = True
+    if args.throughput:
+        config.THROUGHPUT_MODE = True
+
+    config.freeze()
+
+
+def get_config(args):
+    """Get a yacs CfgNode object with default values."""
+    # Return a clone so that the defaults will not be altered
+    # This is for the "local variable" use pattern
+    config = _C.clone()
+    update_config(config, args)
+
+    return config
-- 
GitLab


From 1ce7ec90c0db9ebbc2214aef59a553dbbd8e2067 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 15:20:07 +0100
Subject: [PATCH 02/11] Adjustmentations: Pass parameters as agrument parset
 instead of reading them from config file. just for puting it in the same
 harmony as the rest of the code. Otherwise I like the config file idea.

---
 .../models/hook/Swin_Transformer_Wrapper.py   | 49 +++++++++++--------
 1 file changed, 29 insertions(+), 20 deletions(-)

diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
index 528757d..69e68a9 100644
--- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
+++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
@@ -10,32 +10,41 @@ import torch
 import torch.nn as nn
 
 from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys
-
+from SSLGlacier.models.base_net import Net
 logger = logging.getLogger(__name__)
 
-class SwinUnet(nn.Module):
-    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
-        super(SwinUnet, self).__init__()
+class SwinUnet(Net):
+    def __init__(self, num_classes=21843, zero_head=False, **kwargs):
+        super().__init__(
+            groups=1,
+            widen=1,
+            width_per_group=64,
+            norm_layer=None,
+            output_dim=0 if 'output_dim' not in kwargs else kwargs['output_dim'],
+            hidden_mlp=0 if 'hidden_mlp' not in kwargs else kwargs['hidden_mlp'],
+            num_prototypes=0 if 'num_prototypes' not in kwargs else kwargs['num_prototypes'],
+            eval_mode=False,
+            normalize=False if 'normalize' not in kwargs else kwargs['normalize']
+        )
+
         self.num_classes = num_classes
         self.zero_head = zero_head
-        self.config = config
 
-        self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE,
-                                            patch_size=config.MODEL.SWIN.PATCH_SIZE,
-                                            in_chans=config.MODEL.SWIN.IN_CHANS,
+        self.swin_unet = SwinTransformerSys(img_size=kwargs.image_size,
+                                            patch_size=kwargs.swin_patch_size,
+                                            in_chans=kwargs.swin_in_chans,
                                             num_classes=self.num_classes,
-                                            embed_dim=config.MODEL.SWIN.EMBED_DIM,
-                                            depths=config.MODEL.SWIN.DEPTHS,
-                                            num_heads=config.MODEL.SWIN.NUM_HEADS,
-                                            window_size=config.MODEL.SWIN.WINDOW_SIZE,
-                                            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
-                                            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
-                                            qk_scale=config.MODEL.SWIN.QK_SCALE,
-                                            drop_rate=config.MODEL.DROP_RATE,
-                                            drop_path_rate=config.MODEL.DROP_PATH_RATE,
-                                            ape=config.MODEL.SWIN.APE,
-                                            patch_norm=config.MODEL.SWIN.PATCH_NORM,
-                                            use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+                                            embed_dim=kwargs.swin_embed_dim,
+                                            depths=kwargs.swin_depths,
+                                            num_heads=kwargs.swin_num_heads,
+                                            window_size=kwargs.swin_window_size,
+                                            mlp_ratio=kwargs.swin_mlp_ratio,
+                                            qkv_bias=kwargs.swin.QKV_BIAS,
+                                            qk_scale=kwargs.swin.QK_SCALE,
+                                            drop_rate=kwargs.drop_rate,
+                                            drop_path_rate=kwargs.drop_path_rate,
+                                            ape=kwargs.swin_ape,
+                                            patch_norm=kwargs.swin_patch_norm)
 
     def forward(self, x, y):
         if x.size()[1] == 1:
-- 
GitLab


From dd971ad725aa9f886e6a7c5e5bb70b1809732fd3 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 15:21:17 +0100
Subject: [PATCH 03/11] Adjutmentations for HookNet: Add another argument
 parser just for swin net.

---
 SSLGlacier/main.py | 54 ++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 47 insertions(+), 7 deletions(-)

diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py
index 135ed7a..61fc68d 100644
--- a/SSLGlacier/main.py
+++ b/SSLGlacier/main.py
@@ -16,7 +16,9 @@ import numpy as np
 from utils.utils import str2bool, getDataPath
 
 
-def main(hparams):
+
+
+def main(hparams, swin_hparams):
     # TODO Ask NORA, Why do you use clip gradient? did you test it?
     ######save check points##########
     checkpoint_dirs = os.path.join('checkpoints', f'ssl_zone_segmentation_{hparams.arch}')
@@ -78,7 +80,8 @@ def main(hparams):
         max_pool1d=False,
         learning_rate=hparams.base_lr,
         just_aug_for_same_assign_views=hparams.just_aug_for_same_assign_views,
-        crops_for_assign=hparams.views_for_assign
+        crops_for_assign=hparams.views_for_assign,
+        hparams=swin_hparams
     )
 
     online_evaluator = SSLOnlineEvaluator(
@@ -139,9 +142,9 @@ if __name__ == '__main__':
     #########################
     # Which agumented image should be used for assignment
     parser.add_argument("--views_for_assign", type=int, nargs="+", default=[0, 1],
-                        help="list of agumented views id used for computing assignments")
+                        help="list of augmented views id used for computing assignments")
     parser.add_argument("--just_aug_for_same_assign_views", type=str2bool, default=True,
-                        help="If you want to consider agumented version of the assigned view for swap assignment.")
+                        help="If you want to consider augmented version of the assigned view for swap assignment.")
     parser.add_argument("--temperature", default=0.1, type=float,
                         help="temperature parameter in training loss")
     parser.add_argument("--epsilons", default=0.05, type=float,
@@ -200,8 +203,45 @@ if __name__ == '__main__':
                                                     'noise': 0, 'color_jitter': 0, 'gauss_blur': 0,
                                                     'salt_or_pepper': 0, 'poission': 0, 'speckle': 1, 'zero': 0,
                                                     'normalize': None})
-    parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you wants Swav augmentaion or ours, '
-                                                                        'other papers agumentaions methods will be added later.')
+    parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you want Swav augmentaion or ours, '
+                                                             'other papers agumentaions methods will be added later.')
     hparams = parser.parse_args()
 
-    main(hparams)
+    ###########Swin transformer parameters########
+    #TODO Config file: is a good idea,maybe rafactor the code later....
+    #parameters set in config file, #
+    # Swin Transformer parameters
+    # _C.MODEL.SWIN.PATCH_SIZE = 4
+    # # _C.MODEL.SWIN.PATCH_SIZE = 2
+    # _C.MODEL.SWIN.IN_CHANS = 3
+    # # _C.MODEL.SWIN.IN_CHANS = 1
+    # _C.MODEL.SWIN.EMBED_DIM = 96
+    # _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+    # _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
+    # _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+    # _C.MODEL.SWIN.WINDOW_SIZE = 7
+    # # _C.MODEL.SWIN.WINDOW_SIZE = 2
+    # _C.MODEL.SWIN.MLP_RATIO = 4.
+    # _C.MODEL.SWIN.QKV_BIAS = True
+    # _C.MODEL.SWIN.QK_SCALE = None
+    # _C.MODEL.SWIN.APE = False
+    # _C.MODEL.SWIN.PATCH_NORM = True
+    # _C.MODEL.SWIN.FINAL_UPSAMPLE = "expand_first"
+    ##############################################
+    swin_parser =ArgumentParser()
+    swin_parser.add_argument('hidden_mlp')
+    swin_parser.add_argument('swin_patch_size', default =  4)
+    swin_parser.add_argument('swin_in_chans',  default =  1)
+    swin_parser.add_argument('swin_embed_dim', default =  96)
+    swin_parser.add_argument('swin_depths', default =  [2, 2, 6, 2])
+    swin_parser.add_argument('swin_num_heads', default =  [3, 6, 12, 24])
+    swin_parser.add_argument('swin_window_size', default =   2)
+    swin_parser.add_argument('swin_mlp_ratio', default =  4.0)
+    swin_parser.add_argument('swin_QKV_BIAS', default =  True)
+    swin_parser.add_argument('swin_QK_SCALE', default =  None)
+    swin_parser.add_argument('swin_ape', default =  False)
+    swin_parser.add_argument('swin_path_norm', default =  True)
+    swin_hparams = swin_parser.parse_args()
+
+
+    main(hparams, swin_hparams)
-- 
GitLab


From 7cb80b33e6decea5368fe7017b56bbc293f38f65 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 15:22:28 +0100
Subject: [PATCH 04/11] Adjutmentations for HookNet: Add the swinnet to the
 list of models in initmodel section.

---
 SSLGlacier/processing/swav_module.py | 32 +++++++++++++++++++++++-----
 1 file changed, 27 insertions(+), 5 deletions(-)

diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/processing/swav_module.py
index 970484b..3b57a4e 100644
--- a/SSLGlacier/processing/swav_module.py
+++ b/SSLGlacier/processing/swav_module.py
@@ -14,9 +14,11 @@ from SSLGlacier.models.RESNET import resnet18, resnet50
 from SSLGlacier.models.UNET import UNet #The main network
 from pl_bolts.optimizers.lars import LARS
 from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
-#from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
-
 
+##########Importing files for hook############
+from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
+######Temporary config file for hooknet#######
+from utils.config import get_config
 
 class SwAV(LightningModule):
     def __init__(
@@ -51,7 +53,7 @@ class SwAV(LightningModule):
         weight_decay: float = 1e-6,
         epsilon: float = 0.05,
         just_aug_for_same_assign_views: bool = False,
-        **kwargs
+        swin_hparams = {}
     ) -> None:
         """
         Args:
@@ -148,6 +150,8 @@ class SwAV(LightningModule):
         self.train_iters_per_epoch = self.num_samples // global_batch_size
         self.queue = None
 
+        ####For hook####
+        self.swin_hparams = swin_hparams
     def setup(self, stage):
         if self.queue_length > 0:
             queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
@@ -176,12 +180,30 @@ class SwAV(LightningModule):
                                 first_conv=self.first_conv,
                                 maxpool1=self.maxpool1)
         elif self.arch == "hook":
-            backbone = resnet50(normalize=True,
+            backbone = SwinUnet(
+                                img_size=224,
+                                num_classes=5,
+                                normalize=True,
                                 hidden_mlp=self.hidden_mlp,
                                 output_dim=self.feat_dim,
                                 num_prototypes=self.num_prototypes,
                                 first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
+                                maxpool1=self.maxpool1,######till here for basenet
+                                image_size= self.swin_hparams.image_size,####from here for swin itself
+                                swin_patch_size = self.swin_hparams.swin_patch_size,
+                                swin_in_chans=self.swin_hparams.swin_in_chans,
+                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
+                                swin_depths=self.swin_hparams.swin_depths,
+                                swin_num_heads=self.swin_hparams.swin_num_heads,
+                                swin_window_size=self.swin_hparams.swin_window_size,
+                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
+                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
+                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
+                                drop_rate=self.drop_rate,
+                                drop_path_rate=self.drop_path_rate,
+                                swin_ape=self.swin_hparams.swin_ape,
+                                swin_patch_norm=self.swin_hparams.swin_path_norm
+                                )
 
         elif self.arch == "Unet" or "unet":
             backbone = UNet(num_classes=5,
-- 
GitLab


From 8ab559ce004953838e241f3a8988b3b4a9c52b04 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 17:43:22 +0100
Subject: [PATCH 05/11] Adjustmentations: Pass parameters as agrument parset
 instead of reading them from config file. just for puting it in the same
 harmony as the rest of the code. Otherwise I like the config file idea.

---
 SSLGlacier/main.py | 48 +++++++++++++++++++++++++++++++++++-----------
 1 file changed, 37 insertions(+), 11 deletions(-)

diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py
index 61fc68d..73d9277 100644
--- a/SSLGlacier/main.py
+++ b/SSLGlacier/main.py
@@ -15,6 +15,7 @@ from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
 import numpy as np
 from utils.utils import str2bool, getDataPath
 
+from typing import *
 
 
 
@@ -230,17 +231,42 @@ if __name__ == '__main__':
     ##############################################
     swin_parser =ArgumentParser()
     swin_parser.add_argument('hidden_mlp')
-    swin_parser.add_argument('swin_patch_size', default =  4)
-    swin_parser.add_argument('swin_in_chans',  default =  1)
-    swin_parser.add_argument('swin_embed_dim', default =  96)
-    swin_parser.add_argument('swin_depths', default =  [2, 2, 6, 2])
-    swin_parser.add_argument('swin_num_heads', default =  [3, 6, 12, 24])
-    swin_parser.add_argument('swin_window_size', default =   2)
-    swin_parser.add_argument('swin_mlp_ratio', default =  4.0)
-    swin_parser.add_argument('swin_QKV_BIAS', default =  True)
-    swin_parser.add_argument('swin_QK_SCALE', default =  None)
-    swin_parser.add_argument('swin_ape', default =  False)
-    swin_parser.add_argument('swin_path_norm', default =  True)
+    swin_parser.add_argument('swin_patch_size', default =  4,
+                             type=int,
+                             help='The size (resolution) of each patch.')
+    swin_parser.add_argument('swin_in_chans',  default =  1,
+                             )
+    swin_parser.add_argument('swin_embed_dim', default =  96,
+                             type=int,
+                             help='Dimensionality of patch embedding.')
+    swin_parser.add_argument('swin_depths',
+                             default= [2, 2, 6, 2],
+                             type=list[int],
+                             help="Depth of each layer in the Transformer encoder."
+                             )
+    swin_parser.add_argument('swin_num_heads',
+                             default =  [3, 6, 12, 24],
+                             type=list[int],
+                             help='Number of attention heads in each layer of the Transformer encoder.')
+    swin_parser.add_argument('swin_window_size',
+                             default=2,
+                             type=int,
+                             help= 'Size of windows')
+    swin_parser.add_argument('swin_mlp_ratio',
+                             default= 4.0,type=float,
+                             help='Ratio of MLP hidden dimensionality to embedding dimensionality.')
+    swin_parser.add_argument('swin_QKV_BIAS', default =  True,
+                             type=str2bool,
+                             help='Whether or not a learnable bias should be added to the queries, keys and values.')
+    swin_parser.add_argument('swin_QK_SCALE', default =  None,
+                             type=float,
+                             help='Override default qk scale of head_dim ** -0.5 if set.')
+    swin_parser.add_argument('swin_ape', default =  False,
+                             type=str2bool,
+                             help="If True, add absolute position embedding to the patch embedding.")
+    swin_parser.add_argument('swin_path_norm', default =  True,
+                             type=str2bool,
+                             help="If True, add normalization after patch embedding.")
     swin_hparams = swin_parser.parse_args()
 
 
-- 
GitLab


From 3a920a15e2bd766ee7cf73293d1fd7d583c84d73 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 17:44:36 +0100
Subject: [PATCH 06/11] Adjutmentations for HookNet: Add explnations to patch
 expand, what is that, why is that.

---
 SSLGlacier/models/hook/patch_expand.py | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_expand.py
index ef397f8..b253c80 100644
--- a/SSLGlacier/models/hook/patch_expand.py
+++ b/SSLGlacier/models/hook/patch_expand.py
@@ -4,6 +4,16 @@ from einops import rearrange
 
 
 class PatchExpand(nn.Module):
+    """
+    from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9
+    Take the first patch expanding layer as an example, before up-sampling,
+     a linear layer is applied on the input features (W/32, H/32, 8C) to increase the feature dimension to
+    2 * the original dimension (W/32, H/32, 16C). Then, we use rearrange operation to expand the resolution
+    of the input features to 2*the input resolution and reduce the feature dimension to
+     quarter of the input dimension ((W/32, H/32, 16C) --> (W/16, H/16, 4C)).
+     We will discuss the impact of using patch expanding layer
+      to perform up-sampling inSect. 4.5 in https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9.
+    """
     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
         super().__init__()
         self.input_resolution = input_resolution
-- 
GitLab


From d78eac5a82771edcc08a26d042b3093cdc0261b6 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Thu, 14 Dec 2023 17:45:43 +0100
Subject: [PATCH 07/11] Import problem: Solve import error, by changing import
 path.

---
 SSLGlacier/models/hook/swin_transformer_block.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py
index dd8938c..52fa6dc 100644
--- a/SSLGlacier/models/hook/swin_transformer_block.py
+++ b/SSLGlacier/models/hook/swin_transformer_block.py
@@ -1,9 +1,9 @@
 import torch
 import torch.nn as nn
 from timm.models.layers import DropPath, to_2tuple
-from models.window_attention import WindowAttention
-from models.mlp import Mlp
-from models.window_utils import window_partition, window_reverse
+from SSLGlacier.models.hook.window_attention import WindowAttention
+from SSLGlacier.models.hook.mlp import Mlp
+from SSLGlacier.models.hook.window_utils import window_partition, window_reverse
 
 
 class SwinTransformerBlock(nn.Module):
-- 
GitLab


From 15b8f2c7c0a187162bf95e6983102571e3f6cb9a Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Fri, 15 Dec 2023 09:34:28 +0100
Subject: [PATCH 08/11] Refactor the code: Reduce the number of files, put the
 classes in same file, make less clumsy at least in the directory level.

---
 SSLGlacier/models/hook/Swin_Transformer.py    |  10 +-
 SSLGlacier/models/hook/basic_layers.py        | 229 ++++++++++++++
 SSLGlacier/models/hook/basic_swin_layer.py    |  74 -----
 .../models/hook/basic_swin_layer_cross.py     |  74 -----
 SSLGlacier/models/hook/basic_swin_layer_up.py |  63 ----
 SSLGlacier/models/hook/mlp.py                 |  20 --
 SSLGlacier/models/hook/patch_embedding.py     |  51 ---
 SSLGlacier/models/hook/patch_merging.py       |  51 ---
 .../{patch_expand.py => patch_processing.py}  | 123 +++++++-
 .../models/hook/swin_transformer_block.py     | 136 --------
 .../hook/swin_transformer_block_cross.py      | 151 ---------
 .../models/hook/swin_transformer_blocks.py    | 296 ++++++++++++++++++
 SSLGlacier/models/hook/window_attention.py    | 104 ------
 SSLGlacier/models/hook/window_attentions.py   | 240 ++++++++++++++
 .../models/hook/window_cross_attention.py     | 106 -------
 SSLGlacier/models/hook/window_utils.py        |  34 --
 16 files changed, 886 insertions(+), 876 deletions(-)
 create mode 100644 SSLGlacier/models/hook/basic_layers.py
 delete mode 100644 SSLGlacier/models/hook/basic_swin_layer.py
 delete mode 100644 SSLGlacier/models/hook/basic_swin_layer_cross.py
 delete mode 100644 SSLGlacier/models/hook/basic_swin_layer_up.py
 delete mode 100644 SSLGlacier/models/hook/mlp.py
 delete mode 100644 SSLGlacier/models/hook/patch_embedding.py
 delete mode 100644 SSLGlacier/models/hook/patch_merging.py
 rename SSLGlacier/models/hook/{patch_expand.py => patch_processing.py} (56%)
 delete mode 100644 SSLGlacier/models/hook/swin_transformer_block.py
 delete mode 100644 SSLGlacier/models/hook/swin_transformer_block_cross.py
 create mode 100644 SSLGlacier/models/hook/swin_transformer_blocks.py
 delete mode 100644 SSLGlacier/models/hook/window_attention.py
 create mode 100644 SSLGlacier/models/hook/window_attentions.py
 delete mode 100644 SSLGlacier/models/hook/window_cross_attention.py
 delete mode 100644 SSLGlacier/models/hook/window_utils.py

diff --git a/SSLGlacier/models/hook/Swin_Transformer.py b/SSLGlacier/models/hook/Swin_Transformer.py
index 992126a..1315bfd 100644
--- a/SSLGlacier/models/hook/Swin_Transformer.py
+++ b/SSLGlacier/models/hook/Swin_Transformer.py
@@ -2,12 +2,10 @@ import torch
 import torch.nn as nn
 import numpy as np
 from timm.models.layers import trunc_normal_
-from SSLGlacier.models.hook.patch_embedding import PatchEmbed
-from SSLGlacier.models.hook.patch_merging import PatchMerging
-from SSLGlacier.models.hook.patch_expand import PatchExpand, PatchExpandC, FinalPatchExpand_X4
-from SSLGlacier.models.hook.basic_swin_layer import BasicLayer
-from SSLGlacier.models.hook.basic_swin_layer_up import BasicLayer_up
-from SSLGlacier.models.hook.basic_swin_layer_cross import BasicLayer_Cross
+from SSLGlacier.models.hook.patch_processing import PatchEmbed
+from SSLGlacier.models.hook.patch_processing import PatchMerging
+from SSLGlacier.models.hook.patch_processing import PatchExpand, PatchExpandC, FinalPatchExpand_X4
+from SSLGlacier.models.hook.basic_layers import BasicLayer, BasicLayer_up, BasicLayer_Cross
 
 torch.set_printoptions(threshold=np.inf)
 torch.set_printoptions(linewidth=300)
diff --git a/SSLGlacier/models/hook/basic_layers.py b/SSLGlacier/models/hook/basic_layers.py
new file mode 100644
index 0000000..079d1cc
--- /dev/null
+++ b/SSLGlacier/models/hook/basic_layers.py
@@ -0,0 +1,229 @@
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from SSLGlacier.models.hook.patch_processing import PatchExpand
+from SSLGlacier.models.hook.swin_transformer_blocks import SwinTransformerBlock, SwinTransformerBlock_Cross
+#-----------------------------------------------
+########### Basic Swin Layer Classes ###########
+#-----------------------------------------------
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution  # 56
+        self.depth = depth  # 2
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,  # 96, 56
+                                 num_heads=num_heads, window_size=window_size,  # 3, 14
+                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                 mlp_ratio=mlp_ratio,  # 4
+                                 qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
+                                 drop=drop, attn_drop=attn_drop,  # 0.2, 0
+                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                 norm_layer=norm_layer)  # layer-norm
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+#-----------------------------------------------
+####### End of Basic Swin Layer Classes ########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+####### Basic Swin Layer cross Classes #########
+#-----------------------------------------------
+
+class BasicLayer_Cross(nn.Module):  # Input of second attention is output of first. See paper
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution  # 56
+        self.depth = depth  # 2
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution,  # 96, 56
+                                       num_heads=num_heads, window_size=window_size,  # 3, 14
+                                       shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                       mlp_ratio=mlp_ratio,  # 4
+                                       qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
+                                       drop=drop, attn_drop=attn_drop,  # 0.2, 0
+                                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                       norm_layer=norm_layer)  # layer-norm
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, y):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x, y)
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+#-----------------------------------------------
+##### End of Basic Swin Layer cross Classes ####
+#-----------------------------------------------
+
+################################################
+################################################
+
+
+#-----------------------------------------------
+######### Basic Swin Layer UP Classes ##########
+#-----------------------------------------------
+
+class BasicLayer_up(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+                                 num_heads=num_heads, window_size=window_size,
+                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                 mlp_ratio=mlp_ratio,
+                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                 drop=drop, attn_drop=attn_drop,
+                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                 norm_layer=norm_layer)
+            for i in range(depth)])
+
+        # patch merging layer
+        if upsample is not None:
+            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
+        else:
+            self.upsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.upsample is not None:
+            x = self.upsample(x)
+        return x
+
+#-----------------------------------------------
+##### End of Basic Swin Layer UP Classes ####
+#-----------------------------------------------
diff --git a/SSLGlacier/models/hook/basic_swin_layer.py b/SSLGlacier/models/hook/basic_swin_layer.py
deleted file mode 100644
index e9570cb..0000000
--- a/SSLGlacier/models/hook/basic_swin_layer.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock
-
-
-class BasicLayer(nn.Module):
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution  # 56
-        self.depth = depth  # 2
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,  # 96, 56
-                                 num_heads=num_heads, window_size=window_size,  # 3, 14
-                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                 mlp_ratio=mlp_ratio,  # 4
-                                 qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
-                                 drop=drop, attn_drop=attn_drop,  # 0.2, 0
-                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                 norm_layer=norm_layer)  # layer-norm
-            for i in range(depth)])
-
-        # patch merging layer
-        if downsample is not None:
-            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
-        else:
-            self.downsample = None
-
-    def forward(self, x):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x)
-        if self.downsample is not None:
-            x = self.downsample(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
-    def flops(self):
-        flops = 0
-        for blk in self.blocks:
-            flops += blk.flops()
-        if self.downsample is not None:
-            flops += self.downsample.flops()
-        return flops
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/basic_swin_layer_cross.py b/SSLGlacier/models/hook/basic_swin_layer_cross.py
deleted file mode 100644
index e9481d9..0000000
--- a/SSLGlacier/models/hook/basic_swin_layer_cross.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block_cross import SwinTransformerBlock_Cross
-
-
-class BasicLayer_Cross(nn.Module):  # Input of second attention is output of first. See paper
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution  # 56
-        self.depth = depth  # 2
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution,  # 96, 56
-                                       num_heads=num_heads, window_size=window_size,  # 3, 14
-                                       shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                       mlp_ratio=mlp_ratio,  # 4
-                                       qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
-                                       drop=drop, attn_drop=attn_drop,  # 0.2, 0
-                                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                       norm_layer=norm_layer)  # layer-norm
-            for i in range(depth)])
-
-        # patch merging layer
-        if downsample is not None:
-            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
-        else:
-            self.downsample = None
-
-    def forward(self, x, y):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x, y)
-        if self.downsample is not None:
-            x = self.downsample(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
-    def flops(self):
-        flops = 0
-        for blk in self.blocks:
-            flops += blk.flops()
-        if self.downsample is not None:
-            flops += self.downsample.flops()
-        return flops
diff --git a/SSLGlacier/models/hook/basic_swin_layer_up.py b/SSLGlacier/models/hook/basic_swin_layer_up.py
deleted file mode 100644
index 9ee1b87..0000000
--- a/SSLGlacier/models/hook/basic_swin_layer_up.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock
-from SSLGlacier.models.hook.patch_expand import PatchExpand
-
-
-class BasicLayer_up(nn.Module):
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim
-        self.input_resolution = input_resolution
-        self.depth = depth
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
-                                 num_heads=num_heads, window_size=window_size,
-                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                 mlp_ratio=mlp_ratio,
-                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
-                                 drop=drop, attn_drop=attn_drop,
-                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                 norm_layer=norm_layer)
-            for i in range(depth)])
-
-        # patch merging layer
-        if upsample is not None:
-            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
-        else:
-            self.upsample = None
-
-    def forward(self, x):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x)
-        if self.upsample is not None:
-            x = self.upsample(x)
-        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/mlp.py b/SSLGlacier/models/hook/mlp.py
deleted file mode 100644
index 6e7730f..0000000
--- a/SSLGlacier/models/hook/mlp.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import torch.nn as nn
-
-
-class Mlp(nn.Module):
-    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
-        super().__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        self.fc1 = nn.Linear(in_features, hidden_features)
-        self.act = act_layer()
-        self.fc2 = nn.Linear(hidden_features, out_features)
-        self.drop = nn.Dropout(drop)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x = self.act(x)
-        x = self.drop(x)
-        x = self.fc2(x)
-        x = self.drop(x)
-        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/patch_embedding.py b/SSLGlacier/models/hook/patch_embedding.py
deleted file mode 100644
index e9fba67..0000000
--- a/SSLGlacier/models/hook/patch_embedding.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch.nn as nn
-from timm.models.layers import to_2tuple
-
-
-class PatchEmbed(nn.Module):
-    r""" Image to Patch Embedding
-
-    Args:
-        img_size (int): Image size.  Default: 224.
-        patch_size (int): Patch token size. Default: 4.
-        in_chans (int): Number of input image channels. Default: 3.
-        embed_dim (int): Number of linear projection output channels. Default: 96.
-        norm_layer (nn.Module, optional): Normalization layer. Default: None
-    """
-
-    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
-        super().__init__()
-        img_size = to_2tuple(img_size)
-        patch_size = to_2tuple(patch_size)
-        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
-        self.img_size = img_size  # 224
-        self.patch_size = patch_size  # 4
-        self.patches_resolution = patches_resolution  # 56
-        self.num_patches = patches_resolution[0] * patches_resolution[1]  # 3136
-
-        self.in_chans = in_chans
-        self.embed_dim = embed_dim
-
-        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 22, 96, 56, 56
-
-        if norm_layer is not None:
-            self.norm = norm_layer(embed_dim)
-        else:
-            self.norm = None
-
-    def forward(self, x):
-        B, C, H, W = x.shape
-        # FIXME look at relaxing size constraints
-        assert H == self.img_size[0] and W == self.img_size[1], \
-            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
-        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C (22, 3136, 96)
-        if self.norm is not None:
-            x = self.norm(x)
-        return x
-
-    def flops(self):
-        Ho, Wo = self.patches_resolution
-        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
-        if self.norm is not None:
-            flops += Ho * Wo * self.embed_dim
-        return flops
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/patch_merging.py b/SSLGlacier/models/hook/patch_merging.py
deleted file mode 100644
index 0567504..0000000
--- a/SSLGlacier/models/hook/patch_merging.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class PatchMerging(nn.Module):
-    r""" Patch Merging Layer.
-
-    Args:
-        input_resolution (tuple[int]): Resolution of input feature.
-        dim (int): Number of input channels.
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.input_resolution = input_resolution
-        self.dim = dim
-        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
-        self.norm = norm_layer(4 * dim)
-
-    def forward(self, x):
-        """
-        x: B, H*W, C
-        """
-        H, W = self.input_resolution
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
-        x = x.view(B, H, W, C)
-        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
-        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
-        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
-        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
-        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
-        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
-
-        x = self.norm(x)
-        x = self.reduction(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
-    def flops(self):
-        H, W = self.input_resolution
-        flops = H * W * self.dim
-        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
-        return flops
-
diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_processing.py
similarity index 56%
rename from SSLGlacier/models/hook/patch_expand.py
rename to SSLGlacier/models/hook/patch_processing.py
index b253c80..d7e1f33 100644
--- a/SSLGlacier/models/hook/patch_expand.py
+++ b/SSLGlacier/models/hook/patch_processing.py
@@ -1,8 +1,124 @@
+import torch
 import torch.nn as nn
 from einops import rearrange
-# import torch
+from timm.models.layers import to_2tuple
+#-----------------------------------------------
+############## Patch Merging Class ###########
+#-----------------------------------------------
+class PatchMerging(nn.Module):
+    r""" Patch Merging Layer.
 
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+        x = x.view(B, H, W, C)
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+    def flops(self):
+        H, W = self.input_resolution
+        flops = H * W * self.dim
+        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+        return flops
+
+#-----------------------------------------------
+######### End of Patch Merging Class ###########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+############## Patch Embedding Class ###########
+#-----------------------------------------------
+class PatchEmbed(nn.Module):
+    r""" Image to Patch Embedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size  # 224
+        self.patch_size = patch_size  # 4
+        self.patches_resolution = patches_resolution  # 56
+        self.num_patches = patches_resolution[0] * patches_resolution[1]  # 3136
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 22, 96, 56, 56
+
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
 
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C (22, 3136, 96)
+        if self.norm is not None:
+            x = self.norm(x)
+        return x
+
+    def flops(self):
+        Ho, Wo = self.patches_resolution
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops
+#-----------------------------------------------
+######### End of Patch Embeding Class ###########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+############# Patch Expand Classes #############
+#-----------------------------------------------
 class PatchExpand(nn.Module):
     """
     from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9
@@ -36,7 +152,6 @@ class PatchExpand(nn.Module):
 
         return x
 
-
 # class PatchExpand4(nn.Module):
 #     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
 #         super().__init__()
@@ -59,8 +174,6 @@ class PatchExpand(nn.Module):
 #         x = self.norm(x)
 #
 #         return x
-
-
 class PatchExpandC(nn.Module):
     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
         super().__init__()
@@ -84,7 +197,6 @@ class PatchExpandC(nn.Module):
 
         return x
 
-
 # class PCAa(object):
 #     def __init__(self, n_components=2):
 #         self.n_components = n_components
@@ -136,7 +248,6 @@ class PatchExpandC(nn.Module):
 #
 #         return x
 
-
 class FinalPatchExpand_X4(nn.Module):
     def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
         super().__init__()
diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py
deleted file mode 100644
index 52fa6dc..0000000
--- a/SSLGlacier/models/hook/swin_transformer_block.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import DropPath, to_2tuple
-from SSLGlacier.models.hook.window_attention import WindowAttention
-from SSLGlacier.models.hook.mlp import Mlp
-from SSLGlacier.models.hook.window_utils import window_partition, window_reverse
-
-
-class SwinTransformerBlock(nn.Module):
-    r""" Swin Transformer Block.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resulotion.
-        num_heads (int): Number of attention heads.
-        window_size (int): Window size.
-        shift_size (int): Shift size for SW-MSA.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float, optional): Stochastic depth rate. Default: 0.0
-        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
-                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution
-        self.num_heads = num_heads
-        self.window_size = window_size
-        self.shift_size = shift_size  # 0 or 7
-        self.mlp_ratio = mlp_ratio
-        if min(self.input_resolution) <= self.window_size:
-            # if window size is larger than input resolution, we don't partition windows
-            self.shift_size = 0
-            self.window_size = min(self.input_resolution)
-        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
-        self.norm1 = norm_layer(dim)
-        self.attn = WindowAttention(
-            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
-            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-        self.norm2 = norm_layer(dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
-        if self.shift_size > 0:
-            # calculate attention mask for SW-MSA
-            H, W = self.input_resolution  # 56, 56
-            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
-            h_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-
-            w_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-            cnt = 0
-            for h in h_slices:
-                for w in w_slices:
-                    img_mask[:, h, w, :] = cnt
-                    cnt += 1
-
-            mask_windows = window_partition(img_mask,
-                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
-            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
-            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
-            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
-                0.0))  # 64, 49, 49 # 16, 196, 196
-        else:
-            attn_mask = None
-
-        self.register_buffer("attn_mask", attn_mask)
-
-    def forward(self, x):
-        H, W = self.input_resolution
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-
-        shortcut = x
-        x = self.norm1(x)
-        x = x.view(B, H, W, C)
-        # cyclic shift
-        if self.shift_size > 0:
-            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-        else:
-            shifted_x = x
-
-        # partition windows
-        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        x_windows = x_windows.view(-1, self.window_size * self.window_size,
-                                   C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        # W-MSA/SW-MSA
-        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
-
-        # merge windows
-        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
-        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
-        # reverse cyclic shift
-        if self.shift_size > 0:
-            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
-        else:
-            x = shifted_x
-
-        x = x.view(B, H * W, C)  # 7, 56*56, 96
-
-        # FFN
-        x = shortcut + self.drop_path(x)
-        x = x + self.drop_path(self.mlp(self.norm2(x)))
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
-               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
-    def flops(self):
-        flops = 0
-        H, W = self.input_resolution
-        # norm1
-        flops += self.dim * H * W
-        # W-MSA/SW-MSA
-        nW = H * W / self.window_size / self.window_size
-        flops += nW * self.attn.flops(self.window_size * self.window_size)
-        # mlp
-        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
-        # norm2
-        flops += self.dim * H * W
-        return flops
diff --git a/SSLGlacier/models/hook/swin_transformer_block_cross.py b/SSLGlacier/models/hook/swin_transformer_block_cross.py
deleted file mode 100644
index 8dbfb67..0000000
--- a/SSLGlacier/models/hook/swin_transformer_block_cross.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import DropPath, to_2tuple
-from models.window_cross_attention import WindowAttention_Cross
-from models.mlp import Mlp
-from models.window_utils import window_partition, window_reverse
-
-
-class SwinTransformerBlock_Cross(nn.Module):
-    r""" Swin Transformer Block.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resulotion.
-        num_heads (int): Number of attention heads.
-        window_size (int): Window size.
-        shift_size (int): Shift size for SW-MSA.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float, optional): Stochastic depth rate. Default: 0.0
-        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
-                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution
-        self.num_heads = num_heads
-        self.window_size = window_size
-        self.shift_size = shift_size  # 0 or 7
-        self.mlp_ratio = mlp_ratio
-        if min(self.input_resolution) <= self.window_size:
-            # if window size is larger than input resolution, we don't partition windows
-            self.shift_size = 0
-            self.window_size = min(self.input_resolution)
-        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
-        self.norm1 = norm_layer(dim)
-        self.attn = WindowAttention_Cross(
-            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
-            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-        self.norm2 = norm_layer(dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
-        if self.shift_size > 0:
-            # calculate attention mask for SW-MSA
-            H, W = self.input_resolution  # 56, 56
-            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
-            h_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-
-            w_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-            cnt = 0
-            for h in h_slices:
-                for w in w_slices:
-                    img_mask[:, h, w, :] = cnt
-                    cnt += 1
-
-            mask_windows = window_partition(img_mask,
-                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
-            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
-            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
-            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
-                0.0))  # 64, 49, 49 # 16, 196, 196
-        else:
-            attn_mask = None
-
-        self.register_buffer("attn_mask", attn_mask)
-
-    def forward(self, target, context):
-        H, W = self.input_resolution
-        B, L, C = target.shape
-        assert L == H * W, "input feature has wrong size"
-
-        shortcut = target
-        target = self.norm1(target)
-        target = target.view(B, H, W, C)
-
-        context = self.norm1(context)
-        context = context.view(B, H, W, C)
-
-        # cyclic shift
-        if self.shift_size > 0:
-            shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-            shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-        else:
-            shifted_target = target
-            shifted_context = context
-
-        # partition windows
-        target_windows = window_partition(shifted_target,
-                                          self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        context_windows = window_partition(shifted_context,
-                                           self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        # print(x_windows.shape, 'ss')
-        target_windows = target_windows.view(-1, self.window_size * self.window_size,
-                                             C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        context_windows = context_windows.view(-1, self.window_size * self.window_size,
-                                               C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        # W-MSA/SW-MSA
-        attn_windows = self.attn(target_windows, context_windows,
-                                 mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
-
-        # merge windows
-        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
-        shifted_target = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
-        # reverse cyclic shift
-        if self.shift_size > 0:
-            target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
-        else:
-            target = shifted_target
-
-        target = target.view(B, H * W, C)  # 7, 56*56, 96
-
-        # FFN
-        target = shortcut + self.drop_path(target)
-        target = target + self.drop_path(self.mlp(self.norm2(target)))
-
-        return target
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
-               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
-    def flops(self):
-        flops = 0
-        H, W = self.input_resolution
-        # norm1
-        flops += self.dim * H * W
-        # W-MSA/SW-MSA
-        nW = H * W / self.window_size / self.window_size
-        flops += nW * self.attn.flops(self.window_size * self.window_size)
-        # mlp
-        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
-        # norm2
-        flops += self.dim * H * W
-        return flops
diff --git a/SSLGlacier/models/hook/swin_transformer_blocks.py b/SSLGlacier/models/hook/swin_transformer_blocks.py
new file mode 100644
index 0000000..f6745c8
--- /dev/null
+++ b/SSLGlacier/models/hook/swin_transformer_blocks.py
@@ -0,0 +1,296 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, to_2tuple
+from SSLGlacier.models.hook.window_attentions import WindowAttention, WindowAttention_Cross, window_partition, window_reverse
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+class SwinTransformerBlock_Cross(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size  # 0 or 7
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention_Cross(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution  # 56, 56
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
+            h_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+
+            w_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(img_mask,
+                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
+                0.0))  # 64, 49, 49 # 16, 196, 196
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+
+    def forward(self, target, context):
+        H, W = self.input_resolution
+        B, L, C = target.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = target
+        target = self.norm1(target)
+        target = target.view(B, H, W, C)
+
+        context = self.norm1(context)
+        context = context.view(B, H, W, C)
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_target = target
+            shifted_context = context
+
+        # partition windows
+        target_windows = window_partition(shifted_target,
+                                          self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        context_windows = window_partition(shifted_context,
+                                           self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        # print(x_windows.shape, 'ss')
+        target_windows = target_windows.view(-1, self.window_size * self.window_size,
+                                             C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        context_windows = context_windows.view(-1, self.window_size * self.window_size,
+                                               C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(target_windows, context_windows,
+                                 mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
+        shifted_target = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            target = shifted_target
+
+        target = target.view(B, H * W, C)  # 7, 56*56, 96
+
+        # FFN
+        target = shortcut + self.drop_path(target)
+        target = target + self.drop_path(self.mlp(self.norm2(target)))
+
+        return target
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+        # W-MSA/SW-MSA
+        nW = H * W / self.window_size / self.window_size
+        flops += nW * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
+
+
+class SwinTransformerBlock(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size  # 0 or 7
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution  # 56, 56
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
+            h_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+
+            w_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(img_mask,
+                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
+                0.0))  # 64, 49, 49 # 16, 196, 196
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+
+    def forward(self, x):
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_x = x
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        x_windows = x_windows.view(-1, self.window_size * self.window_size,
+                                   C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
+        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        x = x.view(B, H * W, C)  # 7, 56*56, 96
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+        # W-MSA/SW-MSA
+        nW = H * W / self.window_size / self.window_size
+        flops += nW * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
diff --git a/SSLGlacier/models/hook/window_attention.py b/SSLGlacier/models/hook/window_attention.py
deleted file mode 100644
index 39e9be5..0000000
--- a/SSLGlacier/models/hook/window_attention.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import trunc_normal_
-
-
-class WindowAttention(nn.Module):
-    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
-    It supports both of shifted and non-shifted window.
-
-    Args:
-        dim (int): Number of input channels.
-        window_size (tuple[int]): The height and width of the window.
-        num_heads (int): Number of attention heads.
-        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
-        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
-        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
-    """
-
-    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
-        super().__init__()
-        self.dim = dim
-        self.window_size = window_size  # Wh, Ww
-        self.num_heads = num_heads  # 3, 6, 12, 24
-        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
-
-        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
-
-        self.relative_position_bias_table = nn.Parameter(
-            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
-
-        coords_h = torch.arange(self.window_size[0])
-        coords_w = torch.arange(self.window_size[1])
-
-        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
-
-        coords_flatten = torch.flatten(coords, 1)  # 2, 49
-        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
-        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
-        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
-        relative_coords[:, :, 1] += self.window_size[1] - 1
-        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
-        relative_position_index = relative_coords.sum(-1)  # 49, 49
-        self.register_buffer("relative_position_index", relative_position_index)
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
-        self.proj_drop = nn.Dropout(proj_drop)
-
-        trunc_normal_(self.relative_position_bias_table, std=.02)
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, x, mask=None):
-        """
-        Args:
-            x: input features with shape of (num_windows*B, N, C)
-            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
-        """
-        B_, N, C = x.shape  # -1, 49, 96
-        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
-                                                                                         4)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)  # 448, 3, 49, 32
-
-        q = q * self.scale  # 1/np.sqrt(32)
-        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
-        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
-            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
-            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
-
-        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
-        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
-
-        if mask is not None:
-            nW = mask.shape[0]
-            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
-                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
-            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
-            attn = self.softmax(attn)
-        else:
-            attn = self.softmax(attn)
-
-        attn = self.attn_drop(attn)
-        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-    def extra_repr(self) -> str:
-        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
-    def flops(self, N):
-        # calculate flops for 1 window with token length of N
-        flops = 0
-        # qkv = self.qkv(x)
-        flops += N * self.dim * 3 * self.dim
-        # attn = (q @ k.transpose(-2, -1))
-        flops += self.num_heads * N * (self.dim // self.num_heads) * N
-        #  x = (attn @ v)
-        flops += self.num_heads * N * N * (self.dim // self.num_heads)
-        # x = self.proj(x)
-        flops += N * self.dim * self.dim
-        return flops
diff --git a/SSLGlacier/models/hook/window_attentions.py b/SSLGlacier/models/hook/window_attentions.py
new file mode 100644
index 0000000..4225479
--- /dev/null
+++ b/SSLGlacier/models/hook/window_attentions.py
@@ -0,0 +1,240 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import trunc_normal_
+
+
+class WindowAttention(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads  # 3, 6, 12, 24
+        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
+
+        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
+
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
+
+        coords_flatten = torch.flatten(coords, 1)  # 2, 49
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
+        relative_position_index = relative_coords.sum(-1)  # 49, 49
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape  # -1, 49, 96
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
+                                                                                         4)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)  # 448, 3, 49, 32
+
+        q = q * self.scale  # 1/np.sqrt(32)
+        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
+            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
+        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
+                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
+            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, N):
+        # calculate flops for 1 window with token length of N
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += N * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * N * (self.dim // self.num_heads) * N
+        #  x = (attn @ v)
+        flops += self.num_heads * N * N * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += N * self.dim * self.dim
+        return flops
+
+class WindowAttention_Cross(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads  # 3, 6, 12, 24
+        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
+
+        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
+
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
+
+        coords_flatten = torch.flatten(coords, 1)  # 2, 49
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
+        relative_position_index = relative_coords.sum(-1)  # 49, 49
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, y, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape  # -1, 49, 96
+
+        q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+
+        q = q * self.scale  # 1/np.sqrt(32)
+        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
+            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
+        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
+                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
+            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, N):
+        # calculate flops for 1 window with token length of N
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += N * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * N * (self.dim // self.num_heads) * N
+        #  x = (attn @ v)
+        flops += self.num_heads * N * N * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += N * self.dim * self.dim
+        return flops
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape  # 7, 56, 56, 96; 1, 56, 56, 1
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size,
+               C)  # 7, 8, 7, 8, 7, 96  # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size,
+                                                            C)  # 7, 8, 8, 7, 7, 96;  -1, 7, 7, 96
+    # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))  # 448 / (56*56/7/7) = 7
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # 7, 8, 8, 7, 7, 96
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # 7, 8, 7, 8, 96; 7, 56, 56, 96
+    return x
+
diff --git a/SSLGlacier/models/hook/window_cross_attention.py b/SSLGlacier/models/hook/window_cross_attention.py
deleted file mode 100644
index 7b5ddfd..0000000
--- a/SSLGlacier/models/hook/window_cross_attention.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import trunc_normal_
-
-
-class WindowAttention_Cross(nn.Module):
-    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
-    It supports both of shifted and non-shifted window.
-
-    Args:
-        dim (int): Number of input channels.
-        window_size (tuple[int]): The height and width of the window.
-        num_heads (int): Number of attention heads.
-        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
-        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
-        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
-    """
-
-    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
-        super().__init__()
-        self.dim = dim
-        self.window_size = window_size  # Wh, Ww
-        self.num_heads = num_heads  # 3, 6, 12, 24
-        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
-
-        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
-
-        self.relative_position_bias_table = nn.Parameter(
-            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
-
-        coords_h = torch.arange(self.window_size[0])
-        coords_w = torch.arange(self.window_size[1])
-
-        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
-
-        coords_flatten = torch.flatten(coords, 1)  # 2, 49
-        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
-        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
-        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
-        relative_coords[:, :, 1] += self.window_size[1] - 1
-        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
-        relative_position_index = relative_coords.sum(-1)  # 49, 49
-        self.register_buffer("relative_position_index", relative_position_index)
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
-        self.proj_drop = nn.Dropout(proj_drop)
-
-        trunc_normal_(self.relative_position_bias_table, std=.02)
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, x, y, mask=None):
-        """
-        Args:
-            x: input features with shape of (num_windows*B, N, C)
-            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
-        """
-        B_, N, C = x.shape  # -1, 49, 96
-
-        q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-
-        q = q * self.scale  # 1/np.sqrt(32)
-        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
-        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
-            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
-            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
-
-        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
-        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
-
-        if mask is not None:
-            nW = mask.shape[0]
-            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
-                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
-            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
-            attn = self.softmax(attn)
-        else:
-            attn = self.softmax(attn)
-
-        attn = self.attn_drop(attn)
-        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-    def extra_repr(self) -> str:
-        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
-    def flops(self, N):
-        # calculate flops for 1 window with token length of N
-        flops = 0
-        # qkv = self.qkv(x)
-        flops += N * self.dim * 3 * self.dim
-        # attn = (q @ k.transpose(-2, -1))
-        flops += self.num_heads * N * (self.dim // self.num_heads) * N
-        #  x = (attn @ v)
-        flops += self.num_heads * N * N * (self.dim // self.num_heads)
-        # x = self.proj(x)
-        flops += N * self.dim * self.dim
-        return flops
-
diff --git a/SSLGlacier/models/hook/window_utils.py b/SSLGlacier/models/hook/window_utils.py
deleted file mode 100644
index fe08bad..0000000
--- a/SSLGlacier/models/hook/window_utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-def window_partition(x, window_size):
-    """
-    Args:
-        x: (B, H, W, C)
-        window_size (int): window size
-
-    Returns:
-        windows: (num_windows*B, window_size, window_size, C)
-    """
-    B, H, W, C = x.shape  # 7, 56, 56, 96; 1, 56, 56, 1
-    x = x.view(B, H // window_size, window_size, W // window_size, window_size,
-               C)  # 7, 8, 7, 8, 7, 96  # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1
-    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size,
-                                                            C)  # 7, 8, 8, 7, 7, 96;  -1, 7, 7, 96
-    # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1
-    return windows
-
-
-def window_reverse(windows, window_size, H, W):
-    """
-    Args:
-        windows: (num_windows*B, window_size, window_size, C)
-        window_size (int): Window size
-        H (int): Height of image
-        W (int): Width of image
-
-    Returns:
-        x: (B, H, W, C)
-    """
-    B = int(windows.shape[0] / (H * W / window_size / window_size))  # 448 / (56*56/7/7) = 7
-    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # 7, 8, 8, 7, 7, 96
-    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # 7, 8, 7, 8, 96; 7, 56, 56, 96
-    return x
-
-- 
GitLab


From 7f7de8cd002e873590c5d4622af3c66b6151687b Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Fri, 15 Dec 2023 09:38:22 +0100
Subject: [PATCH 09/11] Refactor the code: Change the base class name to
 SWINNET, it was Net.

---
 SSLGlacier/models/RESNET.py                        |  4 ++--
 SSLGlacier/models/{base_net.py => SWINNET.py}      |  2 +-
 SSLGlacier/models/UNET.py                          |  4 ++--
 SSLGlacier/models/hook/Swin_Transformer_Wrapper.py | 14 ++++++--------
 4 files changed, 11 insertions(+), 13 deletions(-)
 rename SSLGlacier/models/{base_net.py => SWINNET.py} (99%)

diff --git a/SSLGlacier/models/RESNET.py b/SSLGlacier/models/RESNET.py
index 9a094ef..0ee1479 100644
--- a/SSLGlacier/models/RESNET.py
+++ b/SSLGlacier/models/RESNET.py
@@ -1,7 +1,7 @@
 """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py."""
 import torch
 from torch import nn
-from SSLGlacier.models.base_net import Net
+from SSLGlacier.models.SWINNET import SwinNet
 
 
 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
@@ -122,7 +122,7 @@ class Bottleneck(nn.Module):
         return self.relu(out)
 
 
-class ResNet(Net):
+class ResNet(SwinNet):
     def __init__(
             self,
             block,
diff --git a/SSLGlacier/models/base_net.py b/SSLGlacier/models/SWINNET.py
similarity index 99%
rename from SSLGlacier/models/base_net.py
rename to SSLGlacier/models/SWINNET.py
index 4489ec3..9924814 100644
--- a/SSLGlacier/models/base_net.py
+++ b/SSLGlacier/models/SWINNET.py
@@ -3,7 +3,7 @@ import torch
 from torch import nn
 
 
-class Net(nn.Module):
+class SwinNet(nn.Module):
     def __init__(
             self,
             groups=1,
diff --git a/SSLGlacier/models/UNET.py b/SSLGlacier/models/UNET.py
index ad4f65c..5061bb0 100644
--- a/SSLGlacier/models/UNET.py
+++ b/SSLGlacier/models/UNET.py
@@ -1,9 +1,9 @@
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F  # noqa: N812
-from SSLGlacier.models.base_net import Net
+from SSLGlacier.models.SWINNET import SwinNet
 
-class UNet(Net):
+class UNet(SwinNet):
     """Pytorch Lightning implementation of U-Net.
     Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
     <https://arxiv.org/abs/1505.04597>`_
diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
index 69e68a9..4f96144 100644
--- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
+++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
@@ -1,19 +1,17 @@
 # coding=utf-8
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
 
 import copy
 import logging
-
 import torch
-import torch.nn as nn
-
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from SSLGlacier.models.SWINNET import SwinNet
 from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys
-from SSLGlacier.models.base_net import Net
+
 logger = logging.getLogger(__name__)
 
-class SwinUnet(Net):
+class SwinUnet(SwinNet):
     def __init__(self, num_classes=21843, zero_head=False, **kwargs):
         super().__init__(
             groups=1,
-- 
GitLab


From a9f5a009a609a22d5a416f8eb8b6442182e2a926 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Fri, 15 Dec 2023 09:46:31 +0100
Subject: [PATCH 10/11] Refactoring: Make the folder names more clear.

---
 SSLGlacier/main.py                   |   4 +-
 SSLGlacier/modules/agumentations_.py | 411 ++++++++++++++++++++
 SSLGlacier/modules/datamodule_.py    | 149 +++++++
 SSLGlacier/modules/semi_module.py    |  29 ++
 SSLGlacier/modules/swav_module.py    | 562 +++++++++++++++++++++++++++
 SSLGlacier/modules/transformers_.py  | 248 ++++++++++++
 6 files changed, 1401 insertions(+), 2 deletions(-)
 create mode 100644 SSLGlacier/modules/agumentations_.py
 create mode 100644 SSLGlacier/modules/datamodule_.py
 create mode 100644 SSLGlacier/modules/semi_module.py
 create mode 100644 SSLGlacier/modules/swav_module.py
 create mode 100644 SSLGlacier/modules/transformers_.py

diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py
index 73d9277..cc7d5b4 100644
--- a/SSLGlacier/main.py
+++ b/SSLGlacier/main.py
@@ -5,12 +5,12 @@ from torchinfo import summary as torchinfo_summery
 from torchsummary import summary as torch_summary
 from torch._inductor.config import trace
 
-from processing import datamodule_, transformers_
+from modules import datamodule_, transformers_
 import pytorch_lightning as pl
 from models import ResNet_light_mine, Swav_Orig_ResNet
 import os
 import torch
-from processing.swav_module import SwAV
+from modules.swav_module import SwAV
 from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
 import numpy as np
 from utils.utils import str2bool, getDataPath
diff --git a/SSLGlacier/modules/agumentations_.py b/SSLGlacier/modules/agumentations_.py
new file mode 100644
index 0000000..05cb774
--- /dev/null
+++ b/SSLGlacier/modules/agumentations_.py
@@ -0,0 +1,411 @@
+# PILRandomGaussianBlur and get_color_distortion are used and implemented
+# by Swav  and  SimCLR - https://arxiv.org/abs/2002.05709
+import random
+from logging import getLogger
+
+import torch.nn
+from PIL import ImageFilter
+import numpy as np
+import torchvision.transforms as transforms
+import torchvision, tormentor
+from torchvision.transforms import functional as F
+import torch
+import matplotlib.pyplot as plt
+
+logger = getLogger()
+
+
+class Compose(object):
+    """
+    Class for chaining transforms together
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return [image, target]
+
+    def __getitem__(self, item):
+        return self.transforms[item]
+
+class DoNothing(torch.nn.Module):
+    def __init__(self):
+        super(DoNothing, self).__init__()
+
+    def __call__(self, img, mask):
+        return [img, mask]
+
+class Cropper(torch.nn.Module):
+    def __init__(self, i, j, h, w):
+        super(Cropper, self).__init__()
+        self.left = i
+        self.right = j
+        self.height = h
+        self.width = w
+
+    def __call__(self, img, mask):
+        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
+        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
+        return [cropped_img, cropped_mask]
+
+
+class RandomCropper(torch.nn.Module):
+    '''
+          This function returns one patch at time, if you need more patches in one image, call it in a loop
+          Args:
+              orig_img: get an png or jpg image and crop one patch randomly, in both image and mask
+              orig_mask: This is the images mask, we crop same area from
+          Returns:
+              A list of two argument, first cropped patch in image,second cropped patch in mask
+          '''
+
+    def __init__(self, size):
+        super(RandomCropper, self).__init__()
+        self.left = 0
+        self.right = 0
+        self.height = 0
+        self.width = 0
+        self.size = size
+
+    def forward(self, img, mask):
+        self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params(
+            img, output_size=(self.size, self.size))
+        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
+        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
+        return [cropped_img, cropped_mask]
+
+
+class PILRandomGaussianBlur(torch.nn.Module):
+    def __init__(self, radius_min=0.1, radius_max=2.):
+        """
+           Apply Gaussian Blur to the PIL image. Take the radius and probability of
+           application as the parameter.
+           This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
+           """
+        super(PILRandomGaussianBlur, self).__init__()
+        self.radius_min = radius_min
+        self.radius_max = radius_max
+
+    def forward(self, img, mask):
+        return [img.filter(
+            ImageFilter.GaussianBlur(
+                radius=random.uniform(self.radius_min, self.radius_max)
+            )
+        ), mask]
+
+
+class RandomHorizontalFlip(torch.nn.Module):
+    def __init__(self):
+        super(RandomHorizontalFlip, self).__init__()
+
+    def forward(self, img, mask):
+        image = torchvision.transforms.functional.hflip(img)
+        mask = torchvision.transforms.functional.hflip(mask)
+        return [image, mask]
+
+
+class GetColorDistortion(torch.nn.Module):
+    def __int__(self, s=1.0):
+        super(GetColorDistortion, self).__init__()
+        self.s = s
+
+    def forward(self, img, mask):
+        color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
+        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
+        rnd_gray = transforms.RandomGrayscale(p=0.2)
+        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
+        return color_distort(img, mask)
+
+
+# TODO this net
+class GaussNoise(torch.nn.Module):
+    def __init__(self, mean=0, var=10000):
+        super(GaussNoise, self).__init__()
+        self.mean = mean
+        self.var = var
+
+    def forward(self, img, mask):
+        row, col = img.size
+        sigma = self.var ** 0.5
+        gauss = np.random.normal(self.mean, sigma, (row, col))
+        gauss = gauss.reshape(row, col)
+        noisy = img + gauss
+        return [noisy, mask]
+
+
+class SaltPepperNoise(torch.nn.Module):
+    def __init__(self, salt_or_pepper=0.5):
+        super(SaltPepperNoise, self).__init__()
+        self.salt_or_pepper = salt_or_pepper
+
+    def forward(self, img, mask):
+        if len(img.size) == 3:
+            img_size = img.size[1] * img.size[2]
+        else:
+            img_size = img.size[0] * img.size[1]
+        amount = .4
+        noisy = np.copy(img)
+        # Salt mode
+        num_salt = np.ceil(amount * img_size * self.salt_or_pepper)
+        target_pixels = [np.random.randint(0, i - 1, int(num_salt))
+                         for i in img.size]
+        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
+        for i, j in target_pixels:
+            noisy[i][j] = 1
+        # Pepper mode
+        num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper))
+        target_pixels = [np.random.randint(0, i - 1, int(num_pepper))
+                         for i in img.size]
+        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
+
+        for i, j in target_pixels:
+            noisy[i][j] = 0
+
+        # plt.imshow(img)
+        # plt.imshow(noisy)
+        # plt.show()
+        return [noisy, mask]
+
+
+class PoissionNoise(torch.nn.Module):
+    def __init__(self):
+        super(PoissionNoise, self).__init__()
+
+    def forward(self, img, mask):
+        vals = len(np.unique(img))
+        vals = 2 ** np.ceil(np.log2(vals))
+        noisy = np.random.poisson(img * vals) / float(vals)
+        return [noisy, mask]
+
+
+# TODO this one
+class SpeckleNoise(torch.nn.Module):
+    def __init__(self):
+        super(SpeckleNoise, self).__init__()
+
+    def forward(self, img, mask):
+        row, col = img.size
+        gauss = np.random.randn(row, col)
+        gauss = gauss.reshape(row, col)
+        noisy = img + img * gauss
+        return [noisy, mask]
+
+
+class SetZeroNoise(torch.nn.Module):
+    def __init__(self):
+        super(SetZeroNoise, self).__init__()
+
+    def forward(self, img, mask):
+        row, col = img.size
+        img_size = row * col
+        random_rows = np.random.randint(0, int(row), (1, int(img_size)))
+        random_cols = np.random.randint(0, int(col), (1, int(img_size)))
+        target_pixels = list(zip(random_rows[0], random_cols[0]))
+        for pix in target_pixels:
+            img[pix[0], pix[1]] = 0
+        return [img, mask]
+
+
+####################################################################################
+
+###########Base_code Agumentations suggestions: #TODO do not need them now##########
+class MyWrap(torch.nn.Module):
+    """
+    Random wrap augmentation taken from tormentor
+    """
+
+    def __init__(self):
+        super(MyWrap, self).__init__()
+
+    def forward(self, img, target):
+        # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics
+        wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)),
+                                                          intensity=tormentor.random.Uniform(value_range=(.0, 1.)))
+        wrap = wrap_rand()
+        image = wrap(img)
+        mask = wrap(target, is_mask=True)
+        return [image, mask]
+
+
+class Rotate(torch.nn.Module):
+    """
+    Random rotation augmentation
+    """
+
+    def __init__(self):
+        super(Rotate, self).__init__()
+
+    def forward(self, img, target):
+        random = np.random.randint(0, 3)
+        angle = 90
+        if random == 1:
+            angle = 180
+        elif random == 2:
+            angle = 270
+        image = torchvision.transforms.functional.rotate(img, angle=angle)
+        mask = torchvision.transforms.functional.rotate(target, angle=angle)
+        return [image, mask.squeeze(0)]
+
+
+class Bright(torch.nn.Module):
+    """
+    Random brightness adjustment augmentations
+    """
+
+    def __init__(self,
+                 lower_band: float = -0.2,
+                 upper_band: float = 0.2):
+        super(Bright, self).__init__()
+        self.lower_band = lower_band
+        self.upper_band = upper_band
+
+    def forward(self, img, mask):
+        bright_rand = tormentor.Brightness.override_distributions(
+            brightness=tormentor.random.Uniform((self.lower_band, self.upper_band)))
+        # bright = bright_rand()
+        # image_transformed = img.clone()
+        image_transformed = bright_rand(np.asarray(img))
+        # set NA areas back to zero
+        image_transformed.seed[img == 0] = 0.0
+
+        return [image_transformed.seed, mask]
+
+
+class Noise(torch.nn.Module):
+    """
+    Random additive noise augmentation
+    """
+
+    def __init__(self):
+        super(Noise, self).__init__()
+
+    def forward(self, img, target):
+        # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again
+        noise = torch.normal(mean=0, std=0.3, size=img.size)
+        image = img + img * noise
+        image[image > 1.0] = 1.0
+        image[image < 0.0] = 0.0
+
+        return [image, target]
+
+
+class Resize(torch.nn.Module):
+    # Library code
+    """Resize the input image to the given size.
+    If the image is torch Tensor, it is expected
+    to have [..., H, W] shape, where ... means a maximum of two leading dimensions
+
+    .. warning::
+        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
+        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
+        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
+        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
+        closer.
+
+    Args:
+        size (sequence or int): Desired output size. If size is a sequence like
+            (h, w), output size will be matched to this. If size is an int,
+            smaller edge of the image will be matched to this number.
+            i.e, if height > width, then image will be rescaled to
+            (size * height / width, size).
+
+            .. note::
+                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
+        interpolation (InterpolationMode): Desired interpolation enum defined by
+            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
+            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
+            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
+            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
+        max_size (int, optional): The maximum allowed for the longer edge of
+            the resized image. If the longer edge of the image is greater
+            than ``max_size`` after being resized according to ``size``,
+            ``size`` will be overruled so that the longer edge is equal to
+            ``max_size``.
+            As a result, the smaller edge may be shorter than ``size``. This
+            is only supported if ``size`` is an int (or a sequence of length
+            1 in torchscript mode).
+        antialias (bool, optional): Whether to apply antialiasing.
+            It only affects **tensors** with bilinear or bicubic modes and it is
+            ignored otherwise: on PIL images, antialiasing is always applied on
+            bilinear or bicubic modes; on other modes (for PIL images and
+            tensors), antialiasing makes no sense and this parameter is ignored.
+            Possible values are:
+
+            - ``True``: will apply antialiasing for bilinear or bicubic modes.
+              Other mode aren't affected. This is probably what you want to use.
+            - ``False``: will not apply antialiasing for tensors on any mode. PIL
+              images are still antialiased on bilinear or bicubic modes, because
+              PIL doesn't support no antialias.
+            - ``None``: equivalent to ``False`` for tensors and ``True`` for
+              PIL images. This value exists for legacy reasons and you probably
+              don't want to use it unless you really know what you are doing.
+
+            The current default is ``None`` **but will change to** ``True`` **in
+            v0.17** for the PIL and Tensor backends to be consistent.
+    """
+
+    def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
+        super().__init__()
+        self.size = size
+        self.max_size = max_size
+        self.interpolation = interpolation
+        self.antialias = antialias
+
+    def forward(self, img, mask):
+        """
+        Args:
+            img (PIL Image or Tensor): Image to be scaled.
+
+        Returns:
+            PIL Image or Tensor: Rescaled image.
+        """
+        return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask]
+
+
+class CenterCrop(torch.nn.Module):
+    """Crops the given image at the center.
+    If the image is torch Tensor, it is expected
+    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
+    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
+
+    Args:
+        size (sequence or int): Desired output size of the crop. If size is an
+            int instead of sequence like (h, w), a square crop (size, size) is
+            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
+    """
+
+    def __init__(self, size):
+        super().__init__()
+        self.size = size  # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
+
+    def forward(self, img, mask):
+        """
+        Args:
+            img (PIL Image or Tensor): Image to be cropped.
+
+        Returns:
+            PIL Image or Tensor: Cropped image.
+        """
+        ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER
+        W, H = img.size
+
+        WW = (W // self.size) + 2
+        HH = (H // self.size) + 2
+
+        return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))]
+
+
+class ToTensorZones():
+    def __call__(self, image, target):
+        image = F.to_tensor(np.array(image).astype(np.float32))
+        target = torch.from_numpy(np.array(target).astype(np.float32))
+        # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254
+        target[target == 0] = 0
+        target[target == 64] = 1
+        target[target == 127] = 2
+        target[target == 254] = 3
+        # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3
+        return [image, target]
diff --git a/SSLGlacier/modules/datamodule_.py b/SSLGlacier/modules/datamodule_.py
new file mode 100644
index 0000000..edf3648
--- /dev/null
+++ b/SSLGlacier/modules/datamodule_.py
@@ -0,0 +1,149 @@
+from torchvision.transforms import functional as F
+from pl_bolts.models.self_supervised import swav
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+import pytorch_lightning as pl
+from torchvision import transforms as transform_lib
+from . import agumentations_ as our_transforms
+import torchvision
+import numpy as np
+from typing import Tuple, List, Any, Callable, Optional
+import torch
+import cv2
+from PIL import Image,ImageOps
+import os
+
+# TODO add moco, simclr,cpc and amdim transformers
+class SSLDataset(Dataset):
+    def __init__(self, parent_dir, transform,return_index=False, **kwargs):  # TODO Pass them as arguments
+        '''
+        Args:
+            mode: can be tested, train, or validation
+            parent_dir: directory in which the folders test, train and validation exists
+        '''
+        self.mode = kwargs['mode'] if kwargs['mode'] else 'train'
+        self.return_index = return_index
+        self.images_path = os.path.join(parent_dir, "sar_images", self.mode)
+        self.masks_path = os.path.join(parent_dir, 'zones', self.mode)
+        self.images = os.listdir(self.images_path)
+        self.masks = os.listdir(self.masks_path)
+        assert len(self.masks) == len(self.images), "You don't have the same number of images and masks"
+        self.transform = transform
+        # Sort than images and masks fit together
+        self.images.sort()
+        self.masks.sort()
+
+        # Let shuffle and save indices, and load them in next calls if they exist
+        if not os.path.exists(os.path.join("data_processing", "data_splits")):
+            os.makedirs(os.path.join("data_processing", "data_splits"))
+        if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")):
+            shuffle = np.random.permutation(len(self.images))
+            # Works for numpy version >= 1.5
+            np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
+                       newline=' ')
+
+        else:
+            # use already existing shuffle
+            with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp:
+                lines = fp.readlines()
+                shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines]
+                # if lengths do not match, we need to create a new permutation
+                if len(shuffle) != len(self.images):
+                    shuffle = np.random.permutation(len(self.images))
+                    np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
+                               newline=' ')
+
+        self.images = np.array(self.images)  # Why!!!!
+        self.masks = np.array(self.masks)
+        self.images = self.images[shuffle].copy()
+        self.masks = self.masks[shuffle].copy()
+        self.images = list(self.images)
+        self.masks = list(self.masks)  # Why!!!!
+
+    def __len__(self):
+        return len(self.images)
+
+    def __getitem__(self, index):
+        img_name = self.images[index]
+        masks_name = self.masks[index]
+        assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \
+            "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name
+        #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE)
+        #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE)
+        image = Image.open(os.path.join(self.images_path, img_name).__str__())
+        image = ImageOps.grayscale(image)
+        #image = np.array(image).astype(np.float32)
+
+        #rgbimg = Image.new("RGBA", image.size)
+        #image = rgbimg.paste(image)
+        #image = ImageOps.grayscale(image)
+        mask = Image.open(os.path.join(self.masks_path, masks_name).__str__())
+        #rgbimg = Image.new("RGBA", mask.size)
+        #mask = rgbimg.paste(rgbimg)
+        mask = ImageOps.grayscale(mask)
+        #mask = np.array(mask).astype(np.float32)
+        # TODO check if it works or not, it should return multiple transformation for one image
+        # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!!
+        # TODO this is a list of tuples of img and masks, think if you want to change it later
+
+       # to_tensor = our_transforms.ToTensorZones()
+       # _, mask = to_tensor(image,mask)
+        if self.transform is not None:
+            augmented_imgs, masks = self.transform(image, mask)
+        if self.return_index:
+            return index, augmented_imgs, mask, img_name, masks_name
+        return augmented_imgs, masks
+
+
+class SSLDataModule(pl.LightningDataModule):
+    # Base is adopted from Nora's code
+    def __init__(self, batch_size, parent_dir, args):
+        """
+        :param batch_size: batch size
+        :param target: Either 'zones' or 'front'. Tells which masks should be used.
+        """
+        super().__init__()
+        self.transforms = None
+        self.batch_size = batch_size
+        self.glacier_test = None
+        self.glacier_train = None
+        self.glacier_val = None
+        self.parent_dir = parent_dir
+        self.aug_args = args.agumentations
+        self.size_crops = args.size_crops
+        self.num_crops = args.nmb_crops
+
+    def prepare_data(self):
+        # download,
+        # only called on 1 GPU/TPU in distributed
+        pass
+    def __len__(self):
+        return len(self.glacier_train)
+    def setup(self, stage=None):
+        # process and split here
+        # make assignments here (val/train/test split)
+        # called on every process in DDP
+        if stage == 'test' or stage is None:
+            self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
+            self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage)
+        if stage == 'fit' or stage is None:
+            self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
+            self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train')
+
+            self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
+            self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val')
+
+    def train_dataloader(self) -> DataLoader:
+        return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True,
+                          drop_last=True)
+
+    def val_dataloader(self)-> DataLoader:
+        return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True)
+
+    def test_dataloader(self)-> DataLoader:
+        return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True,
+                          drop_last=True)  # TODO self.batch_size
+
+    def _default_transforms(self) -> Callable:
+        #return transform_lib.Compose([transform_lib.ToTensor()])
+        return our_transforms.Compose([our_transforms.ToTensorZones()])
diff --git a/SSLGlacier/modules/semi_module.py b/SSLGlacier/modules/semi_module.py
new file mode 100644
index 0000000..9798a36
--- /dev/null
+++ b/SSLGlacier/modules/semi_module.py
@@ -0,0 +1,29 @@
+from pytorch_lightning import LightningModule
+from typing import Any
+class Semi_(LightningModule):
+    def __init__(self):
+        pass
+
+    def setup(self, stage: str) -> None:
+        pass
+
+    def init_model(self):
+        pass
+
+    def forward(self, *args: Any, **kwargs: Any) -> Any:
+        pass
+
+    def on_train_epoch_start(self) -> None:
+        pass
+
+    def shared_step(self, batch):
+        inputs, y = batch
+        inputs = inputs[:, -1]
+        embedding = self.model(inputs)
+        ####Where should I add noise....
+        ####assume embedding[b][i] belongs to batch b view i
+        ####if we have two view, one with noise and one without, then
+        ##### select some portion of each view belongs to different classes.
+        ##### pos_set = same classess ..
+        ##### neg_set = diff classes
+
diff --git a/SSLGlacier/modules/swav_module.py b/SSLGlacier/modules/swav_module.py
new file mode 100644
index 0000000..3b57a4e
--- /dev/null
+++ b/SSLGlacier/modules/swav_module.py
@@ -0,0 +1,562 @@
+"""Adapted from official swav implementation: https://github.com/facebookresearch/swav."""
+import os
+from argparse import ArgumentParser
+
+import torch
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
+from torch import nn
+
+#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
+from SSLGlacier.models.losses import SWAVLoss
+from SSLGlacier.models.RESNET import resnet18, resnet50
+#from Base_Code.models.ParentUNet import UNet
+from SSLGlacier.models.UNET import UNet #The main network
+from pl_bolts.optimizers.lars import LARS
+from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
+
+##########Importing files for hook############
+from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
+######Temporary config file for hooknet#######
+from utils.config import get_config
+
+class SwAV(LightningModule):
+    def __init__(
+        self,
+        gpus: int,
+        num_samples: int,
+        batch_size: int,
+        dataset: str,
+        num_nodes: int = 1,
+        arch: str = "resnet50",
+        hidden_mlp: int = 2048,
+        feat_dim: int = 128,
+        warmup_epochs: int = 10,
+        max_epochs: int = 100,
+        num_prototypes: int = 3000,
+        freeze_prototypes_epochs: int = 1,
+        temperature: float = 0.1,
+        sinkhorn_iterations: int = 3,
+        queue_length: int = 0,  # must be divisible by total batch-size
+        queue_path: str = "queue",
+        epoch_queue_starts: int = 15,
+        crops_for_assign: tuple = (0, 1),
+        num_crops: tuple = (2, 6),
+        num_augs: int= 2,
+        first_conv: bool = True,
+        maxpool1: bool = True,
+        optimizer: str = "adam",
+        exclude_bn_bias: bool = False,
+        start_lr: float = 0.0,
+        learning_rate: float = 1e-3,
+        final_lr: float = 0.0,
+        weight_decay: float = 1e-6,
+        epsilon: float = 0.05,
+        just_aug_for_same_assign_views: bool = False,
+        swin_hparams = {}
+    ) -> None:
+        """
+        Args:
+            gpus: number of gpus per node used in training, passed to SwAV module
+                to manage the queue and select distributed sinkhorn
+            num_nodes: number of nodes to train on
+            num_samples: number of image samples used for training
+            batch_size: batch size per GPU in ddp
+            dataset: dataset being used for train/val
+            arch: encoder architecture used for pre-training
+            hidden_mlp: hidden layer of non-linear projection head, set to 0
+                to use a linear projection head
+            feat_dim: output dim of the projection head
+            warmup_epochs: apply linear warmup for this many epochs
+            max_epochs: epoch count for pre-training
+            num_prototypes: count of prototype vectors
+            freeze_prototypes_epochs: epoch till which gradients of prototype layer
+                are frozen
+            temperature: loss temperature
+            sinkhorn_iterations: iterations for sinkhorn normalization
+            queue_length: set queue when batch size is small,
+                must be divisible by total batch-size (i.e. total_gpus * batch_size),
+                set to 0 to remove the queue
+            queue_path: folder within the logs directory
+            epoch_queue_starts: start uing the queue after this epoch
+            crops_for_assign: list of crop ids for computing assignment
+            num_crops: number of global and local crops, ex: [2, 6]
+            first_conv: keep first conv same as the original resnet architecture,
+                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
+            maxpool1: keep first maxpool layer same as the original resnet architecture,
+                if set to false, first maxpool is turned off (cifar10, maybe stl10)
+            optimizer: optimizer to use
+            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
+            start_lr: starting lr for linear warmup
+            learning_rate: learning rate
+            final_lr: float = final learning rate for cosine weight decay
+            weight_decay: weight decay for optimizer
+            epsilon: epsilon val for swav assignments
+        """
+        super().__init__()
+        self.save_hyperparameters()
+
+        self.gpus = gpus
+        self.num_nodes = num_nodes
+        self.arch = arch
+        self.dataset = dataset
+        self.num_samples = num_samples
+        self.batch_size = batch_size
+
+        self.hidden_mlp = hidden_mlp
+        self.feat_dim = feat_dim
+        self.num_prototypes = num_prototypes
+        self.freeze_prototypes_epochs = freeze_prototypes_epochs
+        self.sinkhorn_iterations = sinkhorn_iterations
+
+        self.queue_length = queue_length
+        self.queue_path = queue_path
+        self.epoch_queue_starts = epoch_queue_starts
+        self.crops_for_assign = crops_for_assign
+        self.num_crops = num_crops
+        self.num_augs = num_augs
+
+        self.first_conv = first_conv
+        self.maxpool1 = maxpool1
+
+        self.optim = optimizer
+        self.exclude_bn_bias = exclude_bn_bias
+        self.weight_decay = weight_decay
+        self.epsilon = epsilon
+        self.temperature = temperature
+
+        self.start_lr = start_lr
+        self.final_lr = final_lr
+        self.learning_rate = learning_rate
+        self.warmup_epochs = warmup_epochs
+        self.max_epochs = max_epochs
+        self.just_aug_for_same_assign_views = just_aug_for_same_assign_views
+
+        self.model = self.init_model()
+        self.criterion = SWAVLoss(
+            gpus=self.gpus,
+            num_nodes=self.num_nodes,
+            temperature=self.temperature,
+            crops_for_assign=self.crops_for_assign,
+            num_crops=self.num_crops,
+            num_augs=self.num_augs,
+            sinkhorn_iterations=self.sinkhorn_iterations,
+            epsilon=self.epsilon,
+            just_aug_for_same_assign_views= self.just_aug_for_same_assign_views
+        )
+        self.use_the_queue = None
+        # compute iters per epoch
+        global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
+        self.train_iters_per_epoch = self.num_samples // global_batch_size
+        self.queue = None
+
+        ####For hook####
+        self.swin_hparams = swin_hparams
+    def setup(self, stage):
+        if self.queue_length > 0:
+            queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
+            if not os.path.exists(queue_folder):
+                os.makedirs(queue_folder)
+
+            self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth")
+
+            if os.path.isfile(self.queue_path):
+                self.queue = torch.load(self.queue_path)["queue"]
+
+    def init_model(self):
+        if self.arch == "resnet18":
+            backbone = resnet18(normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1)
+
+        elif self.arch == "resnet50":
+            backbone = resnet50(normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1)
+        elif self.arch == "hook":
+            backbone = SwinUnet(
+                                img_size=224,
+                                num_classes=5,
+                                normalize=True,
+                                hidden_mlp=self.hidden_mlp,
+                                output_dim=self.feat_dim,
+                                num_prototypes=self.num_prototypes,
+                                first_conv=self.first_conv,
+                                maxpool1=self.maxpool1,######till here for basenet
+                                image_size= self.swin_hparams.image_size,####from here for swin itself
+                                swin_patch_size = self.swin_hparams.swin_patch_size,
+                                swin_in_chans=self.swin_hparams.swin_in_chans,
+                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
+                                swin_depths=self.swin_hparams.swin_depths,
+                                swin_num_heads=self.swin_hparams.swin_num_heads,
+                                swin_window_size=self.swin_hparams.swin_window_size,
+                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
+                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
+                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
+                                drop_rate=self.drop_rate,
+                                drop_path_rate=self.drop_path_rate,
+                                swin_ape=self.swin_hparams.swin_ape,
+                                swin_patch_norm=self.swin_hparams.swin_path_norm
+                                )
+
+        elif self.arch == "Unet" or "unet":
+            backbone = UNet(num_classes=5,
+                            input_channels=1,
+                            num_layers=5,
+                            features_start=64,
+                            bilinear=False,
+                            normalize=True,
+                            hidden_mlp=self.hidden_mlp,
+                            output_dim=self.feat_dim,
+                            num_prototypes=self.num_prototypes,
+                            first_conv=self.first_conv,
+                            maxpool1=self.maxpool1
+                            )
+        else:
+            raise f'{self.arch} model is not defined'
+
+        return backbone
+
+    def forward(self, x):
+        # pass single batch from the resnet backbone
+        return self.model.forward_backbone(x)
+
+    def on_train_epoch_start(self):
+        if self.queue_length > 0:
+            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
+                self.queue = torch.zeros(
+                    len(self.crops_for_assign),
+                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
+                    self.feat_dim,
+                )
+
+            if self.queue is not None:
+                self.queue = self.queue.to(self.device)
+
+        self.use_the_queue = False
+
+    def on_train_epoch_end(self) -> None:
+        if self.queue is not None:
+            torch.save({"queue": self.queue}, self.queue_path)
+
+    def on_after_backward(self):
+        if self.current_epoch < self.freeze_prototypes_epochs:
+            for name, p in self.model.named_parameters():
+                if "prototypes" in name:
+                    p.grad = None
+
+    def shared_step(self, batch):
+        if self.dataset == "stl10":
+            unlabeled_batch = batch[0]
+            batch = unlabeled_batch
+
+        inputs, y = batch
+        inputs = inputs[:-1]  # remove online train/eval transforms at this point
+
+        # 1. normalize the prototypes
+        with torch.no_grad():
+            w = self.model.prototypes.weight.data.clone()
+            w = nn.functional.normalize(w, dim=1, p=2)
+            self.model.prototypes.weight.copy_(w)
+
+        # 2. multi-res forward passes
+        embedding, output = self.model(inputs)
+        embedding = embedding.detach()
+        bs = inputs[0].size(0)
+
+        # SWAV loss computation
+        loss, queue, use_queue = self.criterion(
+            output=output,
+            embedding=embedding,
+            prototype_weights=self.model.prototypes.weight,
+            batch_size=bs,
+            queue=self.queue,
+            use_queue=self.use_the_queue,
+        )
+        self.queue = queue
+        self.use_the_queue = use_queue
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        loss = self.shared_step(batch)
+
+        self.log("train_loss", loss, on_step=True, on_epoch=False)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        loss = self.shared_step(batch)
+
+        self.log("val_loss", loss, on_step=False, on_epoch=True)
+        return loss
+
+    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")):
+        params = []
+        excluded_params = []
+
+        for name, param in named_params:
+            if not param.requires_grad:
+                continue
+            if any(layer_name in name for layer_name in skip_list):
+                excluded_params.append(param)
+            else:
+                params.append(param)
+
+        return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}]
+
+    def configure_optimizers(self):
+        if self.exclude_bn_bias:
+            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
+        else:
+            params = self.parameters()
+
+        if self.optim == "lars":
+            optimizer = LARS(
+                params,
+                lr=self.learning_rate,
+                momentum=0.9,
+                weight_decay=self.weight_decay,
+                trust_coefficient=0.001,
+            )
+        elif self.optim == "adam":
+            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
+
+        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
+        total_steps = self.train_iters_per_epoch * self.max_epochs
+
+        scheduler = {
+            "scheduler": torch.optim.lr_scheduler.LambdaLR(
+                optimizer,
+                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
+            ),
+            "interval": "step",
+            "frequency": 1,
+        }
+
+        return [optimizer], [scheduler]
+
+    @staticmethod
+    def add_model_specific_args(parent_parser):
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+
+        # model params
+        parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
+        # specify flags to store false
+        parser.add_argument("--first_conv", action="store_false")
+        parser.add_argument("--maxpool1", action="store_false")
+        parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head")
+        parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension")
+        parser.add_argument("--online_ft", action="store_true")
+        parser.add_argument("--fp32", action="store_true")
+
+        # transform params
+        parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur")
+        parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength")
+        parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10")
+        parser.add_argument("--data_dir", type=str, default=".", help="path to download data")
+        parser.add_argument("--queue_path", type=str, default="queue", help="path for queue")
+
+        parser.add_argument(
+            "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])"
+        )
+        parser.add_argument(
+            "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])"
+        )
+        parser.add_argument(
+            "--min_scale_crops",
+            type=float,
+            default=[0.33, 0.10],
+            nargs="+",
+            help="argument in RandomResizedCrop (example: [0.14, 0.05])",
+        )
+        parser.add_argument(
+            "--max_scale_crops",
+            type=float,
+            default=[1, 0.33],
+            nargs="+",
+            help="argument in RandomResizedCrop (example: [1., 0.14])",
+        )
+
+        # training params
+        parser.add_argument("--fast_dev_run", default=1, type=int)
+        parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
+        parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
+        parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
+        parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
+        parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay")
+        parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
+        parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
+        parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
+        parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
+
+        parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay")
+        parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate")
+        parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate")
+        parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate")
+
+        # swav params
+        parser.add_argument(
+            "--crops_for_assign",
+            type=int,
+            nargs="+",
+            default=[0, 1],
+            help="list of crops id used for computing assignments",
+        )
+        parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss")
+        parser.add_argument(
+            "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm"
+        )
+        parser.add_argument(
+            "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm"
+        )
+        parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes")
+        parser.add_argument(
+            "--queue_length",
+            type=int,
+            default=0,
+            help="length of the queue (0 for no queue); must be divisible by total batch size",
+        )
+        parser.add_argument(
+            "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue"
+        )
+        parser.add_argument(
+            "--freeze_prototypes_epochs",
+            default=1,
+            type=int,
+            help="freeze the prototypes during this many epochs from the start",
+        )
+
+        return parser
+
+
+def cli_main():
+    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
+    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
+    from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
+
+    parser = ArgumentParser()
+
+    # model args
+    parser = SwAV.add_model_specific_args(parser)
+    args = parser.parse_args()
+
+    if args.dataset == "stl10":
+        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        dm.train_dataloader = dm.train_dataloader_mixed
+        dm.val_dataloader = dm.val_dataloader_mixed
+        args.num_samples = dm.num_unlabeled_samples
+
+        args.maxpool1 = False
+
+        normalization = stl10_normalization()
+    elif args.dataset == "cifar10":
+        args.batch_size = 2
+        args.num_workers = 0
+
+        dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        args.num_samples = dm.num_samples
+
+        args.maxpool1 = False
+        args.first_conv = False
+
+        normalization = cifar10_normalization()
+
+        # cifar10 specific params
+        args.size_crops = [32, 16]
+        args.num_crops = [2, 1]
+        args.gaussian_blur = False
+    elif args.dataset == "imagenet":
+        args.maxpool1 = True
+        args.first_conv = True
+        normalization = imagenet_normalization()
+
+        args.size_crops = [224, 96]
+        args.num_crops = [2, 6]
+        args.min_scale_crops = [0.14, 0.05]
+        args.max_scale_crops = [1.0, 0.14]
+        args.gaussian_blur = True
+        args.jitter_strength = 1.0
+
+        args.batch_size = 64
+        args.num_nodes = 8
+        args.gpus = 8  # per-node
+        args.max_epochs = 800
+
+        args.optimizer = "lars"
+        args.learning_rate = 4.8
+        args.final_lr = 0.0048
+        args.start_lr = 0.3
+
+        args.num_prototypes = 3000
+        args.online_ft = True
+
+        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
+
+        args.num_samples = dm.num_samples
+        args.input_height = dm.dims[-1]
+    else:
+        raise NotImplementedError("other datasets have not been implemented till now")
+
+    dm.train_transforms = SwAVTrainDataTransform(
+        normalize=normalization,
+        size_crops=args.size_crops,
+        num_crops=args.num_crops,
+        min_scale_crops=args.min_scale_crops,
+        max_scale_crops=args.max_scale_crops,
+        gaussian_blur=args.gaussian_blur,
+        jitter_strength=args.jitter_strength,
+    )
+
+    dm.val_transforms = SwAVEvalDataTransform(
+        normalize=normalization,
+        size_crops=args.size_crops,
+        num_crops=args.num_crops,
+        min_scale_crops=args.min_scale_crops,
+        max_scale_crops=args.max_scale_crops,
+        gaussian_blur=args.gaussian_blur,
+        jitter_strength=args.jitter_strength,
+    )
+
+    # swav model init
+    model = SwAV(**args.__dict__)
+
+    online_evaluator = None
+    if args.online_ft:
+        # online eval
+        online_evaluator = SSLOnlineEvaluator(
+            drop_p=0.0,
+            hidden_dim=None,
+            z_dim=args.hidden_mlp,
+            num_classes=dm.num_classes,
+            dataset=args.dataset,
+        )
+
+    lr_monitor = LearningRateMonitor(logging_interval="step")
+    model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss")
+    callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
+    callbacks.append(lr_monitor)
+
+    trainer = Trainer(
+        max_epochs=args.max_epochs,
+        max_steps=None if args.max_steps == -1 else args.max_steps,
+        gpus=args.gpus,
+        num_nodes=args.num_nodes,
+        accelerator="ddp" if args.gpus > 1 else None,
+        sync_batchnorm=args.gpus > 1,
+        precision=32 if args.fp32 else 16,
+        callbacks=callbacks,
+        fast_dev_run=args.fast_dev_run,
+    )
+
+    trainer.fit(model, datamodule=dm)
+
+
+if __name__ == "__main__":
+    cli_main()
\ No newline at end of file
diff --git a/SSLGlacier/modules/transformers_.py b/SSLGlacier/modules/transformers_.py
new file mode 100644
index 0000000..f8290bb
--- /dev/null
+++ b/SSLGlacier/modules/transformers_.py
@@ -0,0 +1,248 @@
+from pl_bolts.models.self_supervised import swav
+from torchvision import transforms  # need it for validation and fine tuning
+from . import agumentations_ as ag
+from typing import Tuple, List, Dict
+from torch import Tensor
+
+
+
+class OurTrainTransformer():
+    def __init__(self,
+                 size_crops: Tuple[int] = (294, 94),
+                 nmb_crops: Tuple[int] = (2, 4),
+                 augs:Dict = {'SaltPepper':1},
+                 use_hook_former: bool = True) -> List:
+        # TODO you can set it in a way that making sure you have one sample from each argumentation method
+        self.size_crops = size_crops
+        self.num_crops = nmb_crops
+        augmented_imgs = []
+        self.use_hook_former = use_hook_former
+        self.transform = []
+        if augs['OrigPixelValues']:
+            augmented_imgs.append(ag.DoNothing())
+        if augs['RGaussianB']:
+            augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.))
+        if augs['GaussiN']:
+            augmented_imgs.append(ag.GaussNoise(mean=0, var=10000))
+        if augs['SaltPepper']:
+            augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2))
+        if augs['ZeroN']:
+            augmented_imgs.append(ag.SetZeroNoise())
+        # Noras augmentations methods
+        if augs['flip']:
+            # TODO manage this, how to add mask for the times we need transformation for masks too
+            augmented_imgs.append(ag.RandomHorizentalFlip())
+            # image = torchvision.transforms.functional.hflip(image)
+        # mask = torchvision.transforms.functional.hflip(mask)
+        # image, mask = ag.ToTensorZones(image=image, target=np.array(mask))
+        # augmented_imgs.append((image, mask))
+        if augs['rotate']:
+            augmented_imgs.append(ag.Rotate())
+        if augs['bright']:
+            augmented_imgs.append(ag.Bright())
+        if augs['wrap']:
+            augmented_imgs.append()
+        if augs['noise']:
+            augmented_imgs.append(ag.Noise())
+
+        if augs['normalize'] is not None:
+            self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']])
+        else:
+            self.final_transform = ag.ToTensorZones()
+
+        if self.use_hook_former:
+            self.hook_former_transformer(augmented_imgs)
+        else:
+            self.other_networks_transformers(augmented_imgs)
+
+
+
+    def __call__(self, image, mask):
+        """
+        import matplotlib.pyplot as plt
+        img1 =transformed_imgs[0].detach().cpu().numpy().squeeze()
+        plt.imshow(img1)
+        plt.show()
+        """
+        transformed_img_mask = []
+        for transform in self.transform:
+            if isinstance(transform.__getitem__(0), ag.RandomCropper):
+                transformed_img_mask.append(transform(image, mask))
+                self.i = transform.__getitem__(0).left
+                self.j = transform.__getitem__(0).right
+                self.w = transform.__getitem__(0).width
+                self.h = transform.__getitem__(0).height
+            elif isinstance(transform.__getitem__(0), ag.Cropper):
+                transform.__getitem__(0).left = self.i
+                transform.__getitem__(0).right = self.j
+                transform.__getitem__(0).width = self.w
+                transform.__getitem__(0).height = self.h
+                transformed_img_mask.append(transform(image, mask))
+            else:
+                transformed_img_mask.append(transform(image,mask))
+        transformed_imgs = [item[0] for item in transformed_img_mask]
+        transformed_masks = [item[1] for item in transformed_img_mask]
+
+
+        # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12))
+        # fig.suptitle("Patches used in training", fontsize=18, y=.95)
+        #
+        # for i, image_id in enumerate(transformed_imgs):
+        #     img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze()
+        #     mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze()
+        #     axs[0, i].imshow(img1)
+        #     axs[0, i].imshow(mask1)
+        #     axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze())
+        #     axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze())
+        # plt.savefig('High.png')
+        return transformed_imgs, transformed_masks
+
+    def hook_former_transformer(self, augmented_imgs):
+        transform = []
+        for i in range(len(self.size_crops)):
+            i_transform = []
+            for ith_aug, aug in enumerate(augmented_imgs):
+                # For the first time we crop randomly, then save coordinates for further transformation
+                if ith_aug == 0:
+                    global_crop = ag.RandomCropper(self.size_crops[i])
+                    i_transform.extend(
+                        [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])]
+                    )
+                else:  # Crop same area, and do augmentation
+                    local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4)
+                    i_transform.extend(
+                        [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])])
+            transform += i_transform * self.num_crops[i]
+        self.transform =  transform
+        # add online train transform of the size of global view
+        online_train_transform = ag.Compose(
+            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
+        )
+
+        self.transform.append(online_train_transform)
+        return transform
+
+    def other_networks_transformers(self, augmented_imgs):
+        transform = []
+        for i in range(len(self.size_crops)):
+            i_transform = []
+            for ith_aug, aug in enumerate(augmented_imgs):
+                # For the first time we crop randomly, then save coordinates for further transformation
+                if ith_aug == 0:
+                    random_crop = ag.RandomCropper(self.size_crops[i])
+                    i_transform.extend(
+                        [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])]
+                    )
+                else:  # Crop same area, and do augmentation
+                    fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width)
+                    i_transform.extend(
+                        [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])])
+            transform += i_transform * self.num_crops[i]
+
+        self.transform = transform
+        # add online train transform of the size of global view
+        #What was that for?? I forgot!
+        online_train_transform = ag.Compose(
+            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
+        )
+
+        self.transform.append(online_train_transform)
+        return transform
+
+class OurEvalTransformer(OurTrainTransformer):
+    def __init__(self,
+                 size_crops: Tuple[int] = (294, 294),
+                 nmb_crops: Tuple[int] = (2, 4),
+                 augs:Dict = {'SaltPepper':1},
+                 use_hook_former: bool =True) -> List:
+        super().__init__(size_crops=size_crops,
+                       nmb_crops=nmb_crops,
+                       augs = augs,
+                         use_hook_former=use_hook_former)
+
+        input_height = self.size_crops[0]  # get global view crop
+        test_transform = ag.Compose(
+            [
+                ag.Resize(int(input_height + 0.1 * input_height)),
+                ag.CenterCrop(input_height),
+                self.final_transform,
+            ]
+        )
+
+    # replace last transform to eval transform in self.transform list
+        self.transform[-1] = test_transform
+
+class OurFinetuneTransform:
+    def __init__(self,
+                 input_height: int = 224,
+                 normalize = None,
+                 eval_transform: bool = False) -> None :
+
+        if not eval_transform:# TODO think about it
+            data_transforms = [
+                transforms.RandomResizedCrop(size=self.input_height),
+                ag.RandomHorizontalFlip(),
+                ag.SaltPepperNoise(salt_or_pepper=0.5),
+                transforms.RandomGrayscale(p=0.2),
+            ]
+        else:
+            data_transforms = ag.Compose(
+                [
+                    transforms.Resize(int(input_height + 0.1 * input_height)),
+                    transforms.CenterCrop(input_height),
+                    self.final_transform,
+                ]
+            )
+
+        if normalize is None:
+            final_transform = transforms.ToTensor()
+        else:
+            final_transform = transforms.Compose([transforms.ToTensor(), normalize])
+
+        data_transforms.append(final_transform)
+        self.transform = ag.Compose(data_transforms)
+
+    def __call__(self, img: Tensor)->Tensor:
+        return self.transform(img)
+
+# TODO use it in the code with a flag
+class SwapDefaultTransformer():
+    def __init__(
+            self,
+            size_crops: Tuple[int] = (96, 36),
+            nmb_crops: Tuple[int] = (2, 4),
+            min_scale_crops: Tuple[float] = (0.33, 0.10),
+            max_scale_crops: Tuple[float] = (1, 0.33),
+            gaussian_blur: bool = True,
+            jitter_strength: float = 1.0,
+    ) -> object:
+
+        self.size_crops = size_crops
+        self.num_crops = nmb_crops
+        self.min_scale_crops = min_scale_crops
+        self.max_scale_crops = max_scale_crops
+        self.gaussian_blur = gaussian_blur
+        self.jitter_strength = jitter_strength
+
+    def get(self, mode):
+        if mode == 'train':
+            return swav.SwAVTrainDataTransform(
+                #normalize=self.normalization(),
+                size_crops=self.size_crops,
+                num_crops=self.num_crops,
+                min_scale_crops=self.min_scale_crops,
+                max_scale_crops=self.max_scale_crops,
+                gaussian_blur=self.gaussian_blur,
+                jitter_strength=self.jitter_strength
+            )
+        elif mode == 'val':
+            return swav.SwAVEvalDataTransform(
+                #normalize=self.normalization(),
+                size_crops=self.size_crops,
+                num_crops=self.num_crops,
+                min_scale_crops=self.min_scale_crops,
+                max_scale_crops=self.max_scale_crops,
+                gaussian_blur=self.gaussian_blur,
+                jitter_strength=self.jitter_strength
+            )
+
-- 
GitLab


From 5e93adb360abfc69d12690c5b400245b947b8bb0 Mon Sep 17 00:00:00 2001
From: Marziyeh <marziyeh.mohammadi@fau.de>
Date: Fri, 15 Dec 2023 09:46:48 +0100
Subject: [PATCH 11/11] Refactoring: Make the folder names more clear.

---
 SSLGlacier/processing/agumentations_.py | 411 -----------------
 SSLGlacier/processing/datamodule_.py    | 149 -------
 SSLGlacier/processing/semi_module.py    |  29 --
 SSLGlacier/processing/swav_module.py    | 562 ------------------------
 SSLGlacier/processing/transformers_.py  | 248 -----------
 5 files changed, 1399 deletions(-)
 delete mode 100644 SSLGlacier/processing/agumentations_.py
 delete mode 100644 SSLGlacier/processing/datamodule_.py
 delete mode 100644 SSLGlacier/processing/semi_module.py
 delete mode 100644 SSLGlacier/processing/swav_module.py
 delete mode 100644 SSLGlacier/processing/transformers_.py

diff --git a/SSLGlacier/processing/agumentations_.py b/SSLGlacier/processing/agumentations_.py
deleted file mode 100644
index 05cb774..0000000
--- a/SSLGlacier/processing/agumentations_.py
+++ /dev/null
@@ -1,411 +0,0 @@
-# PILRandomGaussianBlur and get_color_distortion are used and implemented
-# by Swav  and  SimCLR - https://arxiv.org/abs/2002.05709
-import random
-from logging import getLogger
-
-import torch.nn
-from PIL import ImageFilter
-import numpy as np
-import torchvision.transforms as transforms
-import torchvision, tormentor
-from torchvision.transforms import functional as F
-import torch
-import matplotlib.pyplot as plt
-
-logger = getLogger()
-
-
-class Compose(object):
-    """
-    Class for chaining transforms together
-    """
-
-    def __init__(self, transforms):
-        self.transforms = transforms
-
-    def __call__(self, image, target):
-        for t in self.transforms:
-            image, target = t(image, target)
-        return [image, target]
-
-    def __getitem__(self, item):
-        return self.transforms[item]
-
-class DoNothing(torch.nn.Module):
-    def __init__(self):
-        super(DoNothing, self).__init__()
-
-    def __call__(self, img, mask):
-        return [img, mask]
-
-class Cropper(torch.nn.Module):
-    def __init__(self, i, j, h, w):
-        super(Cropper, self).__init__()
-        self.left = i
-        self.right = j
-        self.height = h
-        self.width = w
-
-    def __call__(self, img, mask):
-        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
-        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
-        return [cropped_img, cropped_mask]
-
-
-class RandomCropper(torch.nn.Module):
-    '''
-          This function returns one patch at time, if you need more patches in one image, call it in a loop
-          Args:
-              orig_img: get an png or jpg image and crop one patch randomly, in both image and mask
-              orig_mask: This is the images mask, we crop same area from
-          Returns:
-              A list of two argument, first cropped patch in image,second cropped patch in mask
-          '''
-
-    def __init__(self, size):
-        super(RandomCropper, self).__init__()
-        self.left = 0
-        self.right = 0
-        self.height = 0
-        self.width = 0
-        self.size = size
-
-    def forward(self, img, mask):
-        self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params(
-            img, output_size=(self.size, self.size))
-        cropped_img = F.crop(img, self.left, self.right, self.height, self.width)
-        cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width)
-        return [cropped_img, cropped_mask]
-
-
-class PILRandomGaussianBlur(torch.nn.Module):
-    def __init__(self, radius_min=0.1, radius_max=2.):
-        """
-           Apply Gaussian Blur to the PIL image. Take the radius and probability of
-           application as the parameter.
-           This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
-           """
-        super(PILRandomGaussianBlur, self).__init__()
-        self.radius_min = radius_min
-        self.radius_max = radius_max
-
-    def forward(self, img, mask):
-        return [img.filter(
-            ImageFilter.GaussianBlur(
-                radius=random.uniform(self.radius_min, self.radius_max)
-            )
-        ), mask]
-
-
-class RandomHorizontalFlip(torch.nn.Module):
-    def __init__(self):
-        super(RandomHorizontalFlip, self).__init__()
-
-    def forward(self, img, mask):
-        image = torchvision.transforms.functional.hflip(img)
-        mask = torchvision.transforms.functional.hflip(mask)
-        return [image, mask]
-
-
-class GetColorDistortion(torch.nn.Module):
-    def __int__(self, s=1.0):
-        super(GetColorDistortion, self).__init__()
-        self.s = s
-
-    def forward(self, img, mask):
-        color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)
-        rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
-        rnd_gray = transforms.RandomGrayscale(p=0.2)
-        color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
-        return color_distort(img, mask)
-
-
-# TODO this net
-class GaussNoise(torch.nn.Module):
-    def __init__(self, mean=0, var=10000):
-        super(GaussNoise, self).__init__()
-        self.mean = mean
-        self.var = var
-
-    def forward(self, img, mask):
-        row, col = img.size
-        sigma = self.var ** 0.5
-        gauss = np.random.normal(self.mean, sigma, (row, col))
-        gauss = gauss.reshape(row, col)
-        noisy = img + gauss
-        return [noisy, mask]
-
-
-class SaltPepperNoise(torch.nn.Module):
-    def __init__(self, salt_or_pepper=0.5):
-        super(SaltPepperNoise, self).__init__()
-        self.salt_or_pepper = salt_or_pepper
-
-    def forward(self, img, mask):
-        if len(img.size) == 3:
-            img_size = img.size[1] * img.size[2]
-        else:
-            img_size = img.size[0] * img.size[1]
-        amount = .4
-        noisy = np.copy(img)
-        # Salt mode
-        num_salt = np.ceil(amount * img_size * self.salt_or_pepper)
-        target_pixels = [np.random.randint(0, i - 1, int(num_salt))
-                         for i in img.size]
-        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
-        for i, j in target_pixels:
-            noisy[i][j] = 1
-        # Pepper mode
-        num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper))
-        target_pixels = [np.random.randint(0, i - 1, int(num_pepper))
-                         for i in img.size]
-        target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1])))
-
-        for i, j in target_pixels:
-            noisy[i][j] = 0
-
-        # plt.imshow(img)
-        # plt.imshow(noisy)
-        # plt.show()
-        return [noisy, mask]
-
-
-class PoissionNoise(torch.nn.Module):
-    def __init__(self):
-        super(PoissionNoise, self).__init__()
-
-    def forward(self, img, mask):
-        vals = len(np.unique(img))
-        vals = 2 ** np.ceil(np.log2(vals))
-        noisy = np.random.poisson(img * vals) / float(vals)
-        return [noisy, mask]
-
-
-# TODO this one
-class SpeckleNoise(torch.nn.Module):
-    def __init__(self):
-        super(SpeckleNoise, self).__init__()
-
-    def forward(self, img, mask):
-        row, col = img.size
-        gauss = np.random.randn(row, col)
-        gauss = gauss.reshape(row, col)
-        noisy = img + img * gauss
-        return [noisy, mask]
-
-
-class SetZeroNoise(torch.nn.Module):
-    def __init__(self):
-        super(SetZeroNoise, self).__init__()
-
-    def forward(self, img, mask):
-        row, col = img.size
-        img_size = row * col
-        random_rows = np.random.randint(0, int(row), (1, int(img_size)))
-        random_cols = np.random.randint(0, int(col), (1, int(img_size)))
-        target_pixels = list(zip(random_rows[0], random_cols[0]))
-        for pix in target_pixels:
-            img[pix[0], pix[1]] = 0
-        return [img, mask]
-
-
-####################################################################################
-
-###########Base_code Agumentations suggestions: #TODO do not need them now##########
-class MyWrap(torch.nn.Module):
-    """
-    Random wrap augmentation taken from tormentor
-    """
-
-    def __init__(self):
-        super(MyWrap, self).__init__()
-
-    def forward(self, img, target):
-        # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics
-        wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)),
-                                                          intensity=tormentor.random.Uniform(value_range=(.0, 1.)))
-        wrap = wrap_rand()
-        image = wrap(img)
-        mask = wrap(target, is_mask=True)
-        return [image, mask]
-
-
-class Rotate(torch.nn.Module):
-    """
-    Random rotation augmentation
-    """
-
-    def __init__(self):
-        super(Rotate, self).__init__()
-
-    def forward(self, img, target):
-        random = np.random.randint(0, 3)
-        angle = 90
-        if random == 1:
-            angle = 180
-        elif random == 2:
-            angle = 270
-        image = torchvision.transforms.functional.rotate(img, angle=angle)
-        mask = torchvision.transforms.functional.rotate(target, angle=angle)
-        return [image, mask.squeeze(0)]
-
-
-class Bright(torch.nn.Module):
-    """
-    Random brightness adjustment augmentations
-    """
-
-    def __init__(self,
-                 lower_band: float = -0.2,
-                 upper_band: float = 0.2):
-        super(Bright, self).__init__()
-        self.lower_band = lower_band
-        self.upper_band = upper_band
-
-    def forward(self, img, mask):
-        bright_rand = tormentor.Brightness.override_distributions(
-            brightness=tormentor.random.Uniform((self.lower_band, self.upper_band)))
-        # bright = bright_rand()
-        # image_transformed = img.clone()
-        image_transformed = bright_rand(np.asarray(img))
-        # set NA areas back to zero
-        image_transformed.seed[img == 0] = 0.0
-
-        return [image_transformed.seed, mask]
-
-
-class Noise(torch.nn.Module):
-    """
-    Random additive noise augmentation
-    """
-
-    def __init__(self):
-        super(Noise, self).__init__()
-
-    def forward(self, img, target):
-        # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again
-        noise = torch.normal(mean=0, std=0.3, size=img.size)
-        image = img + img * noise
-        image[image > 1.0] = 1.0
-        image[image < 0.0] = 0.0
-
-        return [image, target]
-
-
-class Resize(torch.nn.Module):
-    # Library code
-    """Resize the input image to the given size.
-    If the image is torch Tensor, it is expected
-    to have [..., H, W] shape, where ... means a maximum of two leading dimensions
-
-    .. warning::
-        The output image might be different depending on its type: when downsampling, the interpolation of PIL images
-        and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences
-        in the performance of a network. Therefore, it is preferable to train and serve a model with the same input
-        types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors
-        closer.
-
-    Args:
-        size (sequence or int): Desired output size. If size is a sequence like
-            (h, w), output size will be matched to this. If size is an int,
-            smaller edge of the image will be matched to this number.
-            i.e, if height > width, then image will be rescaled to
-            (size * height / width, size).
-
-            .. note::
-                In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
-        interpolation (InterpolationMode): Desired interpolation enum defined by
-            :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
-            If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
-            ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
-            The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
-        max_size (int, optional): The maximum allowed for the longer edge of
-            the resized image. If the longer edge of the image is greater
-            than ``max_size`` after being resized according to ``size``,
-            ``size`` will be overruled so that the longer edge is equal to
-            ``max_size``.
-            As a result, the smaller edge may be shorter than ``size``. This
-            is only supported if ``size`` is an int (or a sequence of length
-            1 in torchscript mode).
-        antialias (bool, optional): Whether to apply antialiasing.
-            It only affects **tensors** with bilinear or bicubic modes and it is
-            ignored otherwise: on PIL images, antialiasing is always applied on
-            bilinear or bicubic modes; on other modes (for PIL images and
-            tensors), antialiasing makes no sense and this parameter is ignored.
-            Possible values are:
-
-            - ``True``: will apply antialiasing for bilinear or bicubic modes.
-              Other mode aren't affected. This is probably what you want to use.
-            - ``False``: will not apply antialiasing for tensors on any mode. PIL
-              images are still antialiased on bilinear or bicubic modes, because
-              PIL doesn't support no antialias.
-            - ``None``: equivalent to ``False`` for tensors and ``True`` for
-              PIL images. This value exists for legacy reasons and you probably
-              don't want to use it unless you really know what you are doing.
-
-            The current default is ``None`` **but will change to** ``True`` **in
-            v0.17** for the PIL and Tensor backends to be consistent.
-    """
-
-    def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"):
-        super().__init__()
-        self.size = size
-        self.max_size = max_size
-        self.interpolation = interpolation
-        self.antialias = antialias
-
-    def forward(self, img, mask):
-        """
-        Args:
-            img (PIL Image or Tensor): Image to be scaled.
-
-        Returns:
-            PIL Image or Tensor: Rescaled image.
-        """
-        return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask]
-
-
-class CenterCrop(torch.nn.Module):
-    """Crops the given image at the center.
-    If the image is torch Tensor, it is expected
-    to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
-    If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
-
-    Args:
-        size (sequence or int): Desired output size of the crop. If size is an
-            int instead of sequence like (h, w), a square crop (size, size) is
-            made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
-    """
-
-    def __init__(self, size):
-        super().__init__()
-        self.size = size  # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
-
-    def forward(self, img, mask):
-        """
-        Args:
-            img (PIL Image or Tensor): Image to be cropped.
-
-        Returns:
-            PIL Image or Tensor: Cropped image.
-        """
-        ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER
-        W, H = img.size
-
-        WW = (W // self.size) + 2
-        HH = (H // self.size) + 2
-
-        return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))]
-
-
-class ToTensorZones():
-    def __call__(self, image, target):
-        image = F.to_tensor(np.array(image).astype(np.float32))
-        target = torch.from_numpy(np.array(target).astype(np.float32))
-        # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254
-        target[target == 0] = 0
-        target[target == 64] = 1
-        target[target == 127] = 2
-        target[target == 254] = 3
-        # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3
-        return [image, target]
diff --git a/SSLGlacier/processing/datamodule_.py b/SSLGlacier/processing/datamodule_.py
deleted file mode 100644
index edf3648..0000000
--- a/SSLGlacier/processing/datamodule_.py
+++ /dev/null
@@ -1,149 +0,0 @@
-from torchvision.transforms import functional as F
-from pl_bolts.models.self_supervised import swav
-from torch.utils.data import DataLoader
-from torch.utils.data import Dataset
-import pytorch_lightning as pl
-from torchvision import transforms as transform_lib
-from . import agumentations_ as our_transforms
-import torchvision
-import numpy as np
-from typing import Tuple, List, Any, Callable, Optional
-import torch
-import cv2
-from PIL import Image,ImageOps
-import os
-
-# TODO add moco, simclr,cpc and amdim transformers
-class SSLDataset(Dataset):
-    def __init__(self, parent_dir, transform,return_index=False, **kwargs):  # TODO Pass them as arguments
-        '''
-        Args:
-            mode: can be tested, train, or validation
-            parent_dir: directory in which the folders test, train and validation exists
-        '''
-        self.mode = kwargs['mode'] if kwargs['mode'] else 'train'
-        self.return_index = return_index
-        self.images_path = os.path.join(parent_dir, "sar_images", self.mode)
-        self.masks_path = os.path.join(parent_dir, 'zones', self.mode)
-        self.images = os.listdir(self.images_path)
-        self.masks = os.listdir(self.masks_path)
-        assert len(self.masks) == len(self.images), "You don't have the same number of images and masks"
-        self.transform = transform
-        # Sort than images and masks fit together
-        self.images.sort()
-        self.masks.sort()
-
-        # Let shuffle and save indices, and load them in next calls if they exist
-        if not os.path.exists(os.path.join("data_processing", "data_splits")):
-            os.makedirs(os.path.join("data_processing", "data_splits"))
-        if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")):
-            shuffle = np.random.permutation(len(self.images))
-            # Works for numpy version >= 1.5
-            np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
-                       newline=' ')
-
-        else:
-            # use already existing shuffle
-            with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp:
-                lines = fp.readlines()
-                shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines]
-                # if lengths do not match, we need to create a new permutation
-                if len(shuffle) != len(self.images):
-                    shuffle = np.random.permutation(len(self.images))
-                    np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle,
-                               newline=' ')
-
-        self.images = np.array(self.images)  # Why!!!!
-        self.masks = np.array(self.masks)
-        self.images = self.images[shuffle].copy()
-        self.masks = self.masks[shuffle].copy()
-        self.images = list(self.images)
-        self.masks = list(self.masks)  # Why!!!!
-
-    def __len__(self):
-        return len(self.images)
-
-    def __getitem__(self, index):
-        img_name = self.images[index]
-        masks_name = self.masks[index]
-        assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \
-            "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name
-        #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE)
-        #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE)
-        image = Image.open(os.path.join(self.images_path, img_name).__str__())
-        image = ImageOps.grayscale(image)
-        #image = np.array(image).astype(np.float32)
-
-        #rgbimg = Image.new("RGBA", image.size)
-        #image = rgbimg.paste(image)
-        #image = ImageOps.grayscale(image)
-        mask = Image.open(os.path.join(self.masks_path, masks_name).__str__())
-        #rgbimg = Image.new("RGBA", mask.size)
-        #mask = rgbimg.paste(rgbimg)
-        mask = ImageOps.grayscale(mask)
-        #mask = np.array(mask).astype(np.float32)
-        # TODO check if it works or not, it should return multiple transformation for one image
-        # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!!
-        # TODO this is a list of tuples of img and masks, think if you want to change it later
-
-       # to_tensor = our_transforms.ToTensorZones()
-       # _, mask = to_tensor(image,mask)
-        if self.transform is not None:
-            augmented_imgs, masks = self.transform(image, mask)
-        if self.return_index:
-            return index, augmented_imgs, mask, img_name, masks_name
-        return augmented_imgs, masks
-
-
-class SSLDataModule(pl.LightningDataModule):
-    # Base is adopted from Nora's code
-    def __init__(self, batch_size, parent_dir, args):
-        """
-        :param batch_size: batch size
-        :param target: Either 'zones' or 'front'. Tells which masks should be used.
-        """
-        super().__init__()
-        self.transforms = None
-        self.batch_size = batch_size
-        self.glacier_test = None
-        self.glacier_train = None
-        self.glacier_val = None
-        self.parent_dir = parent_dir
-        self.aug_args = args.agumentations
-        self.size_crops = args.size_crops
-        self.num_crops = args.nmb_crops
-
-    def prepare_data(self):
-        # download,
-        # only called on 1 GPU/TPU in distributed
-        pass
-    def __len__(self):
-        return len(self.glacier_train)
-    def setup(self, stage=None):
-        # process and split here
-        # make assignments here (val/train/test split)
-        # called on every process in DDP
-        if stage == 'test' or stage is None:
-            self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
-            self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage)
-        if stage == 'fit' or stage is None:
-            self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
-            self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train')
-
-            self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
-            self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val')
-
-    def train_dataloader(self) -> DataLoader:
-        return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True,
-                          drop_last=True)
-
-    def val_dataloader(self)-> DataLoader:
-        return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True)
-
-    def test_dataloader(self)-> DataLoader:
-        return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True,
-                          drop_last=True)  # TODO self.batch_size
-
-    def _default_transforms(self) -> Callable:
-        #return transform_lib.Compose([transform_lib.ToTensor()])
-        return our_transforms.Compose([our_transforms.ToTensorZones()])
diff --git a/SSLGlacier/processing/semi_module.py b/SSLGlacier/processing/semi_module.py
deleted file mode 100644
index 9798a36..0000000
--- a/SSLGlacier/processing/semi_module.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from pytorch_lightning import LightningModule
-from typing import Any
-class Semi_(LightningModule):
-    def __init__(self):
-        pass
-
-    def setup(self, stage: str) -> None:
-        pass
-
-    def init_model(self):
-        pass
-
-    def forward(self, *args: Any, **kwargs: Any) -> Any:
-        pass
-
-    def on_train_epoch_start(self) -> None:
-        pass
-
-    def shared_step(self, batch):
-        inputs, y = batch
-        inputs = inputs[:, -1]
-        embedding = self.model(inputs)
-        ####Where should I add noise....
-        ####assume embedding[b][i] belongs to batch b view i
-        ####if we have two view, one with noise and one without, then
-        ##### select some portion of each view belongs to different classes.
-        ##### pos_set = same classess ..
-        ##### neg_set = diff classes
-
diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/processing/swav_module.py
deleted file mode 100644
index 3b57a4e..0000000
--- a/SSLGlacier/processing/swav_module.py
+++ /dev/null
@@ -1,562 +0,0 @@
-"""Adapted from official swav implementation: https://github.com/facebookresearch/swav."""
-import os
-from argparse import ArgumentParser
-
-import torch
-from pytorch_lightning import LightningModule, Trainer
-from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
-from torch import nn
-
-#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
-from SSLGlacier.models.losses import SWAVLoss
-from SSLGlacier.models.RESNET import resnet18, resnet50
-#from Base_Code.models.ParentUNet import UNet
-from SSLGlacier.models.UNET import UNet #The main network
-from pl_bolts.optimizers.lars import LARS
-from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
-
-##########Importing files for hook############
-from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
-######Temporary config file for hooknet#######
-from utils.config import get_config
-
-class SwAV(LightningModule):
-    def __init__(
-        self,
-        gpus: int,
-        num_samples: int,
-        batch_size: int,
-        dataset: str,
-        num_nodes: int = 1,
-        arch: str = "resnet50",
-        hidden_mlp: int = 2048,
-        feat_dim: int = 128,
-        warmup_epochs: int = 10,
-        max_epochs: int = 100,
-        num_prototypes: int = 3000,
-        freeze_prototypes_epochs: int = 1,
-        temperature: float = 0.1,
-        sinkhorn_iterations: int = 3,
-        queue_length: int = 0,  # must be divisible by total batch-size
-        queue_path: str = "queue",
-        epoch_queue_starts: int = 15,
-        crops_for_assign: tuple = (0, 1),
-        num_crops: tuple = (2, 6),
-        num_augs: int= 2,
-        first_conv: bool = True,
-        maxpool1: bool = True,
-        optimizer: str = "adam",
-        exclude_bn_bias: bool = False,
-        start_lr: float = 0.0,
-        learning_rate: float = 1e-3,
-        final_lr: float = 0.0,
-        weight_decay: float = 1e-6,
-        epsilon: float = 0.05,
-        just_aug_for_same_assign_views: bool = False,
-        swin_hparams = {}
-    ) -> None:
-        """
-        Args:
-            gpus: number of gpus per node used in training, passed to SwAV module
-                to manage the queue and select distributed sinkhorn
-            num_nodes: number of nodes to train on
-            num_samples: number of image samples used for training
-            batch_size: batch size per GPU in ddp
-            dataset: dataset being used for train/val
-            arch: encoder architecture used for pre-training
-            hidden_mlp: hidden layer of non-linear projection head, set to 0
-                to use a linear projection head
-            feat_dim: output dim of the projection head
-            warmup_epochs: apply linear warmup for this many epochs
-            max_epochs: epoch count for pre-training
-            num_prototypes: count of prototype vectors
-            freeze_prototypes_epochs: epoch till which gradients of prototype layer
-                are frozen
-            temperature: loss temperature
-            sinkhorn_iterations: iterations for sinkhorn normalization
-            queue_length: set queue when batch size is small,
-                must be divisible by total batch-size (i.e. total_gpus * batch_size),
-                set to 0 to remove the queue
-            queue_path: folder within the logs directory
-            epoch_queue_starts: start uing the queue after this epoch
-            crops_for_assign: list of crop ids for computing assignment
-            num_crops: number of global and local crops, ex: [2, 6]
-            first_conv: keep first conv same as the original resnet architecture,
-                if set to false it is replace by a kernel 3, stride 1 conv (cifar-10)
-            maxpool1: keep first maxpool layer same as the original resnet architecture,
-                if set to false, first maxpool is turned off (cifar10, maybe stl10)
-            optimizer: optimizer to use
-            exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers
-            start_lr: starting lr for linear warmup
-            learning_rate: learning rate
-            final_lr: float = final learning rate for cosine weight decay
-            weight_decay: weight decay for optimizer
-            epsilon: epsilon val for swav assignments
-        """
-        super().__init__()
-        self.save_hyperparameters()
-
-        self.gpus = gpus
-        self.num_nodes = num_nodes
-        self.arch = arch
-        self.dataset = dataset
-        self.num_samples = num_samples
-        self.batch_size = batch_size
-
-        self.hidden_mlp = hidden_mlp
-        self.feat_dim = feat_dim
-        self.num_prototypes = num_prototypes
-        self.freeze_prototypes_epochs = freeze_prototypes_epochs
-        self.sinkhorn_iterations = sinkhorn_iterations
-
-        self.queue_length = queue_length
-        self.queue_path = queue_path
-        self.epoch_queue_starts = epoch_queue_starts
-        self.crops_for_assign = crops_for_assign
-        self.num_crops = num_crops
-        self.num_augs = num_augs
-
-        self.first_conv = first_conv
-        self.maxpool1 = maxpool1
-
-        self.optim = optimizer
-        self.exclude_bn_bias = exclude_bn_bias
-        self.weight_decay = weight_decay
-        self.epsilon = epsilon
-        self.temperature = temperature
-
-        self.start_lr = start_lr
-        self.final_lr = final_lr
-        self.learning_rate = learning_rate
-        self.warmup_epochs = warmup_epochs
-        self.max_epochs = max_epochs
-        self.just_aug_for_same_assign_views = just_aug_for_same_assign_views
-
-        self.model = self.init_model()
-        self.criterion = SWAVLoss(
-            gpus=self.gpus,
-            num_nodes=self.num_nodes,
-            temperature=self.temperature,
-            crops_for_assign=self.crops_for_assign,
-            num_crops=self.num_crops,
-            num_augs=self.num_augs,
-            sinkhorn_iterations=self.sinkhorn_iterations,
-            epsilon=self.epsilon,
-            just_aug_for_same_assign_views= self.just_aug_for_same_assign_views
-        )
-        self.use_the_queue = None
-        # compute iters per epoch
-        global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
-        self.train_iters_per_epoch = self.num_samples // global_batch_size
-        self.queue = None
-
-        ####For hook####
-        self.swin_hparams = swin_hparams
-    def setup(self, stage):
-        if self.queue_length > 0:
-            queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
-            if not os.path.exists(queue_folder):
-                os.makedirs(queue_folder)
-
-            self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth")
-
-            if os.path.isfile(self.queue_path):
-                self.queue = torch.load(self.queue_path)["queue"]
-
-    def init_model(self):
-        if self.arch == "resnet18":
-            backbone = resnet18(normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
-
-        elif self.arch == "resnet50":
-            backbone = resnet50(normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
-        elif self.arch == "hook":
-            backbone = SwinUnet(
-                                img_size=224,
-                                num_classes=5,
-                                normalize=True,
-                                hidden_mlp=self.hidden_mlp,
-                                output_dim=self.feat_dim,
-                                num_prototypes=self.num_prototypes,
-                                first_conv=self.first_conv,
-                                maxpool1=self.maxpool1,######till here for basenet
-                                image_size= self.swin_hparams.image_size,####from here for swin itself
-                                swin_patch_size = self.swin_hparams.swin_patch_size,
-                                swin_in_chans=self.swin_hparams.swin_in_chans,
-                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
-                                swin_depths=self.swin_hparams.swin_depths,
-                                swin_num_heads=self.swin_hparams.swin_num_heads,
-                                swin_window_size=self.swin_hparams.swin_window_size,
-                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
-                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
-                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
-                                drop_rate=self.drop_rate,
-                                drop_path_rate=self.drop_path_rate,
-                                swin_ape=self.swin_hparams.swin_ape,
-                                swin_patch_norm=self.swin_hparams.swin_path_norm
-                                )
-
-        elif self.arch == "Unet" or "unet":
-            backbone = UNet(num_classes=5,
-                            input_channels=1,
-                            num_layers=5,
-                            features_start=64,
-                            bilinear=False,
-                            normalize=True,
-                            hidden_mlp=self.hidden_mlp,
-                            output_dim=self.feat_dim,
-                            num_prototypes=self.num_prototypes,
-                            first_conv=self.first_conv,
-                            maxpool1=self.maxpool1
-                            )
-        else:
-            raise f'{self.arch} model is not defined'
-
-        return backbone
-
-    def forward(self, x):
-        # pass single batch from the resnet backbone
-        return self.model.forward_backbone(x)
-
-    def on_train_epoch_start(self):
-        if self.queue_length > 0:
-            if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None:
-                self.queue = torch.zeros(
-                    len(self.crops_for_assign),
-                    self.queue_length // self.gpus,  # change to nodes * gpus once multi-node
-                    self.feat_dim,
-                )
-
-            if self.queue is not None:
-                self.queue = self.queue.to(self.device)
-
-        self.use_the_queue = False
-
-    def on_train_epoch_end(self) -> None:
-        if self.queue is not None:
-            torch.save({"queue": self.queue}, self.queue_path)
-
-    def on_after_backward(self):
-        if self.current_epoch < self.freeze_prototypes_epochs:
-            for name, p in self.model.named_parameters():
-                if "prototypes" in name:
-                    p.grad = None
-
-    def shared_step(self, batch):
-        if self.dataset == "stl10":
-            unlabeled_batch = batch[0]
-            batch = unlabeled_batch
-
-        inputs, y = batch
-        inputs = inputs[:-1]  # remove online train/eval transforms at this point
-
-        # 1. normalize the prototypes
-        with torch.no_grad():
-            w = self.model.prototypes.weight.data.clone()
-            w = nn.functional.normalize(w, dim=1, p=2)
-            self.model.prototypes.weight.copy_(w)
-
-        # 2. multi-res forward passes
-        embedding, output = self.model(inputs)
-        embedding = embedding.detach()
-        bs = inputs[0].size(0)
-
-        # SWAV loss computation
-        loss, queue, use_queue = self.criterion(
-            output=output,
-            embedding=embedding,
-            prototype_weights=self.model.prototypes.weight,
-            batch_size=bs,
-            queue=self.queue,
-            use_queue=self.use_the_queue,
-        )
-        self.queue = queue
-        self.use_the_queue = use_queue
-        return loss
-
-    def training_step(self, batch, batch_idx):
-        loss = self.shared_step(batch)
-
-        self.log("train_loss", loss, on_step=True, on_epoch=False)
-        return loss
-
-    def validation_step(self, batch, batch_idx):
-        loss = self.shared_step(batch)
-
-        self.log("val_loss", loss, on_step=False, on_epoch=True)
-        return loss
-
-    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")):
-        params = []
-        excluded_params = []
-
-        for name, param in named_params:
-            if not param.requires_grad:
-                continue
-            if any(layer_name in name for layer_name in skip_list):
-                excluded_params.append(param)
-            else:
-                params.append(param)
-
-        return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}]
-
-    def configure_optimizers(self):
-        if self.exclude_bn_bias:
-            params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay)
-        else:
-            params = self.parameters()
-
-        if self.optim == "lars":
-            optimizer = LARS(
-                params,
-                lr=self.learning_rate,
-                momentum=0.9,
-                weight_decay=self.weight_decay,
-                trust_coefficient=0.001,
-            )
-        elif self.optim == "adam":
-            optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay)
-
-        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs
-        total_steps = self.train_iters_per_epoch * self.max_epochs
-
-        scheduler = {
-            "scheduler": torch.optim.lr_scheduler.LambdaLR(
-                optimizer,
-                linear_warmup_decay(warmup_steps, total_steps, cosine=True),
-            ),
-            "interval": "step",
-            "frequency": 1,
-        }
-
-        return [optimizer], [scheduler]
-
-    @staticmethod
-    def add_model_specific_args(parent_parser):
-        parser = ArgumentParser(parents=[parent_parser], add_help=False)
-
-        # model params
-        parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture")
-        # specify flags to store false
-        parser.add_argument("--first_conv", action="store_false")
-        parser.add_argument("--maxpool1", action="store_false")
-        parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head")
-        parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension")
-        parser.add_argument("--online_ft", action="store_true")
-        parser.add_argument("--fp32", action="store_true")
-
-        # transform params
-        parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur")
-        parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength")
-        parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10")
-        parser.add_argument("--data_dir", type=str, default=".", help="path to download data")
-        parser.add_argument("--queue_path", type=str, default="queue", help="path for queue")
-
-        parser.add_argument(
-            "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])"
-        )
-        parser.add_argument(
-            "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])"
-        )
-        parser.add_argument(
-            "--min_scale_crops",
-            type=float,
-            default=[0.33, 0.10],
-            nargs="+",
-            help="argument in RandomResizedCrop (example: [0.14, 0.05])",
-        )
-        parser.add_argument(
-            "--max_scale_crops",
-            type=float,
-            default=[1, 0.33],
-            nargs="+",
-            help="argument in RandomResizedCrop (example: [1., 0.14])",
-        )
-
-        # training params
-        parser.add_argument("--fast_dev_run", default=1, type=int)
-        parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training")
-        parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
-        parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
-        parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars")
-        parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay")
-        parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run")
-        parser.add_argument("--max_steps", default=-1, type=int, help="max steps")
-        parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
-        parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu")
-
-        parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay")
-        parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate")
-        parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate")
-        parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate")
-
-        # swav params
-        parser.add_argument(
-            "--crops_for_assign",
-            type=int,
-            nargs="+",
-            default=[0, 1],
-            help="list of crops id used for computing assignments",
-        )
-        parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss")
-        parser.add_argument(
-            "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm"
-        )
-        parser.add_argument(
-            "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm"
-        )
-        parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes")
-        parser.add_argument(
-            "--queue_length",
-            type=int,
-            default=0,
-            help="length of the queue (0 for no queue); must be divisible by total batch size",
-        )
-        parser.add_argument(
-            "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue"
-        )
-        parser.add_argument(
-            "--freeze_prototypes_epochs",
-            default=1,
-            type=int,
-            help="freeze the prototypes during this many epochs from the start",
-        )
-
-        return parser
-
-
-def cli_main():
-    from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
-    from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
-    from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
-
-    parser = ArgumentParser()
-
-    # model args
-    parser = SwAV.add_model_specific_args(parser)
-    args = parser.parse_args()
-
-    if args.dataset == "stl10":
-        dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        dm.train_dataloader = dm.train_dataloader_mixed
-        dm.val_dataloader = dm.val_dataloader_mixed
-        args.num_samples = dm.num_unlabeled_samples
-
-        args.maxpool1 = False
-
-        normalization = stl10_normalization()
-    elif args.dataset == "cifar10":
-        args.batch_size = 2
-        args.num_workers = 0
-
-        dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        args.num_samples = dm.num_samples
-
-        args.maxpool1 = False
-        args.first_conv = False
-
-        normalization = cifar10_normalization()
-
-        # cifar10 specific params
-        args.size_crops = [32, 16]
-        args.num_crops = [2, 1]
-        args.gaussian_blur = False
-    elif args.dataset == "imagenet":
-        args.maxpool1 = True
-        args.first_conv = True
-        normalization = imagenet_normalization()
-
-        args.size_crops = [224, 96]
-        args.num_crops = [2, 6]
-        args.min_scale_crops = [0.14, 0.05]
-        args.max_scale_crops = [1.0, 0.14]
-        args.gaussian_blur = True
-        args.jitter_strength = 1.0
-
-        args.batch_size = 64
-        args.num_nodes = 8
-        args.gpus = 8  # per-node
-        args.max_epochs = 800
-
-        args.optimizer = "lars"
-        args.learning_rate = 4.8
-        args.final_lr = 0.0048
-        args.start_lr = 0.3
-
-        args.num_prototypes = 3000
-        args.online_ft = True
-
-        dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
-
-        args.num_samples = dm.num_samples
-        args.input_height = dm.dims[-1]
-    else:
-        raise NotImplementedError("other datasets have not been implemented till now")
-
-    dm.train_transforms = SwAVTrainDataTransform(
-        normalize=normalization,
-        size_crops=args.size_crops,
-        num_crops=args.num_crops,
-        min_scale_crops=args.min_scale_crops,
-        max_scale_crops=args.max_scale_crops,
-        gaussian_blur=args.gaussian_blur,
-        jitter_strength=args.jitter_strength,
-    )
-
-    dm.val_transforms = SwAVEvalDataTransform(
-        normalize=normalization,
-        size_crops=args.size_crops,
-        num_crops=args.num_crops,
-        min_scale_crops=args.min_scale_crops,
-        max_scale_crops=args.max_scale_crops,
-        gaussian_blur=args.gaussian_blur,
-        jitter_strength=args.jitter_strength,
-    )
-
-    # swav model init
-    model = SwAV(**args.__dict__)
-
-    online_evaluator = None
-    if args.online_ft:
-        # online eval
-        online_evaluator = SSLOnlineEvaluator(
-            drop_p=0.0,
-            hidden_dim=None,
-            z_dim=args.hidden_mlp,
-            num_classes=dm.num_classes,
-            dataset=args.dataset,
-        )
-
-    lr_monitor = LearningRateMonitor(logging_interval="step")
-    model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss")
-    callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint]
-    callbacks.append(lr_monitor)
-
-    trainer = Trainer(
-        max_epochs=args.max_epochs,
-        max_steps=None if args.max_steps == -1 else args.max_steps,
-        gpus=args.gpus,
-        num_nodes=args.num_nodes,
-        accelerator="ddp" if args.gpus > 1 else None,
-        sync_batchnorm=args.gpus > 1,
-        precision=32 if args.fp32 else 16,
-        callbacks=callbacks,
-        fast_dev_run=args.fast_dev_run,
-    )
-
-    trainer.fit(model, datamodule=dm)
-
-
-if __name__ == "__main__":
-    cli_main()
\ No newline at end of file
diff --git a/SSLGlacier/processing/transformers_.py b/SSLGlacier/processing/transformers_.py
deleted file mode 100644
index f8290bb..0000000
--- a/SSLGlacier/processing/transformers_.py
+++ /dev/null
@@ -1,248 +0,0 @@
-from pl_bolts.models.self_supervised import swav
-from torchvision import transforms  # need it for validation and fine tuning
-from . import agumentations_ as ag
-from typing import Tuple, List, Dict
-from torch import Tensor
-
-
-
-class OurTrainTransformer():
-    def __init__(self,
-                 size_crops: Tuple[int] = (294, 94),
-                 nmb_crops: Tuple[int] = (2, 4),
-                 augs:Dict = {'SaltPepper':1},
-                 use_hook_former: bool = True) -> List:
-        # TODO you can set it in a way that making sure you have one sample from each argumentation method
-        self.size_crops = size_crops
-        self.num_crops = nmb_crops
-        augmented_imgs = []
-        self.use_hook_former = use_hook_former
-        self.transform = []
-        if augs['OrigPixelValues']:
-            augmented_imgs.append(ag.DoNothing())
-        if augs['RGaussianB']:
-            augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.))
-        if augs['GaussiN']:
-            augmented_imgs.append(ag.GaussNoise(mean=0, var=10000))
-        if augs['SaltPepper']:
-            augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2))
-        if augs['ZeroN']:
-            augmented_imgs.append(ag.SetZeroNoise())
-        # Noras augmentations methods
-        if augs['flip']:
-            # TODO manage this, how to add mask for the times we need transformation for masks too
-            augmented_imgs.append(ag.RandomHorizentalFlip())
-            # image = torchvision.transforms.functional.hflip(image)
-        # mask = torchvision.transforms.functional.hflip(mask)
-        # image, mask = ag.ToTensorZones(image=image, target=np.array(mask))
-        # augmented_imgs.append((image, mask))
-        if augs['rotate']:
-            augmented_imgs.append(ag.Rotate())
-        if augs['bright']:
-            augmented_imgs.append(ag.Bright())
-        if augs['wrap']:
-            augmented_imgs.append()
-        if augs['noise']:
-            augmented_imgs.append(ag.Noise())
-
-        if augs['normalize'] is not None:
-            self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']])
-        else:
-            self.final_transform = ag.ToTensorZones()
-
-        if self.use_hook_former:
-            self.hook_former_transformer(augmented_imgs)
-        else:
-            self.other_networks_transformers(augmented_imgs)
-
-
-
-    def __call__(self, image, mask):
-        """
-        import matplotlib.pyplot as plt
-        img1 =transformed_imgs[0].detach().cpu().numpy().squeeze()
-        plt.imshow(img1)
-        plt.show()
-        """
-        transformed_img_mask = []
-        for transform in self.transform:
-            if isinstance(transform.__getitem__(0), ag.RandomCropper):
-                transformed_img_mask.append(transform(image, mask))
-                self.i = transform.__getitem__(0).left
-                self.j = transform.__getitem__(0).right
-                self.w = transform.__getitem__(0).width
-                self.h = transform.__getitem__(0).height
-            elif isinstance(transform.__getitem__(0), ag.Cropper):
-                transform.__getitem__(0).left = self.i
-                transform.__getitem__(0).right = self.j
-                transform.__getitem__(0).width = self.w
-                transform.__getitem__(0).height = self.h
-                transformed_img_mask.append(transform(image, mask))
-            else:
-                transformed_img_mask.append(transform(image,mask))
-        transformed_imgs = [item[0] for item in transformed_img_mask]
-        transformed_masks = [item[1] for item in transformed_img_mask]
-
-
-        # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12))
-        # fig.suptitle("Patches used in training", fontsize=18, y=.95)
-        #
-        # for i, image_id in enumerate(transformed_imgs):
-        #     img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze()
-        #     mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze()
-        #     axs[0, i].imshow(img1)
-        #     axs[0, i].imshow(mask1)
-        #     axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze())
-        #     axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze())
-        # plt.savefig('High.png')
-        return transformed_imgs, transformed_masks
-
-    def hook_former_transformer(self, augmented_imgs):
-        transform = []
-        for i in range(len(self.size_crops)):
-            i_transform = []
-            for ith_aug, aug in enumerate(augmented_imgs):
-                # For the first time we crop randomly, then save coordinates for further transformation
-                if ith_aug == 0:
-                    global_crop = ag.RandomCropper(self.size_crops[i])
-                    i_transform.extend(
-                        [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])]
-                    )
-                else:  # Crop same area, and do augmentation
-                    local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4)
-                    i_transform.extend(
-                        [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])])
-            transform += i_transform * self.num_crops[i]
-        self.transform =  transform
-        # add online train transform of the size of global view
-        online_train_transform = ag.Compose(
-            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
-        )
-
-        self.transform.append(online_train_transform)
-        return transform
-
-    def other_networks_transformers(self, augmented_imgs):
-        transform = []
-        for i in range(len(self.size_crops)):
-            i_transform = []
-            for ith_aug, aug in enumerate(augmented_imgs):
-                # For the first time we crop randomly, then save coordinates for further transformation
-                if ith_aug == 0:
-                    random_crop = ag.RandomCropper(self.size_crops[i])
-                    i_transform.extend(
-                        [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])]
-                    )
-                else:  # Crop same area, and do augmentation
-                    fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width)
-                    i_transform.extend(
-                        [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])])
-            transform += i_transform * self.num_crops[i]
-
-        self.transform = transform
-        # add online train transform of the size of global view
-        #What was that for?? I forgot!
-        online_train_transform = ag.Compose(
-            [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform]
-        )
-
-        self.transform.append(online_train_transform)
-        return transform
-
-class OurEvalTransformer(OurTrainTransformer):
-    def __init__(self,
-                 size_crops: Tuple[int] = (294, 294),
-                 nmb_crops: Tuple[int] = (2, 4),
-                 augs:Dict = {'SaltPepper':1},
-                 use_hook_former: bool =True) -> List:
-        super().__init__(size_crops=size_crops,
-                       nmb_crops=nmb_crops,
-                       augs = augs,
-                         use_hook_former=use_hook_former)
-
-        input_height = self.size_crops[0]  # get global view crop
-        test_transform = ag.Compose(
-            [
-                ag.Resize(int(input_height + 0.1 * input_height)),
-                ag.CenterCrop(input_height),
-                self.final_transform,
-            ]
-        )
-
-    # replace last transform to eval transform in self.transform list
-        self.transform[-1] = test_transform
-
-class OurFinetuneTransform:
-    def __init__(self,
-                 input_height: int = 224,
-                 normalize = None,
-                 eval_transform: bool = False) -> None :
-
-        if not eval_transform:# TODO think about it
-            data_transforms = [
-                transforms.RandomResizedCrop(size=self.input_height),
-                ag.RandomHorizontalFlip(),
-                ag.SaltPepperNoise(salt_or_pepper=0.5),
-                transforms.RandomGrayscale(p=0.2),
-            ]
-        else:
-            data_transforms = ag.Compose(
-                [
-                    transforms.Resize(int(input_height + 0.1 * input_height)),
-                    transforms.CenterCrop(input_height),
-                    self.final_transform,
-                ]
-            )
-
-        if normalize is None:
-            final_transform = transforms.ToTensor()
-        else:
-            final_transform = transforms.Compose([transforms.ToTensor(), normalize])
-
-        data_transforms.append(final_transform)
-        self.transform = ag.Compose(data_transforms)
-
-    def __call__(self, img: Tensor)->Tensor:
-        return self.transform(img)
-
-# TODO use it in the code with a flag
-class SwapDefaultTransformer():
-    def __init__(
-            self,
-            size_crops: Tuple[int] = (96, 36),
-            nmb_crops: Tuple[int] = (2, 4),
-            min_scale_crops: Tuple[float] = (0.33, 0.10),
-            max_scale_crops: Tuple[float] = (1, 0.33),
-            gaussian_blur: bool = True,
-            jitter_strength: float = 1.0,
-    ) -> object:
-
-        self.size_crops = size_crops
-        self.num_crops = nmb_crops
-        self.min_scale_crops = min_scale_crops
-        self.max_scale_crops = max_scale_crops
-        self.gaussian_blur = gaussian_blur
-        self.jitter_strength = jitter_strength
-
-    def get(self, mode):
-        if mode == 'train':
-            return swav.SwAVTrainDataTransform(
-                #normalize=self.normalization(),
-                size_crops=self.size_crops,
-                num_crops=self.num_crops,
-                min_scale_crops=self.min_scale_crops,
-                max_scale_crops=self.max_scale_crops,
-                gaussian_blur=self.gaussian_blur,
-                jitter_strength=self.jitter_strength
-            )
-        elif mode == 'val':
-            return swav.SwAVEvalDataTransform(
-                #normalize=self.normalization(),
-                size_crops=self.size_crops,
-                num_crops=self.num_crops,
-                min_scale_crops=self.min_scale_crops,
-                max_scale_crops=self.max_scale_crops,
-                gaussian_blur=self.gaussian_blur,
-                jitter_strength=self.jitter_strength
-            )
-
-- 
GitLab