diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py
index 135ed7ac1bd14a6e8e39ba0478cd9a0c421bdf0c..cc7d5b4b06f1cf3bda20348953614c420070015f 100644
--- a/SSLGlacier/main.py
+++ b/SSLGlacier/main.py
@@ -5,18 +5,21 @@ from torchinfo import summary as torchinfo_summery
 from torchsummary import summary as torch_summary
 from torch._inductor.config import trace
 
-from processing import datamodule_, transformers_
+from modules import datamodule_, transformers_
 import pytorch_lightning as pl
 from models import ResNet_light_mine, Swav_Orig_ResNet
 import os
 import torch
-from processing.swav_module import SwAV
+from modules.swav_module import SwAV
 from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
 import numpy as np
 from utils.utils import str2bool, getDataPath
 
+from typing import *
 
-def main(hparams):
+
+
+def main(hparams, swin_hparams):
     # TODO Ask NORA, Why do you use clip gradient? did you test it?
     ######save check points##########
     checkpoint_dirs = os.path.join('checkpoints', f'ssl_zone_segmentation_{hparams.arch}')
@@ -78,7 +81,8 @@ def main(hparams):
         max_pool1d=False,
         learning_rate=hparams.base_lr,
         just_aug_for_same_assign_views=hparams.just_aug_for_same_assign_views,
-        crops_for_assign=hparams.views_for_assign
+        crops_for_assign=hparams.views_for_assign,
+        hparams=swin_hparams
     )
 
     online_evaluator = SSLOnlineEvaluator(
@@ -139,9 +143,9 @@ if __name__ == '__main__':
     #########################
     # Which agumented image should be used for assignment
     parser.add_argument("--views_for_assign", type=int, nargs="+", default=[0, 1],
-                        help="list of agumented views id used for computing assignments")
+                        help="list of augmented views id used for computing assignments")
     parser.add_argument("--just_aug_for_same_assign_views", type=str2bool, default=True,
-                        help="If you want to consider agumented version of the assigned view for swap assignment.")
+                        help="If you want to consider augmented version of the assigned view for swap assignment.")
     parser.add_argument("--temperature", default=0.1, type=float,
                         help="temperature parameter in training loss")
     parser.add_argument("--epsilons", default=0.05, type=float,
@@ -200,8 +204,70 @@ if __name__ == '__main__':
                                                     'noise': 0, 'color_jitter': 0, 'gauss_blur': 0,
                                                     'salt_or_pepper': 0, 'poission': 0, 'speckle': 1, 'zero': 0,
                                                     'normalize': None})
-    parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you wants Swav augmentaion or ours, '
-                                                                        'other papers agumentaions methods will be added later.')
+    parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you want Swav augmentaion or ours, '
+                                                             'other papers agumentaions methods will be added later.')
     hparams = parser.parse_args()
 
-    main(hparams)
+    ###########Swin transformer parameters########
+    #TODO Config file: is a good idea,maybe rafactor the code later....
+    #parameters set in config file, #
+    # Swin Transformer parameters
+    # _C.MODEL.SWIN.PATCH_SIZE = 4
+    # # _C.MODEL.SWIN.PATCH_SIZE = 2
+    # _C.MODEL.SWIN.IN_CHANS = 3
+    # # _C.MODEL.SWIN.IN_CHANS = 1
+    # _C.MODEL.SWIN.EMBED_DIM = 96
+    # _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+    # _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
+    # _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+    # _C.MODEL.SWIN.WINDOW_SIZE = 7
+    # # _C.MODEL.SWIN.WINDOW_SIZE = 2
+    # _C.MODEL.SWIN.MLP_RATIO = 4.
+    # _C.MODEL.SWIN.QKV_BIAS = True
+    # _C.MODEL.SWIN.QK_SCALE = None
+    # _C.MODEL.SWIN.APE = False
+    # _C.MODEL.SWIN.PATCH_NORM = True
+    # _C.MODEL.SWIN.FINAL_UPSAMPLE = "expand_first"
+    ##############################################
+    swin_parser =ArgumentParser()
+    swin_parser.add_argument('hidden_mlp')
+    swin_parser.add_argument('swin_patch_size', default =  4,
+                             type=int,
+                             help='The size (resolution) of each patch.')
+    swin_parser.add_argument('swin_in_chans',  default =  1,
+                             )
+    swin_parser.add_argument('swin_embed_dim', default =  96,
+                             type=int,
+                             help='Dimensionality of patch embedding.')
+    swin_parser.add_argument('swin_depths',
+                             default= [2, 2, 6, 2],
+                             type=list[int],
+                             help="Depth of each layer in the Transformer encoder."
+                             )
+    swin_parser.add_argument('swin_num_heads',
+                             default =  [3, 6, 12, 24],
+                             type=list[int],
+                             help='Number of attention heads in each layer of the Transformer encoder.')
+    swin_parser.add_argument('swin_window_size',
+                             default=2,
+                             type=int,
+                             help= 'Size of windows')
+    swin_parser.add_argument('swin_mlp_ratio',
+                             default= 4.0,type=float,
+                             help='Ratio of MLP hidden dimensionality to embedding dimensionality.')
+    swin_parser.add_argument('swin_QKV_BIAS', default =  True,
+                             type=str2bool,
+                             help='Whether or not a learnable bias should be added to the queries, keys and values.')
+    swin_parser.add_argument('swin_QK_SCALE', default =  None,
+                             type=float,
+                             help='Override default qk scale of head_dim ** -0.5 if set.')
+    swin_parser.add_argument('swin_ape', default =  False,
+                             type=str2bool,
+                             help="If True, add absolute position embedding to the patch embedding.")
+    swin_parser.add_argument('swin_path_norm', default =  True,
+                             type=str2bool,
+                             help="If True, add normalization after patch embedding.")
+    swin_hparams = swin_parser.parse_args()
+
+
+    main(hparams, swin_hparams)
diff --git a/SSLGlacier/models/RESNET.py b/SSLGlacier/models/RESNET.py
index 9a094ef64799a71827ef9cde50265eb421ba8b7f..0ee1479bc44b2b36c51eda018bfee6eda5b95d2b 100644
--- a/SSLGlacier/models/RESNET.py
+++ b/SSLGlacier/models/RESNET.py
@@ -1,7 +1,7 @@
 """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py."""
 import torch
 from torch import nn
-from SSLGlacier.models.base_net import Net
+from SSLGlacier.models.SWINNET import SwinNet
 
 
 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
@@ -122,7 +122,7 @@ class Bottleneck(nn.Module):
         return self.relu(out)
 
 
