diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py index 135ed7ac1bd14a6e8e39ba0478cd9a0c421bdf0c..cc7d5b4b06f1cf3bda20348953614c420070015f 100644 --- a/SSLGlacier/main.py +++ b/SSLGlacier/main.py @@ -5,18 +5,21 @@ from torchinfo import summary as torchinfo_summery from torchsummary import summary as torch_summary from torch._inductor.config import trace -from processing import datamodule_, transformers_ +from modules import datamodule_, transformers_ import pytorch_lightning as pl from models import ResNet_light_mine, Swav_Orig_ResNet import os import torch -from processing.swav_module import SwAV +from modules.swav_module import SwAV from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator import numpy as np from utils.utils import str2bool, getDataPath +from typing import * -def main(hparams): + + +def main(hparams, swin_hparams): # TODO Ask NORA, Why do you use clip gradient? did you test it? ######save check points########## checkpoint_dirs = os.path.join('checkpoints', f'ssl_zone_segmentation_{hparams.arch}') @@ -78,7 +81,8 @@ def main(hparams): max_pool1d=False, learning_rate=hparams.base_lr, just_aug_for_same_assign_views=hparams.just_aug_for_same_assign_views, - crops_for_assign=hparams.views_for_assign + crops_for_assign=hparams.views_for_assign, + hparams=swin_hparams ) online_evaluator = SSLOnlineEvaluator( @@ -139,9 +143,9 @@ if __name__ == '__main__': ######################### # Which agumented image should be used for assignment parser.add_argument("--views_for_assign", type=int, nargs="+", default=[0, 1], - help="list of agumented views id used for computing assignments") + help="list of augmented views id used for computing assignments") parser.add_argument("--just_aug_for_same_assign_views", type=str2bool, default=True, - help="If you want to consider agumented version of the assigned view for swap assignment.") + help="If you want to consider augmented version of the assigned view for swap assignment.") parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") parser.add_argument("--epsilons", default=0.05, type=float, @@ -200,8 +204,70 @@ if __name__ == '__main__': 'noise': 0, 'color_jitter': 0, 'gauss_blur': 0, 'salt_or_pepper': 0, 'poission': 0, 'speckle': 1, 'zero': 0, 'normalize': None}) - parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you wants Swav augmentaion or ours, ' - 'other papers agumentaions methods will be added later.') + parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you want Swav augmentaion or ours, ' + 'other papers agumentaions methods will be added later.') hparams = parser.parse_args() - main(hparams) + ###########Swin transformer parameters######## + #TODO Config file: is a good idea,maybe rafactor the code later.... + #parameters set in config file, # + # Swin Transformer parameters + # _C.MODEL.SWIN.PATCH_SIZE = 4 + # # _C.MODEL.SWIN.PATCH_SIZE = 2 + # _C.MODEL.SWIN.IN_CHANS = 3 + # # _C.MODEL.SWIN.IN_CHANS = 1 + # _C.MODEL.SWIN.EMBED_DIM = 96 + # _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] + # _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] + # _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] + # _C.MODEL.SWIN.WINDOW_SIZE = 7 + # # _C.MODEL.SWIN.WINDOW_SIZE = 2 + # _C.MODEL.SWIN.MLP_RATIO = 4. + # _C.MODEL.SWIN.QKV_BIAS = True + # _C.MODEL.SWIN.QK_SCALE = None + # _C.MODEL.SWIN.APE = False + # _C.MODEL.SWIN.PATCH_NORM = True + # _C.MODEL.SWIN.FINAL_UPSAMPLE = "expand_first" + ############################################## + swin_parser =ArgumentParser() + swin_parser.add_argument('hidden_mlp') + swin_parser.add_argument('swin_patch_size', default = 4, + type=int, + help='The size (resolution) of each patch.') + swin_parser.add_argument('swin_in_chans', default = 1, + ) + swin_parser.add_argument('swin_embed_dim', default = 96, + type=int, + help='Dimensionality of patch embedding.') + swin_parser.add_argument('swin_depths', + default= [2, 2, 6, 2], + type=list[int], + help="Depth of each layer in the Transformer encoder." + ) + swin_parser.add_argument('swin_num_heads', + default = [3, 6, 12, 24], + type=list[int], + help='Number of attention heads in each layer of the Transformer encoder.') + swin_parser.add_argument('swin_window_size', + default=2, + type=int, + help= 'Size of windows') + swin_parser.add_argument('swin_mlp_ratio', + default= 4.0,type=float, + help='Ratio of MLP hidden dimensionality to embedding dimensionality.') + swin_parser.add_argument('swin_QKV_BIAS', default = True, + type=str2bool, + help='Whether or not a learnable bias should be added to the queries, keys and values.') + swin_parser.add_argument('swin_QK_SCALE', default = None, + type=float, + help='Override default qk scale of head_dim ** -0.5 if set.') + swin_parser.add_argument('swin_ape', default = False, + type=str2bool, + help="If True, add absolute position embedding to the patch embedding.") + swin_parser.add_argument('swin_path_norm', default = True, + type=str2bool, + help="If True, add normalization after patch embedding.") + swin_hparams = swin_parser.parse_args() + + + main(hparams, swin_hparams) diff --git a/SSLGlacier/models/RESNET.py b/SSLGlacier/models/RESNET.py index 9a094ef64799a71827ef9cde50265eb421ba8b7f..0ee1479bc44b2b36c51eda018bfee6eda5b95d2b 100644 --- a/SSLGlacier/models/RESNET.py +++ b/SSLGlacier/models/RESNET.py @@ -1,7 +1,7 @@ """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py.""" import torch from torch import nn -from SSLGlacier.models.base_net import Net +from SSLGlacier.models.SWINNET import SwinNet def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): @@ -122,7 +122,7 @@ class Bottleneck(nn.Module): return self.relu(out) -class ResNet(Net): +class ResNet(SwinNet): def __init__( self, block, diff --git a/SSLGlacier/models/base_net.py b/SSLGlacier/models/SWINNET.py similarity index 99% rename from SSLGlacier/models/base_net.py rename to SSLGlacier/models/SWINNET.py index 4489ec3baca6f158f7f03e4b1c84259b71d99806..9924814f641ceb58479ce6ec37daa5892bb42593 100644 --- a/SSLGlacier/models/base_net.py +++ b/SSLGlacier/models/SWINNET.py @@ -3,7 +3,7 @@ import torch from torch import nn -class Net(nn.Module): +class SwinNet(nn.Module): def __init__( self, groups=1, diff --git a/SSLGlacier/models/UNET.py b/SSLGlacier/models/UNET.py index ad4f65c88f11d852959a19ebf9e9a62fd59e6e5c..5061bb08e355f3a6c078f428051b9e17530aba43 100644 --- a/SSLGlacier/models/UNET.py +++ b/SSLGlacier/models/UNET.py @@ -1,9 +1,9 @@ import torch from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 -from SSLGlacier.models.base_net import Net +from SSLGlacier.models.SWINNET import SwinNet -class UNet(Net): +class UNet(SwinNet): """Pytorch Lightning implementation of U-Net. Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_ diff --git a/SSLGlacier/models/hook/Swin_Transformer.py b/SSLGlacier/models/hook/Swin_Transformer.py index 992126a2d74aad5b258785d4d9476fbc93ec407a..1315bfd20fcda9d2f12310ae640f01d093255ed1 100644 --- a/SSLGlacier/models/hook/Swin_Transformer.py +++ b/SSLGlacier/models/hook/Swin_Transformer.py @@ -2,12 +2,10 @@ import torch import torch.nn as nn import numpy as np from timm.models.layers import trunc_normal_ -from SSLGlacier.models.hook.patch_embedding import PatchEmbed -from SSLGlacier.models.hook.patch_merging import PatchMerging -from SSLGlacier.models.hook.patch_expand import PatchExpand, PatchExpandC, FinalPatchExpand_X4 -from SSLGlacier.models.hook.basic_swin_layer import BasicLayer -from SSLGlacier.models.hook.basic_swin_layer_up import BasicLayer_up -from SSLGlacier.models.hook.basic_swin_layer_cross import BasicLayer_Cross +from SSLGlacier.models.hook.patch_processing import PatchEmbed +from SSLGlacier.models.hook.patch_processing import PatchMerging +from SSLGlacier.models.hook.patch_processing import PatchExpand, PatchExpandC, FinalPatchExpand_X4 +from SSLGlacier.models.hook.basic_layers import BasicLayer, BasicLayer_up, BasicLayer_Cross torch.set_printoptions(threshold=np.inf) torch.set_printoptions(linewidth=300) diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py index 528757dc5facce15c00dc59694f81dead8590afc..4f961444452f693aa42731f3b19ec22e8a546f5b 100644 --- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py +++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py @@ -1,41 +1,48 @@ # coding=utf-8 -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import copy import logging - import torch -import torch.nn as nn - +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from SSLGlacier.models.SWINNET import SwinNet from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys logger = logging.getLogger(__name__) -class SwinUnet(nn.Module): - def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): - super(SwinUnet, self).__init__() +class SwinUnet(SwinNet): + def __init__(self, num_classes=21843, zero_head=False, **kwargs): + super().__init__( + groups=1, + widen=1, + width_per_group=64, + norm_layer=None, + output_dim=0 if 'output_dim' not in kwargs else kwargs['output_dim'], + hidden_mlp=0 if 'hidden_mlp' not in kwargs else kwargs['hidden_mlp'], + num_prototypes=0 if 'num_prototypes' not in kwargs else kwargs['num_prototypes'], + eval_mode=False, + normalize=False if 'normalize' not in kwargs else kwargs['normalize'] + ) + self.num_classes = num_classes self.zero_head = zero_head - self.config = config - self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWIN.PATCH_SIZE, - in_chans=config.MODEL.SWIN.IN_CHANS, + self.swin_unet = SwinTransformerSys(img_size=kwargs.image_size, + patch_size=kwargs.swin_patch_size, + in_chans=kwargs.swin_in_chans, num_classes=self.num_classes, - embed_dim=config.MODEL.SWIN.EMBED_DIM, - depths=config.MODEL.SWIN.DEPTHS, - num_heads=config.MODEL.SWIN.NUM_HEADS, - window_size=config.MODEL.SWIN.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWIN.MLP_RATIO, - qkv_bias=config.MODEL.SWIN.QKV_BIAS, - qk_scale=config.MODEL.SWIN.QK_SCALE, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWIN.APE, - patch_norm=config.MODEL.SWIN.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) + embed_dim=kwargs.swin_embed_dim, + depths=kwargs.swin_depths, + num_heads=kwargs.swin_num_heads, + window_size=kwargs.swin_window_size, + mlp_ratio=kwargs.swin_mlp_ratio, + qkv_bias=kwargs.swin.QKV_BIAS, + qk_scale=kwargs.swin.QK_SCALE, + drop_rate=kwargs.drop_rate, + drop_path_rate=kwargs.drop_path_rate, + ape=kwargs.swin_ape, + patch_norm=kwargs.swin_patch_norm) def forward(self, x, y): if x.size()[1] == 1: diff --git a/SSLGlacier/models/hook/basic_layers.py b/SSLGlacier/models/hook/basic_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..079d1cc4f7944bb707292826838d065735c52f04 --- /dev/null +++ b/SSLGlacier/models/hook/basic_layers.py @@ -0,0 +1,229 @@ +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from SSLGlacier.models.hook.patch_processing import PatchExpand +from SSLGlacier.models.hook.swin_transformer_blocks import SwinTransformerBlock, SwinTransformerBlock_Cross +#----------------------------------------------- +########### Basic Swin Layer Classes ########### +#----------------------------------------------- +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution # 56 + self.depth = depth # 2 + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, # 96, 56 + num_heads=num_heads, window_size=window_size, # 3, 14 + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, # 4 + qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None + drop=drop, attn_drop=attn_drop, # 0.2, 0 + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) # layer-norm + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops +#----------------------------------------------- +####### End of Basic Swin Layer Classes ######## +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +####### Basic Swin Layer cross Classes ######### +#----------------------------------------------- + +class BasicLayer_Cross(nn.Module): # Input of second attention is output of first. See paper + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution # 56 + self.depth = depth # 2 + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution, # 96, 56 + num_heads=num_heads, window_size=window_size, # 3, 14 + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, # 4 + qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None + drop=drop, attn_drop=attn_drop, # 0.2, 0 + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) # layer-norm + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, y): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, y) + if self.downsample is not None: + x = self.downsample(x) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops +#----------------------------------------------- +##### End of Basic Swin Layer cross Classes #### +#----------------------------------------------- + +################################################ +################################################ + + +#----------------------------------------------- +######### Basic Swin Layer UP Classes ########## +#----------------------------------------------- + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +#----------------------------------------------- +##### End of Basic Swin Layer UP Classes #### +#----------------------------------------------- diff --git a/SSLGlacier/models/hook/basic_swin_layer.py b/SSLGlacier/models/hook/basic_swin_layer.py deleted file mode 100644 index e9570cb96d1b2302cd25658c41b1977a8a40aa88..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/basic_swin_layer.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution # 56 - self.depth = depth # 2 - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, # 96, 56 - num_heads=num_heads, window_size=window_size, # 3, 14 - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, # 4 - qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None - drop=drop, attn_drop=attn_drop, # 0.2, 0 - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) # layer-norm - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - if self.downsample is not None: - x = self.downsample(x) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops \ No newline at end of file diff --git a/SSLGlacier/models/hook/basic_swin_layer_cross.py b/SSLGlacier/models/hook/basic_swin_layer_cross.py deleted file mode 100644 index e9481d94d488dca538b3ac308b96bbcf3d7cc3aa..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/basic_swin_layer_cross.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block_cross import SwinTransformerBlock_Cross - - -class BasicLayer_Cross(nn.Module): # Input of second attention is output of first. See paper - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution # 56 - self.depth = depth # 2 - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution, # 96, 56 - num_heads=num_heads, window_size=window_size, # 3, 14 - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, # 4 - qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None - drop=drop, attn_drop=attn_drop, # 0.2, 0 - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) # layer-norm - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, y): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x, y) - if self.downsample is not None: - x = self.downsample(x) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops diff --git a/SSLGlacier/models/hook/basic_swin_layer_up.py b/SSLGlacier/models/hook/basic_swin_layer_up.py deleted file mode 100644 index 9ee1b87675acf8a5e8a9d54a98c86f49122f8f56..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/basic_swin_layer_up.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock -from SSLGlacier.models.hook.patch_expand import PatchExpand - - -class BasicLayer_up(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if upsample is not None: - self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) - else: - self.upsample = None - - def forward(self, x): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - if self.upsample is not None: - x = self.upsample(x) - return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/mlp.py b/SSLGlacier/models/hook/mlp.py deleted file mode 100644 index 6e7730ffc7f0d4aa7c018084c4169e6b80693d36..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/mlp.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch.nn as nn - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/patch_embedding.py b/SSLGlacier/models/hook/patch_embedding.py deleted file mode 100644 index e9fba670109312057260b71a0371789c597c1e12..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/patch_embedding.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch.nn as nn -from timm.models.layers import to_2tuple - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size # 224 - self.patch_size = patch_size # 4 - self.patches_resolution = patches_resolution # 56 - self.num_patches = patches_resolution[0] * patches_resolution[1] # 3136 - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 22, 96, 56, 56 - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C (22, 3136, 96) - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops \ No newline at end of file diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_expand.py deleted file mode 100644 index ef397f8d147497ab1005ee7d2a39575879354d9f..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/patch_expand.py +++ /dev/null @@ -1,155 +0,0 @@ -import torch.nn as nn -from einops import rearrange -# import torch - - -class PatchExpand(nn.Module): - def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim # 96 - self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() - self.norm = norm_layer(dim // dim_scale) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - x = self.expand(x) # 7, 49, 768; 7, 49, 1536 - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - x = x.view(B, H, W, C) # 7, 7, 7, 1536 - x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 - x = x.view(B, -1, C // 4) - x = self.norm(x) - - return x - - -# class PatchExpand4(nn.Module): -# def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): -# super().__init__() -# self.input_resolution = input_resolution -# self.dim = dim # 96 -# self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity() -# self.norm = norm_layer(96) -# -# def forward(self, x): -# """ -# x: B, H*W, C -# """ -# H, W = self.input_resolution -# x = self.expand(x) # 7, 49, 768; 7, 49, 1536 -# B, L, C = x.shape -# assert L == H * W, "input feature has wrong size" -# x = x.view(B, H, W, C) # 7, 7, 7, 1536 -# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 -# x = x.view(B, -1, C // 4) -# x = self.norm(x) -# -# return x - - -class PatchExpandC(nn.Module): - def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim # 96 - self.expand = nn.Linear(dim, dim, bias=False) if dim_scale == 2 else nn.Identity() - self.norm = norm_layer(192) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - x = self.expand(x) # 7, 49, 768; 7, 49, 1536 - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - x = x.view(B, H, W, C) # 7, 7, 7, 1536 - x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 - x = x.view(B, -1, C // 4) - x = self.norm(x) - - return x - - -# class PCAa(object): -# def __init__(self, n_components=2): -# self.n_components = n_components -# -# def fit_trans(self, X): -# n = X.shape[1] -# self.mean = torch.mean(X, axis=1).unsqueeze(dim=1) -# -# X = X - self.mean -# covariance_matrix = 1 / n * torch.matmul(X.permute(0, 2, 1), X) -# -# eigenvalues, eigenvectors = torch.linalg.eig(covariance_matrix) -# -# eigenvalues = torch.real(eigenvalues) -# eigenvectors = torch.real(eigenvectors) -# -# idx = torch.argsort(-eigenvalues) -# -# eigenvectors = torch.gather(eigenvectors, 2, idx.unsqueeze(dim=1).expand(eigenvectors.shape)) -# self.proj_mat = eigenvectors[:, :, 0:self.n_components] -# -# Out = X.matmul(self.proj_mat) -# -# return Out -# -# -# class PatchExpandPCA(nn.Module): -# def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): -# super().__init__() -# self.input_resolution = input_resolution -# self.dim = dim # 96 -# # self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity() -# self.norm = norm_layer(96) -# -# def forward(self, x): -# """ -# x: B, H*W, C -# """ -# H, W = self.input_resolution -# pcaa = PCAa(n_components=384) -# x = pcaa.fit_trans(x) -# -# B, L, C = x.shape -# assert L == H * W, "input feature has wrong size" -# x = x.view(B, H, W, C) # 7, 7, 7, 1536 -# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 -# x = x.view(B, -1, C // 4) -# x = self.norm(x) -# -# return x - - -class FinalPatchExpand_X4(nn.Module): - def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.dim_scale = dim_scale - self.expand = nn.Linear(dim, 16 * dim, bias=False) - self.output_dim = dim - self.norm = norm_layer(self.output_dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - x = self.expand(x) - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - x = x.view(B, H, W, C) - x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, - c=C // (self.dim_scale ** 2)) - x = x.view(B, -1, self.output_dim) - x = self.norm(x) - - return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/patch_merging.py b/SSLGlacier/models/hook/patch_merging.py deleted file mode 100644 index 0567504b51f4936d8a7152fcfc5d37547e2f7ff9..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/patch_merging.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - diff --git a/SSLGlacier/models/hook/patch_processing.py b/SSLGlacier/models/hook/patch_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e1f3309c2e8e8edd64b70986c1e18b5886bebc --- /dev/null +++ b/SSLGlacier/models/hook/patch_processing.py @@ -0,0 +1,276 @@ +import torch +import torch.nn as nn +from einops import rearrange +from timm.models.layers import to_2tuple +#----------------------------------------------- +############## Patch Merging Class ########### +#----------------------------------------------- +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +#----------------------------------------------- +######### End of Patch Merging Class ########### +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +############## Patch Embedding Class ########### +#----------------------------------------------- +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size # 224 + self.patch_size = patch_size # 4 + self.patches_resolution = patches_resolution # 56 + self.num_patches = patches_resolution[0] * patches_resolution[1] # 3136 + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 22, 96, 56, 56 + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C (22, 3136, 96) + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops +#----------------------------------------------- +######### End of Patch Embeding Class ########### +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +############# Patch Expand Classes ############# +#----------------------------------------------- +class PatchExpand(nn.Module): + """ + from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9 + Take the first patch expanding layer as an example, before up-sampling, + a linear layer is applied on the input features (W/32, H/32, 8C) to increase the feature dimension to + 2 * the original dimension (W/32, H/32, 16C). Then, we use rearrange operation to expand the resolution + of the input features to 2*the input resolution and reduce the feature dimension to + quarter of the input dimension ((W/32, H/32, 16C) --> (W/16, H/16, 4C)). + We will discuss the impact of using patch expanding layer + to perform up-sampling inSect. 4.5 in https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9. + """ + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim # 96 + self.expand = nn.Linear(dim, 2 * dim, bias=False) if dim_scale == 2 else nn.Identity() + self.norm = norm_layer(dim // dim_scale) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) # 7, 49, 768; 7, 49, 1536 + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + x = x.view(B, H, W, C) # 7, 7, 7, 1536 + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 + x = x.view(B, -1, C // 4) + x = self.norm(x) + + return x + +# class PatchExpand4(nn.Module): +# def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): +# super().__init__() +# self.input_resolution = input_resolution +# self.dim = dim # 96 +# self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity() +# self.norm = norm_layer(96) +# +# def forward(self, x): +# """ +# x: B, H*W, C +# """ +# H, W = self.input_resolution +# x = self.expand(x) # 7, 49, 768; 7, 49, 1536 +# B, L, C = x.shape +# assert L == H * W, "input feature has wrong size" +# x = x.view(B, H, W, C) # 7, 7, 7, 1536 +# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 +# x = x.view(B, -1, C // 4) +# x = self.norm(x) +# +# return x +class PatchExpandC(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim # 96 + self.expand = nn.Linear(dim, dim, bias=False) if dim_scale == 2 else nn.Identity() + self.norm = norm_layer(192) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) # 7, 49, 768; 7, 49, 1536 + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + x = x.view(B, H, W, C) # 7, 7, 7, 1536 + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 + x = x.view(B, -1, C // 4) + x = self.norm(x) + + return x + +# class PCAa(object): +# def __init__(self, n_components=2): +# self.n_components = n_components +# +# def fit_trans(self, X): +# n = X.shape[1] +# self.mean = torch.mean(X, axis=1).unsqueeze(dim=1) +# +# X = X - self.mean +# covariance_matrix = 1 / n * torch.matmul(X.permute(0, 2, 1), X) +# +# eigenvalues, eigenvectors = torch.linalg.eig(covariance_matrix) +# +# eigenvalues = torch.real(eigenvalues) +# eigenvectors = torch.real(eigenvectors) +# +# idx = torch.argsort(-eigenvalues) +# +# eigenvectors = torch.gather(eigenvectors, 2, idx.unsqueeze(dim=1).expand(eigenvectors.shape)) +# self.proj_mat = eigenvectors[:, :, 0:self.n_components] +# +# Out = X.matmul(self.proj_mat) +# +# return Out +# +# +# class PatchExpandPCA(nn.Module): +# def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): +# super().__init__() +# self.input_resolution = input_resolution +# self.dim = dim # 96 +# # self.expand = nn.Linear(dim, 384, bias=False) if dim_scale == 2 else nn.Identity() +# self.norm = norm_layer(96) +# +# def forward(self, x): +# """ +# x: B, H*W, C +# """ +# H, W = self.input_resolution +# pcaa = PCAa(n_components=384) +# x = pcaa.fit_trans(x) +# +# B, L, C = x.shape +# assert L == H * W, "input feature has wrong size" +# x = x.view(B, H, W, C) # 7, 7, 7, 1536 +# x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C // 4) # 7, 14, 14, 384 +# x = x.view(B, -1, C // 4) +# x = self.norm(x) +# +# return x + +class FinalPatchExpand_X4(nn.Module): + def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.dim_scale = dim_scale + self.expand = nn.Linear(dim, 16 * dim, bias=False) + self.output_dim = dim + self.norm = norm_layer(self.output_dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + x = self.expand(x) + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, + c=C // (self.dim_scale ** 2)) + x = x.view(B, -1, self.output_dim) + x = self.norm(x) + + return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py deleted file mode 100644 index dd8938ca0fd03270fff060485a4ff92cfc3ab775..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/swin_transformer_block.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import DropPath, to_2tuple -from models.window_attention import WindowAttention -from models.mlp import Mlp -from models.window_utils import window_partition, window_reverse - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size # 0 or 7 - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution # 56, 56 - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 - h_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - - w_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, - self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( - 0.0)) # 64, 49, 49 # 16, 196, 196 - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - x_windows = x_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - - x = x.view(B, H * W, C) # 7, 56*56, 96 - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops diff --git a/SSLGlacier/models/hook/swin_transformer_block_cross.py b/SSLGlacier/models/hook/swin_transformer_block_cross.py deleted file mode 100644 index 8dbfb6719a5718f3e04bbc19a728afc98fbe6219..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/swin_transformer_block_cross.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import DropPath, to_2tuple -from models.window_cross_attention import WindowAttention_Cross -from models.mlp import Mlp -from models.window_utils import window_partition, window_reverse - - -class SwinTransformerBlock_Cross(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size # 0 or 7 - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention_Cross( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution # 56, 56 - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 - h_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - - w_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, - self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( - 0.0)) # 64, 49, 49 # 16, 196, 196 - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, target, context): - H, W = self.input_resolution - B, L, C = target.shape - assert L == H * W, "input feature has wrong size" - - shortcut = target - target = self.norm1(target) - target = target.view(B, H, W, C) - - context = self.norm1(context) - context = context.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_target = target - shifted_context = context - - # partition windows - target_windows = window_partition(shifted_target, - self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - context_windows = window_partition(shifted_context, - self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - # print(x_windows.shape, 'ss') - target_windows = target_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - context_windows = context_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - # W-MSA/SW-MSA - attn_windows = self.attn(target_windows, context_windows, - mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 - shifted_target = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 - # reverse cyclic shift - if self.shift_size > 0: - target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - target = shifted_target - - target = target.view(B, H * W, C) # 7, 56*56, 96 - - # FFN - target = shortcut + self.drop_path(target) - target = target + self.drop_path(self.mlp(self.norm2(target))) - - return target - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops diff --git a/SSLGlacier/models/hook/swin_transformer_blocks.py b/SSLGlacier/models/hook/swin_transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..f6745c8a505c38f52232948aaa4dcea6edb29c71 --- /dev/null +++ b/SSLGlacier/models/hook/swin_transformer_blocks.py @@ -0,0 +1,296 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple +from SSLGlacier.models.hook.window_attentions import WindowAttention, WindowAttention_Cross, window_partition, window_reverse + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class SwinTransformerBlock_Cross(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size # 0 or 7 + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention_Cross( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution # 56, 56 + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 + h_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + + w_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, + self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( + 0.0)) # 64, 49, 49 # 16, 196, 196 + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, target, context): + H, W = self.input_resolution + B, L, C = target.shape + assert L == H * W, "input feature has wrong size" + + shortcut = target + target = self.norm1(target) + target = target.view(B, H, W, C) + + context = self.norm1(context) + context = context.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_target = target + shifted_context = context + + # partition windows + target_windows = window_partition(shifted_target, + self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + context_windows = window_partition(shifted_context, + self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + # print(x_windows.shape, 'ss') + target_windows = target_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + context_windows = context_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + # W-MSA/SW-MSA + attn_windows = self.attn(target_windows, context_windows, + mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 + shifted_target = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 + # reverse cyclic shift + if self.shift_size > 0: + target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + target = shifted_target + + target = target.view(B, H * W, C) # 7, 56*56, 96 + + # FFN + target = shortcut + self.drop_path(target) + target = target + self.drop_path(self.mlp(self.norm2(target))) + + return target + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size # 0 or 7 + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution # 56, 56 + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 + h_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + + w_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, + self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( + 0.0)) # 64, 49, 49 # 16, 196, 196 + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + x = x.view(B, H * W, C) # 7, 56*56, 96 + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops diff --git a/SSLGlacier/models/hook/window_attention.py b/SSLGlacier/models/hook/window_attention.py deleted file mode 100644 index 39e9be5a9db16d19fa596a7b60085a2d44676d6d..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/window_attention.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import trunc_normal_ - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads # 3, 6, 12, 24 - head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 - - self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) - - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 - - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 - - coords_flatten = torch.flatten(coords, 1) # 2, 49 - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 - relative_position_index = relative_coords.sum(-1) # 49, 49 - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape # -1, 49, 96 - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, - 4) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # 448, 3, 49, 32 - - q = q * self.scale # 1/np.sqrt(32) - attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], - -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 - - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 - attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( - 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 - attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops diff --git a/SSLGlacier/models/hook/window_attentions.py b/SSLGlacier/models/hook/window_attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..422547958583f2218f1a7451ea90a600eed4748f --- /dev/null +++ b/SSLGlacier/models/hook/window_attentions.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads # 3, 6, 12, 24 + head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 + + self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 + + coords_flatten = torch.flatten(coords, 1) # 2, 49 + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 + relative_position_index = relative_coords.sum(-1) # 49, 49 + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape # -1, 49, 96 + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, + 4) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # 448, 3, 49, 32 + + q = q * self.scale # 1/np.sqrt(32) + attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 + attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 + attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class WindowAttention_Cross(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads # 3, 6, 12, 24 + head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 + + self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 + + coords_flatten = torch.flatten(coords, 1) # 2, 49 + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 + relative_position_index = relative_coords.sum(-1) # 49, 49 + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape # -1, 49, 96 + + q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + + q = q * self.scale # 1/np.sqrt(32) + attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 + attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 + attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape # 7, 56, 56, 96; 1, 56, 56, 1 + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) # 7, 8, 7, 8, 7, 96 # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1 + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, + C) # 7, 8, 8, 7, 7, 96; -1, 7, 7, 96 + # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1 + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) # 448 / (56*56/7/7) = 7 + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # 7, 8, 8, 7, 7, 96 + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # 7, 8, 7, 8, 96; 7, 56, 56, 96 + return x + diff --git a/SSLGlacier/models/hook/window_cross_attention.py b/SSLGlacier/models/hook/window_cross_attention.py deleted file mode 100644 index 7b5ddfde667f4b5269523bb189fa743bfd3a01dd..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/window_cross_attention.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import trunc_normal_ - - -class WindowAttention_Cross(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads # 3, 6, 12, 24 - head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 - - self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) - - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 - - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 - - coords_flatten = torch.flatten(coords, 1) # 2, 49 - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 - relative_position_index = relative_coords.sum(-1) # 49, 49 - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, y, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape # -1, 49, 96 - - q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - - q = q * self.scale # 1/np.sqrt(32) - attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], - -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 - - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 - attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( - 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 - attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - diff --git a/SSLGlacier/models/hook/window_utils.py b/SSLGlacier/models/hook/window_utils.py deleted file mode 100644 index fe08bade963ede9938ef18fde95624f4b43a98f9..0000000000000000000000000000000000000000 --- a/SSLGlacier/models/hook/window_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape # 7, 56, 56, 96; 1, 56, 56, 1 - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) # 7, 8, 7, 8, 7, 96 # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1 - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, - C) # 7, 8, 8, 7, 7, 96; -1, 7, 7, 96 - # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1 - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) # 448 / (56*56/7/7) = 7 - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # 7, 8, 8, 7, 7, 96 - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # 7, 8, 7, 8, 96; 7, 56, 56, 96 - return x - diff --git a/SSLGlacier/processing/agumentations_.py b/SSLGlacier/modules/agumentations_.py similarity index 100% rename from SSLGlacier/processing/agumentations_.py rename to SSLGlacier/modules/agumentations_.py diff --git a/SSLGlacier/processing/datamodule_.py b/SSLGlacier/modules/datamodule_.py similarity index 100% rename from SSLGlacier/processing/datamodule_.py rename to SSLGlacier/modules/datamodule_.py diff --git a/SSLGlacier/processing/semi_module.py b/SSLGlacier/modules/semi_module.py similarity index 100% rename from SSLGlacier/processing/semi_module.py rename to SSLGlacier/modules/semi_module.py diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/modules/swav_module.py similarity index 92% rename from SSLGlacier/processing/swav_module.py rename to SSLGlacier/modules/swav_module.py index 970484bc3afc2c58beb2bda198771c6a45b796b5..3b57a4ede5265846c6a4aeb4b885d3d31169c206 100644 --- a/SSLGlacier/processing/swav_module.py +++ b/SSLGlacier/modules/swav_module.py @@ -14,9 +14,11 @@ from SSLGlacier.models.RESNET import resnet18, resnet50 from SSLGlacier.models.UNET import UNet #The main network from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay -#from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet - +##########Importing files for hook############ +from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet +######Temporary config file for hooknet####### +from utils.config import get_config class SwAV(LightningModule): def __init__( @@ -51,7 +53,7 @@ class SwAV(LightningModule): weight_decay: float = 1e-6, epsilon: float = 0.05, just_aug_for_same_assign_views: bool = False, - **kwargs + swin_hparams = {} ) -> None: """ Args: @@ -148,6 +150,8 @@ class SwAV(LightningModule): self.train_iters_per_epoch = self.num_samples // global_batch_size self.queue = None + ####For hook#### + self.swin_hparams = swin_hparams def setup(self, stage): if self.queue_length > 0: queue_folder = os.path.join(self.logger.log_dir, self.queue_path) @@ -176,12 +180,30 @@ class SwAV(LightningModule): first_conv=self.first_conv, maxpool1=self.maxpool1) elif self.arch == "hook": - backbone = resnet50(normalize=True, + backbone = SwinUnet( + img_size=224, + num_classes=5, + normalize=True, hidden_mlp=self.hidden_mlp, output_dim=self.feat_dim, num_prototypes=self.num_prototypes, first_conv=self.first_conv, - maxpool1=self.maxpool1) + maxpool1=self.maxpool1,######till here for basenet + image_size= self.swin_hparams.image_size,####from here for swin itself + swin_patch_size = self.swin_hparams.swin_patch_size, + swin_in_chans=self.swin_hparams.swin_in_chans, + swin_embed_dim=self.swin_hparams.swin_embed_dim, + swin_depths=self.swin_hparams.swin_depths, + swin_num_heads=self.swin_hparams.swin_num_heads, + swin_window_size=self.swin_hparams.swin_window_size, + swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio, + swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS, + swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + swin_ape=self.swin_hparams.swin_ape, + swin_patch_norm=self.swin_hparams.swin_path_norm + ) elif self.arch == "Unet" or "unet": backbone = UNet(num_classes=5, diff --git a/SSLGlacier/processing/transformers_.py b/SSLGlacier/modules/transformers_.py similarity index 100% rename from SSLGlacier/processing/transformers_.py rename to SSLGlacier/modules/transformers_.py diff --git a/SSLGlacier/utils/config.py b/SSLGlacier/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c0b002125d72a899336c30c614afd9a4ff0c4238 --- /dev/null +++ b/SSLGlacier/utils/config.py @@ -0,0 +1,236 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# --------------------------------------------------------' + +import os +import yaml +from yacs.config import CfgNode as CN + +_C = CN() + +# Base config files +_C.BASE = [''] + +# ----------------------------------------------------------------------------- +# Data settings +# ----------------------------------------------------------------------------- +_C.DATA = CN() +# Batch size for a single GPU, could be overwritten by command line argument +_C.DATA.BATCH_SIZE = 128 +# Path to dataset, could be overwritten by command line argument +_C.DATA.DATA_PATH = '' +# Dataset name +_C.DATA.DATASET = 'imagenet' +# Input image size +_C.DATA.IMG_SIZE = 224 +# _C.DATA.IMG_SIZE = 56 +# Interpolation to resize image (random, bilinear, bicubic) +_C.DATA.INTERPOLATION = 'bicubic' +# Use zipped dataset instead of folder dataset +# could be overwritten by command line argument +_C.DATA.ZIP_MODE = False +# Cache Data in Memory, could be overwritten by command line argument +_C.DATA.CACHE_MODE = 'part' +# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. +_C.DATA.PIN_MEMORY = True +# Number of data loading threads +_C.DATA.NUM_WORKERS = 8 + +# ----------------------------------------------------------------------------- +# Model settings +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Model type +_C.MODEL.TYPE = 'swin' +# Model name +_C.MODEL.NAME = 'swin_tiny_patch4_window7_224' +# Checkpoint to resume, could be overwritten by command line argument +# _C.MODEL.PRETRAIN_CKPT = '/home/wf/Swin-Unet-main/output/epoch_150.pth' +_C.MODEL.PRETRAIN_CKPT = None +_C.MODEL.RESUME = '' +# Number of classes, overwritten in data preparation +_C.MODEL.NUM_CLASSES = 1000 +# Dropout rate +_C.MODEL.DROP_RATE = 0.0 +# Drop path rate +_C.MODEL.DROP_PATH_RATE = 0.1 +# Label Smoothing +_C.MODEL.LABEL_SMOOTHING = 0.1 + +# Swin Transformer parameters +_C.MODEL.SWIN = CN() +_C.MODEL.SWIN.PATCH_SIZE = 4 +# _C.MODEL.SWIN.PATCH_SIZE = 2 +_C.MODEL.SWIN.IN_CHANS = 3 +# _C.MODEL.SWIN.IN_CHANS = 1 +_C.MODEL.SWIN.EMBED_DIM = 96 +_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] +_C.MODEL.SWIN.WINDOW_SIZE = 7 +# _C.MODEL.SWIN.WINDOW_SIZE = 2 +_C.MODEL.SWIN.MLP_RATIO = 4. +_C.MODEL.SWIN.QKV_BIAS = True +_C.MODEL.SWIN.QK_SCALE = None +_C.MODEL.SWIN.APE = False +_C.MODEL.SWIN.PATCH_NORM = True +_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.START_EPOCH = 0 +_C.TRAIN.EPOCHS = 300 +_C.TRAIN.WARMUP_EPOCHS = 20 +_C.TRAIN.WEIGHT_DECAY = 0.05 +_C.TRAIN.BASE_LR = 5e-4 +_C.TRAIN.WARMUP_LR = 5e-7 +_C.TRAIN.MIN_LR = 5e-6 +# Clip gradient norm +_C.TRAIN.CLIP_GRAD = 5.0 +# Auto resume from latest checkpoint +_C.TRAIN.AUTO_RESUME = True +# Gradient accumulation steps +# could be overwritten by command line argument +_C.TRAIN.ACCUMULATION_STEPS = 0 +# Whether to use gradient checkpointing to save memory +# could be overwritten by command line argument +_C.TRAIN.USE_CHECKPOINT = False + +# LR scheduler +_C.TRAIN.LR_SCHEDULER = CN() +_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' +# Epoch interval to decay LR, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 +# LR decay rate, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 + +# Optimizer +_C.TRAIN.OPTIMIZER = CN() +_C.TRAIN.OPTIMIZER.NAME = 'adamw' # - NOT USED +# Optimizer Epsilon +_C.TRAIN.OPTIMIZER.EPS = 1e-8 +# Optimizer Betas +_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) +# SGD momentum +_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 + +# ----------------------------------------------------------------------------- +# Augmentation settings - NOT USED +# ----------------------------------------------------------------------------- +_C.AUG = CN() +# Color jitter factor +_C.AUG.COLOR_JITTER = 0.4 +# Use AutoAugment policy. "v0" or "original" +_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' +# Random erase prob +_C.AUG.REPROB = 0.25 +# Random erase mode +_C.AUG.REMODE = 'pixel' +# Random erase count +_C.AUG.RECOUNT = 1 +# Mixup alpha, mixup enabled if > 0 +_C.AUG.MIXUP = 0.8 +# Cutmix alpha, cutmix enabled if > 0 +_C.AUG.CUTMIX = 1.0 +# Cutmix min/max ratio, overrides alpha and enables cutmix if set +_C.AUG.CUTMIX_MINMAX = None +# Probability of performing mixup or cutmix when either/both is enabled +_C.AUG.MIXUP_PROB = 1.0 +# Probability of switching to cutmix when both mixup and cutmix enabled +_C.AUG.MIXUP_SWITCH_PROB = 0.5 +# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" +_C.AUG.MIXUP_MODE = 'batch' + +# ----------------------------------------------------------------------------- +# Testing settings +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# Whether to use center crop when testing +_C.TEST.CROP = True + +# ----------------------------------------------------------------------------- +# Misc +# ----------------------------------------------------------------------------- +# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') +# overwritten by command line argument +_C.AMP_OPT_LEVEL = '' +# Path to output folder, overwritten by command line argument +_C.OUTPUT = '' +# Tag of experiment, overwritten by command line argument +_C.TAG = 'default' +# Frequency to save checkpoint +_C.SAVE_FREQ = 1 +# Frequency to logging info +_C.PRINT_FREQ = 10 +# Fixed random seed +_C.SEED = 0 +# Perform evaluation only, overwritten by command line argument +_C.EVAL_MODE = False +# Test throughput only, overwritten by command line argument +_C.THROUGHPUT_MODE = False +# local rank for DistributedDataParallel, given by command line argument +_C.LOCAL_RANK = 0 + + +# NOT USED FROM HERE ON + +def _update_config_from_file(config, cfg_file): + config.defrost() + with open(cfg_file, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) + + for cfg in yaml_cfg.setdefault('BASE', ['']): + if cfg: + _update_config_from_file( + config, os.path.join(os.path.dirname(cfg_file), cfg) + ) + print('=> merge config from {}'.format(cfg_file)) + config.merge_from_file(cfg_file) + config.freeze() + + +def update_config(config, args): + _update_config_from_file(config, args.cfg) + + config.defrost() + if args.opts: + config.merge_from_list(args.opts) + + # merge from specific arguments + if args.batch_size: + config.DATA.BATCH_SIZE = args.batch_size + if args.zip: + config.DATA.ZIP_MODE = True + if args.cache_mode: + config.DATA.CACHE_MODE = args.cache_mode + if args.resume: + config.MODEL.RESUME = args.resume + if args.accumulation_steps: + config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps + if args.use_checkpoint: + config.TRAIN.USE_CHECKPOINT = True + if args.amp_opt_level: + config.AMP_OPT_LEVEL = args.amp_opt_level + if args.tag: + config.TAG = args.tag + if args.eval: + config.EVAL_MODE = True + if args.throughput: + config.THROUGHPUT_MODE = True + + config.freeze() + + +def get_config(args): + """Get a yacs CfgNode object with default values.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + config = _C.clone() + update_config(config, args) + + return config