-class ResNet(Net):
+class ResNet(SwinNet):
     def __init__(
             self,
             block,
diff --git a/SSLGlacier/models/base_net.py b/SSLGlacier/models/SWINNET.py
similarity index 99%
rename from SSLGlacier/models/base_net.py
rename to SSLGlacier/models/SWINNET.py
index 4489ec3baca6f158f7f03e4b1c84259b71d99806..9924814f641ceb58479ce6ec37daa5892bb42593 100644
--- a/SSLGlacier/models/base_net.py
+++ b/SSLGlacier/models/SWINNET.py
@@ -3,7 +3,7 @@ import torch
 from torch import nn
 
 
-class Net(nn.Module):
+class SwinNet(nn.Module):
     def __init__(
             self,
             groups=1,
diff --git a/SSLGlacier/models/UNET.py b/SSLGlacier/models/UNET.py
index ad4f65c88f11d852959a19ebf9e9a62fd59e6e5c..5061bb08e355f3a6c078f428051b9e17530aba43 100644
--- a/SSLGlacier/models/UNET.py
+++ b/SSLGlacier/models/UNET.py
@@ -1,9 +1,9 @@
 import torch
 from torch import Tensor, nn
 from torch.nn import functional as F  # noqa: N812
-from SSLGlacier.models.base_net import Net
+from SSLGlacier.models.SWINNET import SwinNet
 
-class UNet(Net):
+class UNet(SwinNet):
     """Pytorch Lightning implementation of U-Net.
     Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation
     <https://arxiv.org/abs/1505.04597>`_
diff --git a/SSLGlacier/models/hook/Swin_Transformer.py b/SSLGlacier/models/hook/Swin_Transformer.py
index 992126a2d74aad5b258785d4d9476fbc93ec407a..1315bfd20fcda9d2f12310ae640f01d093255ed1 100644
--- a/SSLGlacier/models/hook/Swin_Transformer.py
+++ b/SSLGlacier/models/hook/Swin_Transformer.py
@@ -2,12 +2,10 @@ import torch
 import torch.nn as nn
 import numpy as np
 from timm.models.layers import trunc_normal_
-from SSLGlacier.models.hook.patch_embedding import PatchEmbed
-from SSLGlacier.models.hook.patch_merging import PatchMerging
-from SSLGlacier.models.hook.patch_expand import PatchExpand, PatchExpandC, FinalPatchExpand_X4
-from SSLGlacier.models.hook.basic_swin_layer import BasicLayer
-from SSLGlacier.models.hook.basic_swin_layer_up import BasicLayer_up
-from SSLGlacier.models.hook.basic_swin_layer_cross import BasicLayer_Cross
+from SSLGlacier.models.hook.patch_processing import PatchEmbed
+from SSLGlacier.models.hook.patch_processing import PatchMerging
+from SSLGlacier.models.hook.patch_processing import PatchExpand, PatchExpandC, FinalPatchExpand_X4
+from SSLGlacier.models.hook.basic_layers import BasicLayer, BasicLayer_up, BasicLayer_Cross
 
 torch.set_printoptions(threshold=np.inf)
 torch.set_printoptions(linewidth=300)
diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
index 528757dc5facce15c00dc59694f81dead8590afc..4f961444452f693aa42731f3b19ec22e8a546f5b 100644
--- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
+++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py
@@ -1,41 +1,48 @@
 # coding=utf-8
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
 
 import copy
 import logging
-
 import torch
-import torch.nn as nn
-
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from SSLGlacier.models.SWINNET import SwinNet
 from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys
 
 logger = logging.getLogger(__name__)
 
-class SwinUnet(nn.Module):
-    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
-        super(SwinUnet, self).__init__()
+class SwinUnet(SwinNet):
+    def __init__(self, num_classes=21843, zero_head=False, **kwargs):
+        super().__init__(
+            groups=1,
+            widen=1,
+            width_per_group=64,
+            norm_layer=None,
+            output_dim=0 if 'output_dim' not in kwargs else kwargs['output_dim'],
+            hidden_mlp=0 if 'hidden_mlp' not in kwargs else kwargs['hidden_mlp'],
+            num_prototypes=0 if 'num_prototypes' not in kwargs else kwargs['num_prototypes'],
+            eval_mode=False,
+            normalize=False if 'normalize' not in kwargs else kwargs['normalize']
+        )
+
         self.num_classes = num_classes
         self.zero_head = zero_head
-        self.config = config
 
-        self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE,
-                                            patch_size=config.MODEL.SWIN.PATCH_SIZE,
-                                            in_chans=config.MODEL.SWIN.IN_CHANS,
+        self.swin_unet = SwinTransformerSys(img_size=kwargs.image_size,
+                                            patch_size=kwargs.swin_patch_size,
+                                            in_chans=kwargs.swin_in_chans,
                                             num_classes=self.num_classes,
-                                            embed_dim=config.MODEL.SWIN.EMBED_DIM,
-                                            depths=config.MODEL.SWIN.DEPTHS,
-                                            num_heads=config.MODEL.SWIN.NUM_HEADS,
-                                            window_size=config.MODEL.SWIN.WINDOW_SIZE,
-                                            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
-                                            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
-                                            qk_scale=config.MODEL.SWIN.QK_SCALE,
-                                            drop_rate=config.MODEL.DROP_RATE,
-                                            drop_path_rate=config.MODEL.DROP_PATH_RATE,
-                                            ape=config.MODEL.SWIN.APE,
-                                            patch_norm=config.MODEL.SWIN.PATCH_NORM,
-                                            use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+                                            embed_dim=kwargs.swin_embed_dim,
+                                            depths=kwargs.swin_depths,
+                                            num_heads=kwargs.swin_num_heads,
+                                            window_size=kwargs.swin_window_size,
+                                            mlp_ratio=kwargs.swin_mlp_ratio,
+                                            qkv_bias=kwargs.swin.QKV_BIAS,
+                                            qk_scale=kwargs.swin.QK_SCALE,
+                                            drop_rate=kwargs.drop_rate,
+                                            drop_path_rate=kwargs.drop_path_rate,
+                                            ape=kwargs.swin_ape,
+                                            patch_norm=kwargs.swin_patch_norm)
 
     def forward(self, x, y):
         if x.size()[1] == 1:
diff --git a/SSLGlacier/models/hook/basic_layers.py b/SSLGlacier/models/hook/basic_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..079d1cc4f7944bb707292826838d065735c52f04
--- /dev/null
+++ b/SSLGlacier/models/hook/basic_layers.py
@@ -0,0 +1,229 @@
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from SSLGlacier.models.hook.patch_processing import PatchExpand
+from SSLGlacier.models.hook.swin_transformer_blocks import SwinTransformerBlock, SwinTransformerBlock_Cross
+#-----------------------------------------------
+########### Basic Swin Layer Classes ###########
+#-----------------------------------------------
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution  # 56
+        self.depth = depth  # 2
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,  # 96, 56
+                                 num_heads=num_heads, window_size=window_size,  # 3, 14
+                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                 mlp_ratio=mlp_ratio,  # 4
+                                 qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
+                                 drop=drop, attn_drop=attn_drop,  # 0.2, 0
+                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                 norm_layer=norm_layer)  # layer-norm
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+#-----------------------------------------------
+####### End of Basic Swin Layer Classes ########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+####### Basic Swin Layer cross Classes #########
+#-----------------------------------------------
+
+class BasicLayer_Cross(nn.Module):  # Input of second attention is output of first. See paper
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution  # 56
+        self.depth = depth  # 2
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution,  # 96, 56
+                                       num_heads=num_heads, window_size=window_size,  # 3, 14
+                                       shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                       mlp_ratio=mlp_ratio,  # 4
+                                       qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
+                                       drop=drop, attn_drop=attn_drop,  # 0.2, 0
+                                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                       norm_layer=norm_layer)  # layer-norm
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, y):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x, y)
+        if self.downsample is not None:
+            x = self.downsample(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+#-----------------------------------------------
+##### End of Basic Swin Layer cross Classes ####
+#-----------------------------------------------
+
+################################################
+################################################
+
+
+#-----------------------------------------------
+######### Basic Swin Layer UP Classes ##########
+#-----------------------------------------------
+
+class BasicLayer_up(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+                                 num_heads=num_heads, window_size=window_size,
+                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                 mlp_ratio=mlp_ratio,
+                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                 drop=drop, attn_drop=attn_drop,
+                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                 norm_layer=norm_layer)
+            for i in range(depth)])
+
+        # patch merging layer
+        if upsample is not None:
+            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
+        else:
+            self.upsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.upsample is not None:
+            x = self.upsample(x)
+        return x
+
+#-----------------------------------------------
+##### End of Basic Swin Layer UP Classes ####
+#-----------------------------------------------
diff --git a/SSLGlacier/models/hook/basic_swin_layer.py b/SSLGlacier/models/hook/basic_swin_layer.py
deleted file mode 100644
index e9570cb96d1b2302cd25658c41b1977a8a40aa88..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/basic_swin_layer.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock
-
-
-class BasicLayer(nn.Module):
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution  # 56
-        self.depth = depth  # 2
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,  # 96, 56
-                                 num_heads=num_heads, window_size=window_size,  # 3, 14
-                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                 mlp_ratio=mlp_ratio,  # 4
-                                 qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
-                                 drop=drop, attn_drop=attn_drop,  # 0.2, 0
-                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                 norm_layer=norm_layer)  # layer-norm
-            for i in range(depth)])
-
-        # patch merging layer
-        if downsample is not None:
-            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
-        else:
-            self.downsample = None
-
-    def forward(self, x):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x)
-        if self.downsample is not None:
-            x = self.downsample(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
-    def flops(self):
-        flops = 0
-        for blk in self.blocks:
-            flops += blk.flops()
-        if self.downsample is not None:
-            flops += self.downsample.flops()
-        return flops
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/basic_swin_layer_cross.py b/SSLGlacier/models/hook/basic_swin_layer_cross.py
deleted file mode 100644
index e9481d94d488dca538b3ac308b96bbcf3d7cc3aa..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/basic_swin_layer_cross.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block_cross import SwinTransformerBlock_Cross
-
-
-class BasicLayer_Cross(nn.Module):  # Input of second attention is output of first. See paper
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution  # 56
-        self.depth = depth  # 2
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution,  # 96, 56
-                                       num_heads=num_heads, window_size=window_size,  # 3, 14
-                                       shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                       mlp_ratio=mlp_ratio,  # 4
-                                       qkv_bias=qkv_bias, qk_scale=qk_scale,  # True, None
-                                       drop=drop, attn_drop=attn_drop,  # 0.2, 0
-                                       drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                       norm_layer=norm_layer)  # layer-norm
-            for i in range(depth)])
-
-        # patch merging layer
-        if downsample is not None:
-            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
-        else:
-            self.downsample = None
-
-    def forward(self, x, y):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x, y)
-        if self.downsample is not None:
-            x = self.downsample(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
-    def flops(self):
-        flops = 0
-        for blk in self.blocks:
-            flops += blk.flops()
-        if self.downsample is not None:
-            flops += self.downsample.flops()
-        return flops
diff --git a/SSLGlacier/models/hook/basic_swin_layer_up.py b/SSLGlacier/models/hook/basic_swin_layer_up.py
deleted file mode 100644
index 9ee1b87675acf8a5e8a9d54a98c86f49122f8f56..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/basic_swin_layer_up.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import torch.nn as nn
-import torch.utils.checkpoint as checkpoint
-from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock
-from SSLGlacier.models.hook.patch_expand import PatchExpand
-
-
-class BasicLayer_up(nn.Module):
-    """ A basic Swin Transformer layer for one stage.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resolution.
-        depth (int): Number of blocks.
-        num_heads (int): Number of attention heads.
-        window_size (int): Local window size.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
-        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
-        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
-        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
-    """
-
-    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
-                 drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):
-
-        super().__init__()
-        self.dim = dim
-        self.input_resolution = input_resolution
-        self.depth = depth
-        self.use_checkpoint = use_checkpoint
-
-        # build blocks
-        self.blocks = nn.ModuleList([
-            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
-                                 num_heads=num_heads, window_size=window_size,
-                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
-                                 mlp_ratio=mlp_ratio,
-                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
-                                 drop=drop, attn_drop=attn_drop,
-                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
-                                 norm_layer=norm_layer)
-            for i in range(depth)])
-
-        # patch merging layer
-        if upsample is not None:
-            self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)
-        else:
-            self.upsample = None
-
-    def forward(self, x):
-        for blk in self.blocks:
-            if self.use_checkpoint:
-                x = checkpoint.checkpoint(blk, x)
-            else:
-                x = blk(x)
-        if self.upsample is not None:
-            x = self.upsample(x)
-        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/mlp.py b/SSLGlacier/models/hook/mlp.py
deleted file mode 100644
index 6e7730ffc7f0d4aa7c018084c4169e6b80693d36..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/mlp.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import torch.nn as nn
-
-
-class Mlp(nn.Module):
-    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
-        super().__init__()
-        out_features = out_features or in_features
-        hidden_features = hidden_features or in_features
-        self.fc1 = nn.Linear(in_features, hidden_features)
-        self.act = act_layer()
-        self.fc2 = nn.Linear(hidden_features, out_features)
-        self.drop = nn.Dropout(drop)
-
-    def forward(self, x):
-        x = self.fc1(x)
-        x = self.act(x)
-        x = self.drop(x)
-        x = self.fc2(x)
-        x = self.drop(x)
-        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/patch_embedding.py b/SSLGlacier/models/hook/patch_embedding.py
deleted file mode 100644
index e9fba670109312057260b71a0371789c597c1e12..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/patch_embedding.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch.nn as nn
-from timm.models.layers import to_2tuple
-
-
-class PatchEmbed(nn.Module):
-    r""" Image to Patch Embedding
-
-    Args:
-        img_size (int): Image size.  Default: 224.
-        patch_size (int): Patch token size. Default: 4.
-        in_chans (int): Number of input image channels. Default: 3.
-        embed_dim (int): Number of linear projection output channels. Default: 96.
-        norm_layer (nn.Module, optional): Normalization layer. Default: None
-    """
-
-    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
-        super().__init__()
-        img_size = to_2tuple(img_size)
-        patch_size = to_2tuple(patch_size)
-        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
-        self.img_size = img_size  # 224
-        self.patch_size = patch_size  # 4
-        self.patches_resolution = patches_resolution  # 56
-        self.num_patches = patches_resolution[0] * patches_resolution[1]  # 3136
-
-        self.in_chans = in_chans
-        self.embed_dim = embed_dim
-
-        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 22, 96, 56, 56
-
-        if norm_layer is not None:
-            self.norm = norm_layer(embed_dim)
-        else:
-            self.norm = None
-
-    def forward(self, x):
-        B, C, H, W = x.shape
-        # FIXME look at relaxing size constraints
-        assert H == self.img_size[0] and W == self.img_size[1], \
-            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
-        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C (22, 3136, 96)
-        if self.norm is not None:
-            x = self.norm(x)
-        return x
-
-    def flops(self):
-        Ho, Wo = self.patches_resolution
-        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
-        if self.norm is not None:
-            flops += Ho * Wo * self.embed_dim
-        return flops
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_expand.py
deleted file mode 100644
index ef397f8d147497ab1005ee7d2a39575879354d9f..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/patch_expand.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import torch.nn as nn
-from einops import rearrange
-# import torch
-
-
-class PatchExpand(nn.Module):
-    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.input_resolution = input_resolution
-        self.dim = dim  # 96
-        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
-        self.norm = norm_layer(dim // dim_scale)
-
-    def forward(self, x):
-        """
-        x: B, H*W, C
-        """
-        H, W = self.input_resolution
-        x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-        x = x.view(B, H, W, C)  # 7, 7, 7, 1536
-        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
-        x = x.view(B, -1, C // 4)
-        x = self.norm(x)
-
-        return x
-
-
-# class PatchExpand4(nn.Module):
-#     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
-#         super().__init__()
-#         self.input_resolution = input_resolution
-#         self.dim = dim  # 96
-#         self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity()
-#         self.norm = norm_layer(96)
-#
-#     def forward(self, x):
-#         """
-#         x: B, H*W, C
-#         """
-#         H, W = self.input_resolution
-#         x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
-#         B, L, C = x.shape
-#         assert L == H * W, "input feature has wrong size"
-#         x = x.view(B, H, W, C)  # 7, 7, 7, 1536
-#         x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
-#         x = x.view(B, -1, C // 4)
-#         x = self.norm(x)
-#
-#         return x
-
-
-class PatchExpandC(nn.Module):
-    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.input_resolution = input_resolution
-        self.dim = dim  # 96
-        self.expand = nn.Linear(dim, dim, bias=False) if dim_scale == 2 else nn.Identity()
-        self.norm = norm_layer(192)
-
-    def forward(self, x):
-        """
-        x: B, H*W, C
-        """
-        H, W = self.input_resolution
-        x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-        x = x.view(B, H, W, C)  # 7, 7, 7, 1536
-        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
-        x = x.view(B, -1, C // 4)
-        x = self.norm(x)
-
-        return x
-
-
-# class PCAa(object):
-#     def __init__(self, n_components=2):
-#         self.n_components = n_components
-#
-#     def fit_trans(self, X):
-#         n = X.shape[1]
-#         self.mean = torch.mean(X, axis=1).unsqueeze(dim=1)
-#
-#         X = X - self.mean
-#         covariance_matrix = 1 / n * torch.matmul(X.permute(0, 2, 1), X)
-#
-#         eigenvalues, eigenvectors = torch.linalg.eig(covariance_matrix)
-#
-#         eigenvalues = torch.real(eigenvalues)
-#         eigenvectors = torch.real(eigenvectors)
-#
-#         idx = torch.argsort(-eigenvalues)
-#
-#         eigenvectors = torch.gather(eigenvectors, 2, idx.unsqueeze(dim=1).expand(eigenvectors.shape))
-#         self.proj_mat = eigenvectors[:, :, 0:self.n_components]
-#
-#         Out = X.matmul(self.proj_mat)
-#
-#         return Out
-#
-#
-# class PatchExpandPCA(nn.Module):
-#     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
-#         super().__init__()
-#         self.input_resolution = input_resolution
-#         self.dim = dim  # 96
-#         # self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity()
-#         self.norm = norm_layer(96)
-#
-#     def forward(self, x):
-#         """
-#         x: B, H*W, C
-#         """
-#         H, W = self.input_resolution
-#         pcaa = PCAa(n_components=384)
-#         x = pcaa.fit_trans(x)
-#
-#         B, L, C = x.shape
-#         assert L == H * W, "input feature has wrong size"
-#         x = x.view(B, H, W, C)  # 7, 7, 7, 1536
-#         x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
-#         x = x.view(B, -1, C // 4)
-#         x = self.norm(x)
-#
-#         return x
-
-
-class FinalPatchExpand_X4(nn.Module):
-    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.input_resolution = input_resolution
-        self.dim = dim
-        self.dim_scale = dim_scale
-        self.expand = nn.Linear(dim, 16 * dim, bias=False)
-        self.output_dim = dim
-        self.norm = norm_layer(self.output_dim)
-
-    def forward(self, x):
-        """
-        x: B, H*W, C
-        """
-        H, W = self.input_resolution
-        x = self.expand(x)
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-
-        x = x.view(B, H, W, C)
-        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
-                      c=C // (self.dim_scale ** 2))
-        x = x.view(B, -1, self.output_dim)
-        x = self.norm(x)
-
-        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/patch_merging.py b/SSLGlacier/models/hook/patch_merging.py
deleted file mode 100644
index 0567504b51f4936d8a7152fcfc5d37547e2f7ff9..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/patch_merging.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class PatchMerging(nn.Module):
-    r""" Patch Merging Layer.
-
-    Args:
-        input_resolution (tuple[int]): Resolution of input feature.
-        dim (int): Number of input channels.
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.input_resolution = input_resolution
-        self.dim = dim
-        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
-        self.norm = norm_layer(4 * dim)
-
-    def forward(self, x):
-        """
-        x: B, H*W, C
-        """
-        H, W = self.input_resolution
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
-        x = x.view(B, H, W, C)
-        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
-        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
-        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
-        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
-        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
-        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
-
-        x = self.norm(x)
-        x = self.reduction(x)
-
-        return x
-
-    def extra_repr(self) -> str:
-        return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
-    def flops(self):
-        H, W = self.input_resolution
-        flops = H * W * self.dim
-        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
-        return flops
-
diff --git a/SSLGlacier/models/hook/patch_processing.py b/SSLGlacier/models/hook/patch_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7e1f3309c2e8e8edd64b70986c1e18b5886bebc
--- /dev/null
+++ b/SSLGlacier/models/hook/patch_processing.py
@@ -0,0 +1,276 @@
+import torch
+import torch.nn as nn
+from einops import rearrange
+from timm.models.layers import to_2tuple
+#-----------------------------------------------
+############## Patch Merging Class ###########
+#-----------------------------------------------
+class PatchMerging(nn.Module):
+    r""" Patch Merging Layer.
+
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+        x = x.view(B, H, W, C)
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+    def flops(self):
+        H, W = self.input_resolution
+        flops = H * W * self.dim
+        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+        return flops
+
+#-----------------------------------------------
+######### End of Patch Merging Class ###########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+############## Patch Embedding Class ###########
+#-----------------------------------------------
+class PatchEmbed(nn.Module):
+    r""" Image to Patch Embedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size  # 224
+        self.patch_size = patch_size  # 4
+        self.patches_resolution = patches_resolution  # 56
+        self.num_patches = patches_resolution[0] * patches_resolution[1]  # 3136
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)  # 22, 96, 56, 56
+
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C (22, 3136, 96)
+        if self.norm is not None:
+            x = self.norm(x)
+        return x
+
+    def flops(self):
+        Ho, Wo = self.patches_resolution
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops
+#-----------------------------------------------
+######### End of Patch Embeding Class ###########
+#-----------------------------------------------
+
+################################################
+################################################
+
+#-----------------------------------------------
+############# Patch Expand Classes #############
+#-----------------------------------------------
+class PatchExpand(nn.Module):
+    """
+    from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9
+    Take the first patch expanding layer as an example, before up-sampling,
+     a linear layer is applied on the input features (W/32, H/32, 8C) to increase the feature dimension to
+    2 * the original dimension (W/32, H/32, 16C). Then, we use rearrange operation to expand the resolution
+    of the input features to 2*the input resolution and reduce the feature dimension to
+     quarter of the input dimension ((W/32, H/32, 16C) --> (W/16, H/16, 4C)).
+     We will discuss the impact of using patch expanding layer
+      to perform up-sampling inSect. 4.5 in https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9.
+    """
+    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim  # 96
+        self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity()
+        self.norm = norm_layer(dim // dim_scale)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        x = x.view(B, H, W, C)  # 7, 7, 7, 1536
+        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
+        x = x.view(B, -1, C // 4)
+        x = self.norm(x)
+
+        return x
+
+# class PatchExpand4(nn.Module):
+#     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
+#         super().__init__()
+#         self.input_resolution = input_resolution
+#         self.dim = dim  # 96
+#         self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity()
+#         self.norm = norm_layer(96)
+#
+#     def forward(self, x):
+#         """
+#         x: B, H*W, C
+#         """
+#         H, W = self.input_resolution
+#         x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
+#         B, L, C = x.shape
+#         assert L == H * W, "input feature has wrong size"
+#         x = x.view(B, H, W, C)  # 7, 7, 7, 1536
+#         x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
+#         x = x.view(B, -1, C // 4)
+#         x = self.norm(x)
+#
+#         return x
+class PatchExpandC(nn.Module):
+    def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim  # 96
+        self.expand = nn.Linear(dim, dim, bias=False) if dim_scale == 2 else nn.Identity()
+        self.norm = norm_layer(192)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        x = self.expand(x)  # 7, 49, 768; 7, 49, 1536
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        x = x.view(B, H, W, C)  # 7, 7, 7, 1536
+        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
+        x = x.view(B, -1, C // 4)
+        x = self.norm(x)
+
+        return x
+
+# class PCAa(object):
+#     def __init__(self, n_components=2):
+#         self.n_components = n_components
+#
+#     def fit_trans(self, X):
+#         n = X.shape[1]
+#         self.mean = torch.mean(X, axis=1).unsqueeze(dim=1)
+#
+#         X = X - self.mean
+#         covariance_matrix = 1 / n * torch.matmul(X.permute(0, 2, 1), X)
+#
+#         eigenvalues, eigenvectors = torch.linalg.eig(covariance_matrix)
+#
+#         eigenvalues = torch.real(eigenvalues)
+#         eigenvectors = torch.real(eigenvectors)
+#
+#         idx = torch.argsort(-eigenvalues)
+#
+#         eigenvectors = torch.gather(eigenvectors, 2, idx.unsqueeze(dim=1).expand(eigenvectors.shape))
+#         self.proj_mat = eigenvectors[:, :, 0:self.n_components]
+#
+#         Out = X.matmul(self.proj_mat)
+#
+#         return Out
+#
+#
+# class PatchExpandPCA(nn.Module):
+#     def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):
+#         super().__init__()
+#         self.input_resolution = input_resolution
+#         self.dim = dim  # 96
+#         # self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity()
+#         self.norm = norm_layer(96)
+#
+#     def forward(self, x):
+#         """
+#         x: B, H*W, C
+#         """
+#         H, W = self.input_resolution
+#         pcaa = PCAa(n_components=384)
+#         x = pcaa.fit_trans(x)
+#
+#         B, L, C = x.shape
+#         assert L == H * W, "input feature has wrong size"
+#         x = x.view(B, H, W, C)  # 7, 7, 7, 1536
+#         x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4)  # 7, 14, 14, 384
+#         x = x.view(B, -1, C // 4)
+#         x = self.norm(x)
+#
+#         return x
+
+class FinalPatchExpand_X4(nn.Module):
+    def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.dim_scale = dim_scale
+        self.expand = nn.Linear(dim, 16 * dim, bias=False)
+        self.output_dim = dim
+        self.norm = norm_layer(self.output_dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        x = self.expand(x)
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale,
+                      c=C // (self.dim_scale ** 2))
+        x = x.view(B, -1, self.output_dim)
+        x = self.norm(x)
+
+        return x
\ No newline at end of file
diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py
deleted file mode 100644
index dd8938ca0fd03270fff060485a4ff92cfc3ab775..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/swin_transformer_block.py
+++ /dev/null
@@ -1,136 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import DropPath, to_2tuple
-from models.window_attention import WindowAttention
-from models.mlp import Mlp
-from models.window_utils import window_partition, window_reverse
-
-
-class SwinTransformerBlock(nn.Module):
-    r""" Swin Transformer Block.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resulotion.
-        num_heads (int): Number of attention heads.
-        window_size (int): Window size.
-        shift_size (int): Shift size for SW-MSA.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float, optional): Stochastic depth rate. Default: 0.0
-        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
-                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution
-        self.num_heads = num_heads
-        self.window_size = window_size
-        self.shift_size = shift_size  # 0 or 7
-        self.mlp_ratio = mlp_ratio
-        if min(self.input_resolution) <= self.window_size:
-            # if window size is larger than input resolution, we don't partition windows
-            self.shift_size = 0
-            self.window_size = min(self.input_resolution)
-        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
-        self.norm1 = norm_layer(dim)
-        self.attn = WindowAttention(
-            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
-            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-        self.norm2 = norm_layer(dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
-        if self.shift_size > 0:
-            # calculate attention mask for SW-MSA
-            H, W = self.input_resolution  # 56, 56
-            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
-            h_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-
-            w_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-            cnt = 0
-            for h in h_slices:
-                for w in w_slices:
-                    img_mask[:, h, w, :] = cnt
-                    cnt += 1
-
-            mask_windows = window_partition(img_mask,
-                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
-            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
-            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
-            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
-                0.0))  # 64, 49, 49 # 16, 196, 196
-        else:
-            attn_mask = None
-
-        self.register_buffer("attn_mask", attn_mask)
-
-    def forward(self, x):
-        H, W = self.input_resolution
-        B, L, C = x.shape
-        assert L == H * W, "input feature has wrong size"
-
-        shortcut = x
-        x = self.norm1(x)
-        x = x.view(B, H, W, C)
-        # cyclic shift
-        if self.shift_size > 0:
-            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-        else:
-            shifted_x = x
-
-        # partition windows
-        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        x_windows = x_windows.view(-1, self.window_size * self.window_size,
-                                   C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        # W-MSA/SW-MSA
-        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
-
-        # merge windows
-        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
-        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
-        # reverse cyclic shift
-        if self.shift_size > 0:
-            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
-        else:
-            x = shifted_x
-
-        x = x.view(B, H * W, C)  # 7, 56*56, 96
-
-        # FFN
-        x = shortcut + self.drop_path(x)
-        x = x + self.drop_path(self.mlp(self.norm2(x)))
-        return x
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
-               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
-    def flops(self):
-        flops = 0
-        H, W = self.input_resolution
-        # norm1
-        flops += self.dim * H * W
-        # W-MSA/SW-MSA
-        nW = H * W / self.window_size / self.window_size
-        flops += nW * self.attn.flops(self.window_size * self.window_size)
-        # mlp
-        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
-        # norm2
-        flops += self.dim * H * W
-        return flops
diff --git a/SSLGlacier/models/hook/swin_transformer_block_cross.py b/SSLGlacier/models/hook/swin_transformer_block_cross.py
deleted file mode 100644
index 8dbfb6719a5718f3e04bbc19a728afc98fbe6219..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/swin_transformer_block_cross.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import DropPath, to_2tuple
-from models.window_cross_attention import WindowAttention_Cross
-from models.mlp import Mlp
-from models.window_utils import window_partition, window_reverse
-
-
-class SwinTransformerBlock_Cross(nn.Module):
-    r""" Swin Transformer Block.
-
-    Args:
-        dim (int): Number of input channels.
-        input_resolution (tuple[int]): Input resulotion.
-        num_heads (int): Number of attention heads.
-        window_size (int): Window size.
-        shift_size (int): Shift size for SW-MSA.
-        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
-        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
-        drop (float, optional): Dropout rate. Default: 0.0
-        attn_drop (float, optional): Attention dropout rate. Default: 0.0
-        drop_path (float, optional): Stochastic depth rate. Default: 0.0
-        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
-        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
-    """
-
-    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
-                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
-                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
-        super().__init__()
-        self.dim = dim  # 96
-        self.input_resolution = input_resolution
-        self.num_heads = num_heads
-        self.window_size = window_size
-        self.shift_size = shift_size  # 0 or 7
-        self.mlp_ratio = mlp_ratio
-        if min(self.input_resolution) <= self.window_size:
-            # if window size is larger than input resolution, we don't partition windows
-            self.shift_size = 0
-            self.window_size = min(self.input_resolution)
-        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
-        self.norm1 = norm_layer(dim)
-        self.attn = WindowAttention_Cross(
-            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
-            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
-
-        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-        self.norm2 = norm_layer(dim)
-        mlp_hidden_dim = int(dim * mlp_ratio)
-        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
-        if self.shift_size > 0:
-            # calculate attention mask for SW-MSA
-            H, W = self.input_resolution  # 56, 56
-            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
-            h_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-
-            w_slices = (slice(0, -self.window_size),  # 0, -14
-                        slice(-self.window_size, -self.shift_size),  # -14, -7
-                        slice(-self.shift_size, None))  # -7, None
-            cnt = 0
-            for h in h_slices:
-                for w in w_slices:
-                    img_mask[:, h, w, :] = cnt
-                    cnt += 1
-
-            mask_windows = window_partition(img_mask,
-                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
-            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
-            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
-            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
-                0.0))  # 64, 49, 49 # 16, 196, 196
-        else:
-            attn_mask = None
-
-        self.register_buffer("attn_mask", attn_mask)
-
-    def forward(self, target, context):
-        H, W = self.input_resolution
-        B, L, C = target.shape
-        assert L == H * W, "input feature has wrong size"
-
-        shortcut = target
-        target = self.norm1(target)
-        target = target.view(B, H, W, C)
-
-        context = self.norm1(context)
-        context = context.view(B, H, W, C)
-
-        # cyclic shift
-        if self.shift_size > 0:
-            shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-            shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
-        else:
-            shifted_target = target
-            shifted_context = context
-
-        # partition windows
-        target_windows = window_partition(shifted_target,
-                                          self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        context_windows = window_partition(shifted_context,
-                                           self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
-        # print(x_windows.shape, 'ss')
-        target_windows = target_windows.view(-1, self.window_size * self.window_size,
-                                             C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        context_windows = context_windows.view(-1, self.window_size * self.window_size,
-                                               C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
-
-        # W-MSA/SW-MSA
-        attn_windows = self.attn(target_windows, context_windows,
-                                 mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
-
-        # merge windows
-        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
-        shifted_target = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
-        # reverse cyclic shift
-        if self.shift_size > 0:
-            target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
-        else:
-            target = shifted_target
-
-        target = target.view(B, H * W, C)  # 7, 56*56, 96
-
-        # FFN
-        target = shortcut + self.drop_path(target)
-        target = target + self.drop_path(self.mlp(self.norm2(target)))
-
-        return target
-
-    def extra_repr(self) -> str:
-        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
-               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
-    def flops(self):
-        flops = 0
-        H, W = self.input_resolution
-        # norm1
-        flops += self.dim * H * W
-        # W-MSA/SW-MSA
-        nW = H * W / self.window_size / self.window_size
-        flops += nW * self.attn.flops(self.window_size * self.window_size)
-        # mlp
-        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
-        # norm2
-        flops += self.dim * H * W
-        return flops
diff --git a/SSLGlacier/models/hook/swin_transformer_blocks.py b/SSLGlacier/models/hook/swin_transformer_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6745c8a505c38f52232948aaa4dcea6edb29c71
--- /dev/null
+++ b/SSLGlacier/models/hook/swin_transformer_blocks.py
@@ -0,0 +1,296 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, to_2tuple
+from SSLGlacier.models.hook.window_attentions import WindowAttention, WindowAttention_Cross, window_partition, window_reverse
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+class SwinTransformerBlock_Cross(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size  # 0 or 7
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention_Cross(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution  # 56, 56
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
+            h_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+
+            w_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(img_mask,
+                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
+                0.0))  # 64, 49, 49 # 16, 196, 196
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+
+    def forward(self, target, context):
+        H, W = self.input_resolution
+        B, L, C = target.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = target
+        target = self.norm1(target)
+        target = target.view(B, H, W, C)
+
+        context = self.norm1(context)
+        context = context.view(B, H, W, C)
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+            shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_target = target
+            shifted_context = context
+
+        # partition windows
+        target_windows = window_partition(shifted_target,
+                                          self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        context_windows = window_partition(shifted_context,
+                                           self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        # print(x_windows.shape, 'ss')
+        target_windows = target_windows.view(-1, self.window_size * self.window_size,
+                                             C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        context_windows = context_windows.view(-1, self.window_size * self.window_size,
+                                               C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(target_windows, context_windows,
+                                 mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
+        shifted_target = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            target = shifted_target
+
+        target = target.view(B, H * W, C)  # 7, 56*56, 96
+
+        # FFN
+        target = shortcut + self.drop_path(target)
+        target = target + self.drop_path(self.mlp(self.norm2(target)))
+
+        return target
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+        # W-MSA/SW-MSA
+        nW = H * W / self.window_size / self.window_size
+        flops += nW * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
+
+
+class SwinTransformerBlock(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim  # 96
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size  # 0 or 7
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution  # 56, 56
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1; 1, 56, 56, 1
+            h_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+
+            w_slices = (slice(0, -self.window_size),  # 0, -14
+                        slice(-self.window_size, -self.shift_size),  # -14, -7
+                        slice(-self.shift_size, None))  # -7, None
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(img_mask,
+                                            self.window_size)  # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)  # 64, 49 # 16, 196
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # 64, 1, 49 - 64, 49, 1
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(
+                0.0))  # 64, 49, 49 # 16, 196, 196
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+
+    def forward(self, x):
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+        else:
+            shifted_x = x
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C; 448, 7, 7, 96
+        x_windows = x_windows.view(-1, self.window_size * self.window_size,
+                                   C)  # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C; 448, 49, 96
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)  # 448, 7, 7, 96
+        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C; 7, 56, 56, 96
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+        else:
+            x = shifted_x
+
+        x = x.view(B, H * W, C)  # 7, 56*56, 96
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+        # W-MSA/SW-MSA
+        nW = H * W / self.window_size / self.window_size
+        flops += nW * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
diff --git a/SSLGlacier/models/hook/window_attention.py b/SSLGlacier/models/hook/window_attention.py
deleted file mode 100644
index 39e9be5a9db16d19fa596a7b60085a2d44676d6d..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/window_attention.py
+++ /dev/null
@@ -1,104 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import trunc_normal_
-
-
-class WindowAttention(nn.Module):
-    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
-    It supports both of shifted and non-shifted window.
-
-    Args:
-        dim (int): Number of input channels.
-        window_size (tuple[int]): The height and width of the window.
-        num_heads (int): Number of attention heads.
-        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
-        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
-        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
-    """
-
-    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
-        super().__init__()
-        self.dim = dim
-        self.window_size = window_size  # Wh, Ww
-        self.num_heads = num_heads  # 3, 6, 12, 24
-        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
-
-        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
-
-        self.relative_position_bias_table = nn.Parameter(
-            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
-
-        coords_h = torch.arange(self.window_size[0])
-        coords_w = torch.arange(self.window_size[1])
-
-        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
-
-        coords_flatten = torch.flatten(coords, 1)  # 2, 49
-        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
-        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
-        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
-        relative_coords[:, :, 1] += self.window_size[1] - 1
-        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
-        relative_position_index = relative_coords.sum(-1)  # 49, 49
-        self.register_buffer("relative_position_index", relative_position_index)
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
-        self.proj_drop = nn.Dropout(proj_drop)
-
-        trunc_normal_(self.relative_position_bias_table, std=.02)
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, x, mask=None):
-        """
-        Args:
-            x: input features with shape of (num_windows*B, N, C)
-            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
-        """
-        B_, N, C = x.shape  # -1, 49, 96
-        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
-                                                                                         4)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)  # 448, 3, 49, 32
-
-        q = q * self.scale  # 1/np.sqrt(32)
-        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
-        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
-            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
-            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
-
-        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
-        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
-
-        if mask is not None:
-            nW = mask.shape[0]
-            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
-                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
-            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
-            attn = self.softmax(attn)
-        else:
-            attn = self.softmax(attn)
-
-        attn = self.attn_drop(attn)
-        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-    def extra_repr(self) -> str:
-        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
-    def flops(self, N):
-        # calculate flops for 1 window with token length of N
-        flops = 0
-        # qkv = self.qkv(x)
-        flops += N * self.dim * 3 * self.dim
-        # attn = (q @ k.transpose(-2, -1))
-        flops += self.num_heads * N * (self.dim // self.num_heads) * N
-        #  x = (attn @ v)
-        flops += self.num_heads * N * N * (self.dim // self.num_heads)
-        # x = self.proj(x)
-        flops += N * self.dim * self.dim
-        return flops
diff --git a/SSLGlacier/models/hook/window_attentions.py b/SSLGlacier/models/hook/window_attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..422547958583f2218f1a7451ea90a600eed4748f
--- /dev/null
+++ b/SSLGlacier/models/hook/window_attentions.py
@@ -0,0 +1,240 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import trunc_normal_
+
+
+class WindowAttention(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads  # 3, 6, 12, 24
+        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
+
+        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
+
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
+
+        coords_flatten = torch.flatten(coords, 1)  # 2, 49
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
+        relative_position_index = relative_coords.sum(-1)  # 49, 49
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape  # -1, 49, 96
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1,
+                                                                                         4)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)  # 448, 3, 49, 32
+
+        q = q * self.scale  # 1/np.sqrt(32)
+        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
+            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
+        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
+                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
+            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, N):
+        # calculate flops for 1 window with token length of N
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += N * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * N * (self.dim // self.num_heads) * N
+        #  x = (attn @ v)
+        flops += self.num_heads * N * N * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += N * self.dim * self.dim
+        return flops
+
+class WindowAttention_Cross(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads  # 3, 6, 12, 24
+        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
+
+        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
+
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
+
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
+
+        coords_flatten = torch.flatten(coords, 1)  # 2, 49
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
+        relative_position_index = relative_coords.sum(-1)  # 49, 49
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, y, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape  # -1, 49, 96
+
+        q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+        v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
+
+        q = q * self.scale  # 1/np.sqrt(32)
+        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
+            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
+
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
+        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
+                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
+            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, N):
+        # calculate flops for 1 window with token length of N
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += N * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * N * (self.dim // self.num_heads) * N
+        #  x = (attn @ v)
+        flops += self.num_heads * N * N * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += N * self.dim * self.dim
+        return flops
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape  # 7, 56, 56, 96; 1, 56, 56, 1
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size,
+               C)  # 7, 8, 7, 8, 7, 96  # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size,
+                                                            C)  # 7, 8, 8, 7, 7, 96;  -1, 7, 7, 96
+    # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))  # 448 / (56*56/7/7) = 7
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # 7, 8, 8, 7, 7, 96
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # 7, 8, 7, 8, 96; 7, 56, 56, 96
+    return x
+
diff --git a/SSLGlacier/models/hook/window_cross_attention.py b/SSLGlacier/models/hook/window_cross_attention.py
deleted file mode 100644
index 7b5ddfde667f4b5269523bb189fa743bfd3a01dd..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/window_cross_attention.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import torch
-import torch.nn as nn
-from timm.models.layers import trunc_normal_
-
-
-class WindowAttention_Cross(nn.Module):
-    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
-    It supports both of shifted and non-shifted window.
-
-    Args:
-        dim (int): Number of input channels.
-        window_size (tuple[int]): The height and width of the window.
-        num_heads (int): Number of attention heads.
-        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
-        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
-        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
-        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
-    """
-
-    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
-
-        super().__init__()
-        self.dim = dim
-        self.window_size = window_size  # Wh, Ww
-        self.num_heads = num_heads  # 3, 6, 12, 24
-        head_dim = dim // num_heads  # 96 / 3 = 32; 96, 192, 384, 768
-
-        self.scale = qk_scale or head_dim ** -0.5  # 1/np.sqrt(32)
-
-        self.relative_position_bias_table = nn.Parameter(
-            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH; 169, 3
-
-        coords_h = torch.arange(self.window_size[0])
-        coords_w = torch.arange(self.window_size[1])
-
-        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, 7, 7
-
-        coords_flatten = torch.flatten(coords, 1)  # 2, 49
-        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, 49, 49
-        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # 49, 49, 2
-        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
-        relative_coords[:, :, 1] += self.window_size[1] - 1
-        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1  # 14-1
-        relative_position_index = relative_coords.sum(-1)  # 49, 49
-        self.register_buffer("relative_position_index", relative_position_index)
-
-        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)  # 96, 192, 384, 768
-        self.attn_drop = nn.Dropout(attn_drop)
-        self.proj = nn.Linear(dim, dim)  # 96, 192, 384, 768
-        self.proj_drop = nn.Dropout(proj_drop)
-
-        trunc_normal_(self.relative_position_bias_table, std=.02)
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, x, y, mask=None):
-        """
-        Args:
-            x: input features with shape of (num_windows*B, N, C)
-            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
-        """
-        B_, N, C = x.shape  # -1, 49, 96
-
-        q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-        v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)  # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32
-
-        q = q * self.scale  # 1/np.sqrt(32)
-        attn = (q @ k.transpose(-2, -1))  # 448, 3, 49, 49
-        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
-            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
-            -1)  # Wh*Ww,Wh*Ww,nH; 49, 49, 3
-
-        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww; 3, 49, 49
-        attn = attn + relative_position_bias.unsqueeze(0)  # 488, 3, 49, 49
-
-        if mask is not None:
-            nW = mask.shape[0]
-            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
-                0)  # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49
-            attn = attn.view(-1, self.num_heads, N, N)  # 488, 3, 49, 49
-            attn = self.softmax(attn)
-        else:
-            attn = self.softmax(attn)
-
-        attn = self.attn_drop(attn)
-        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)  # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96
-        x = self.proj(x)
-        x = self.proj_drop(x)
-        return x
-
-    def extra_repr(self) -> str:
-        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
-
-    def flops(self, N):
-        # calculate flops for 1 window with token length of N
-        flops = 0
-        # qkv = self.qkv(x)
-        flops += N * self.dim * 3 * self.dim
-        # attn = (q @ k.transpose(-2, -1))
-        flops += self.num_heads * N * (self.dim // self.num_heads) * N
-        #  x = (attn @ v)
-        flops += self.num_heads * N * N * (self.dim // self.num_heads)
-        # x = self.proj(x)
-        flops += N * self.dim * self.dim
-        return flops
-
diff --git a/SSLGlacier/models/hook/window_utils.py b/SSLGlacier/models/hook/window_utils.py
deleted file mode 100644
index fe08bade963ede9938ef18fde95624f4b43a98f9..0000000000000000000000000000000000000000
--- a/SSLGlacier/models/hook/window_utils.py
+++ /dev/null
@@ -1,34 +0,0 @@
-def window_partition(x, window_size):
-    """
-    Args:
-        x: (B, H, W, C)
-        window_size (int): window size
-
-    Returns:
-        windows: (num_windows*B, window_size, window_size, C)
-    """
-    B, H, W, C = x.shape  # 7, 56, 56, 96; 1, 56, 56, 1
-    x = x.view(B, H // window_size, window_size, W // window_size, window_size,
-               C)  # 7, 8, 7, 8, 7, 96  # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1
-    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size,
-                                                            C)  # 7, 8, 8, 7, 7, 96;  -1, 7, 7, 96
-    # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1
-    return windows
-
-
-def window_reverse(windows, window_size, H, W):
-    """
-    Args:
-        windows: (num_windows*B, window_size, window_size, C)
-        window_size (int): Window size
-        H (int): Height of image
-        W (int): Width of image
-
-    Returns:
-        x: (B, H, W, C)
-    """
-    B = int(windows.shape[0] / (H * W / window_size / window_size))  # 448 / (56*56/7/7) = 7
-    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)  # 7, 8, 8, 7, 7, 96
-    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)  # 7, 8, 7, 8, 96; 7, 56, 56, 96
-    return x
-
diff --git a/SSLGlacier/processing/agumentations_.py b/SSLGlacier/modules/agumentations_.py
similarity index 100%
rename from SSLGlacier/processing/agumentations_.py
rename to SSLGlacier/modules/agumentations_.py
diff --git a/SSLGlacier/processing/datamodule_.py b/SSLGlacier/modules/datamodule_.py
similarity index 100%
rename from SSLGlacier/processing/datamodule_.py
rename to SSLGlacier/modules/datamodule_.py
diff --git a/SSLGlacier/processing/semi_module.py b/SSLGlacier/modules/semi_module.py
similarity index 100%
rename from SSLGlacier/processing/semi_module.py
rename to SSLGlacier/modules/semi_module.py
diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/modules/swav_module.py
similarity index 92%
rename from SSLGlacier/processing/swav_module.py
rename to SSLGlacier/modules/swav_module.py
index 970484bc3afc2c58beb2bda198771c6a45b796b5..3b57a4ede5265846c6a4aeb4b885d3d31169c206 100644
--- a/SSLGlacier/processing/swav_module.py
+++ b/SSLGlacier/modules/swav_module.py
@@ -14,9 +14,11 @@ from SSLGlacier.models.RESNET import resnet18, resnet50
 from SSLGlacier.models.UNET import UNet #The main network
 from pl_bolts.optimizers.lars import LARS
 from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
-#from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
-
 
+##########Importing files for hook############
+from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet
+######Temporary config file for hooknet#######
+from utils.config import get_config
 
 class SwAV(LightningModule):
     def __init__(
@@ -51,7 +53,7 @@ class SwAV(LightningModule):
         weight_decay: float = 1e-6,
         epsilon: float = 0.05,
         just_aug_for_same_assign_views: bool = False,
-        **kwargs
+        swin_hparams = {}
     ) -> None:
         """
         Args:
@@ -148,6 +150,8 @@ class SwAV(LightningModule):
         self.train_iters_per_epoch = self.num_samples // global_batch_size
         self.queue = None
 
+        ####For hook####
+        self.swin_hparams = swin_hparams
     def setup(self, stage):
         if self.queue_length > 0:
             queue_folder = os.path.join(self.logger.log_dir, self.queue_path)
@@ -176,12 +180,30 @@ class SwAV(LightningModule):
                                 first_conv=self.first_conv,
                                 maxpool1=self.maxpool1)
         elif self.arch == "hook":
-            backbone = resnet50(normalize=True,
+            backbone = SwinUnet(
+                                img_size=224,
+                                num_classes=5,
+                                normalize=True,
                                 hidden_mlp=self.hidden_mlp,
                                 output_dim=self.feat_dim,
                                 num_prototypes=self.num_prototypes,
                                 first_conv=self.first_conv,
-                                maxpool1=self.maxpool1)
+                                maxpool1=self.maxpool1,######till here for basenet
+                                image_size= self.swin_hparams.image_size,####from here for swin itself
+                                swin_patch_size = self.swin_hparams.swin_patch_size,
+                                swin_in_chans=self.swin_hparams.swin_in_chans,
+                                swin_embed_dim=self.swin_hparams.swin_embed_dim,
+                                swin_depths=self.swin_hparams.swin_depths,
+                                swin_num_heads=self.swin_hparams.swin_num_heads,
+                                swin_window_size=self.swin_hparams.swin_window_size,
+                                swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio,
+                                swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS,
+                                swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE,
+                                drop_rate=self.drop_rate,
+                                drop_path_rate=self.drop_path_rate,
+                                swin_ape=self.swin_hparams.swin_ape,
+                                swin_patch_norm=self.swin_hparams.swin_path_norm
+                                )
 
         elif self.arch == "Unet" or "unet":
             backbone = UNet(num_classes=5,
diff --git a/SSLGlacier/processing/transformers_.py b/SSLGlacier/modules/transformers_.py
similarity index 100%
rename from SSLGlacier/processing/transformers_.py
rename to SSLGlacier/modules/transformers_.py
diff --git a/SSLGlacier/utils/config.py b/SSLGlacier/utils/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0b002125d72a899336c30c614afd9a4ff0c4238
--- /dev/null
+++ b/SSLGlacier/utils/config.py
@@ -0,0 +1,236 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------'
+
+import os
+import yaml
+from yacs.config import CfgNode as CN
+
+_C = CN()
+
+# Base config files
+_C.BASE = ['']
+
+# -----------------------------------------------------------------------------
+# Data settings
+# -----------------------------------------------------------------------------
+_C.DATA = CN()
+# Batch size for a single GPU, could be overwritten by command line argument
+_C.DATA.BATCH_SIZE = 128
+# Path to dataset, could be overwritten by command line argument
+_C.DATA.DATA_PATH = ''
+# Dataset name
+_C.DATA.DATASET = 'imagenet'
+# Input image size
+_C.DATA.IMG_SIZE = 224
+# _C.DATA.IMG_SIZE = 56
+# Interpolation to resize image (random, bilinear, bicubic)
+_C.DATA.INTERPOLATION = 'bicubic'
+# Use zipped dataset instead of folder dataset
+# could be overwritten by command line argument
+_C.DATA.ZIP_MODE = False
+# Cache Data in Memory, could be overwritten by command line argument
+_C.DATA.CACHE_MODE = 'part'
+# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
+_C.DATA.PIN_MEMORY = True
+# Number of data loading threads
+_C.DATA.NUM_WORKERS = 8
+
+# -----------------------------------------------------------------------------
+# Model settings
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+# Model type
+_C.MODEL.TYPE = 'swin'
+# Model name
+_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
+# Checkpoint to resume, could be overwritten by command line argument
+# _C.MODEL.PRETRAIN_CKPT = '/home/wf/Swin-Unet-main/output/epoch_150.pth'
+_C.MODEL.PRETRAIN_CKPT = None
+_C.MODEL.RESUME = ''
+# Number of classes, overwritten in data preparation
+_C.MODEL.NUM_CLASSES = 1000
+# Dropout rate
+_C.MODEL.DROP_RATE = 0.0
+# Drop path rate
+_C.MODEL.DROP_PATH_RATE = 0.1
+# Label Smoothing
+_C.MODEL.LABEL_SMOOTHING = 0.1
+
+# Swin Transformer parameters
+_C.MODEL.SWIN = CN()
+_C.MODEL.SWIN.PATCH_SIZE = 4
+# _C.MODEL.SWIN.PATCH_SIZE = 2
+_C.MODEL.SWIN.IN_CHANS = 3
+# _C.MODEL.SWIN.IN_CHANS = 1
+_C.MODEL.SWIN.EMBED_DIM = 96
+_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN.WINDOW_SIZE = 7
+# _C.MODEL.SWIN.WINDOW_SIZE = 2
+_C.MODEL.SWIN.MLP_RATIO = 4.
+_C.MODEL.SWIN.QKV_BIAS = True
+_C.MODEL.SWIN.QK_SCALE = None
+_C.MODEL.SWIN.APE = False
+_C.MODEL.SWIN.PATCH_NORM = True
+_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
+
+# -----------------------------------------------------------------------------
+# Training settings
+# -----------------------------------------------------------------------------
+_C.TRAIN = CN()
+_C.TRAIN.START_EPOCH = 0
+_C.TRAIN.EPOCHS = 300
+_C.TRAIN.WARMUP_EPOCHS = 20
+_C.TRAIN.WEIGHT_DECAY = 0.05
+_C.TRAIN.BASE_LR = 5e-4
+_C.TRAIN.WARMUP_LR = 5e-7
+_C.TRAIN.MIN_LR = 5e-6
+# Clip gradient norm
+_C.TRAIN.CLIP_GRAD = 5.0
+# Auto resume from latest checkpoint
+_C.TRAIN.AUTO_RESUME = True
+# Gradient accumulation steps
+# could be overwritten by command line argument
+_C.TRAIN.ACCUMULATION_STEPS = 0
+# Whether to use gradient checkpointing to save memory
+# could be overwritten by command line argument
+_C.TRAIN.USE_CHECKPOINT = False
+
+# LR scheduler
+_C.TRAIN.LR_SCHEDULER = CN()
+_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
+# Epoch interval to decay LR, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
+# LR decay rate, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
+
+# Optimizer
+_C.TRAIN.OPTIMIZER = CN()
+_C.TRAIN.OPTIMIZER.NAME = 'adamw'       # - NOT USED
+# Optimizer Epsilon
+_C.TRAIN.OPTIMIZER.EPS = 1e-8
+# Optimizer Betas
+_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
+# SGD momentum
+_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
+
+# -----------------------------------------------------------------------------
+# Augmentation settings - NOT USED
+# -----------------------------------------------------------------------------
+_C.AUG = CN()
+# Color jitter factor
+_C.AUG.COLOR_JITTER = 0.4
+# Use AutoAugment policy. "v0" or "original"
+_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
+# Random erase prob
+_C.AUG.REPROB = 0.25
+# Random erase mode
+_C.AUG.REMODE = 'pixel'
+# Random erase count
+_C.AUG.RECOUNT = 1
+# Mixup alpha, mixup enabled if > 0
+_C.AUG.MIXUP = 0.8
+# Cutmix alpha, cutmix enabled if > 0
+_C.AUG.CUTMIX = 1.0
+# Cutmix min/max ratio, overrides alpha and enables cutmix if set
+_C.AUG.CUTMIX_MINMAX = None
+# Probability of performing mixup or cutmix when either/both is enabled
+_C.AUG.MIXUP_PROB = 1.0
+# Probability of switching to cutmix when both mixup and cutmix enabled
+_C.AUG.MIXUP_SWITCH_PROB = 0.5
+# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
+_C.AUG.MIXUP_MODE = 'batch'
+
+# -----------------------------------------------------------------------------
+# Testing settings
+# -----------------------------------------------------------------------------
+_C.TEST = CN()
+# Whether to use center crop when testing
+_C.TEST.CROP = True
+
+# -----------------------------------------------------------------------------
+# Misc
+# -----------------------------------------------------------------------------
+# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
+# overwritten by command line argument
+_C.AMP_OPT_LEVEL = ''
+# Path to output folder, overwritten by command line argument
+_C.OUTPUT = ''
+# Tag of experiment, overwritten by command line argument
+_C.TAG = 'default'
+# Frequency to save checkpoint
+_C.SAVE_FREQ = 1
+# Frequency to logging info
+_C.PRINT_FREQ = 10
+# Fixed random seed
+_C.SEED = 0
+# Perform evaluation only, overwritten by command line argument
+_C.EVAL_MODE = False
+# Test throughput only, overwritten by command line argument
+_C.THROUGHPUT_MODE = False
+# local rank for DistributedDataParallel, given by command line argument
+_C.LOCAL_RANK = 0
+
+
+# NOT USED FROM HERE ON
+
+def _update_config_from_file(config, cfg_file):
+    config.defrost()
+    with open(cfg_file, 'r') as f:
+        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
+
+    for cfg in yaml_cfg.setdefault('BASE', ['']):
+        if cfg:
+            _update_config_from_file(
+                config, os.path.join(os.path.dirname(cfg_file), cfg)
+            )
+    print('=> merge config from {}'.format(cfg_file))
+    config.merge_from_file(cfg_file)
+    config.freeze()
+
+
+def update_config(config, args):
+    _update_config_from_file(config, args.cfg)
+
+    config.defrost()
+    if args.opts:
+        config.merge_from_list(args.opts)
+
+    # merge from specific arguments
+    if args.batch_size:
+        config.DATA.BATCH_SIZE = args.batch_size
+    if args.zip:
+        config.DATA.ZIP_MODE = True
+    if args.cache_mode:
+        config.DATA.CACHE_MODE = args.cache_mode
+    if args.resume:
+        config.MODEL.RESUME = args.resume
+    if args.accumulation_steps:
+        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
+    if args.use_checkpoint:
+        config.TRAIN.USE_CHECKPOINT = True
+    if args.amp_opt_level:
+        config.AMP_OPT_LEVEL = args.amp_opt_level
+    if args.tag:
+        config.TAG = args.tag
+    if args.eval:
+        config.EVAL_MODE = True
+    if args.throughput:
+        config.THROUGHPUT_MODE = True
+
+    config.freeze()
+
+
+def get_config(args):
+    """Get a yacs CfgNode object with default values."""
+    # Return a clone so that the defaults will not be altered
+    # This is for the "local variable" use pattern
+    config = _C.clone()
+    update_config(config, args)
+
+    return config