From 0013a50e923782dc64ee559f1591d088d7b66679 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 12:53:17 +0100 Subject: [PATCH 01/11] Add a hook net config file, temporary: Just adding a config file in utils, for setting hooknet parammeters. --- SSLGlacier/utils/config.py | 236 +++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 SSLGlacier/utils/config.py diff --git a/SSLGlacier/utils/config.py b/SSLGlacier/utils/config.py new file mode 100644 index 0000000..c0b0021 --- /dev/null +++ b/SSLGlacier/utils/config.py @@ -0,0 +1,236 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# --------------------------------------------------------' + +import os +import yaml +from yacs.config import CfgNode as CN + +_C = CN() + +# Base config files +_C.BASE = [''] + +# ----------------------------------------------------------------------------- +# Data settings +# ----------------------------------------------------------------------------- +_C.DATA = CN() +# Batch size for a single GPU, could be overwritten by command line argument +_C.DATA.BATCH_SIZE = 128 +# Path to dataset, could be overwritten by command line argument +_C.DATA.DATA_PATH = '' +# Dataset name +_C.DATA.DATASET = 'imagenet' +# Input image size +_C.DATA.IMG_SIZE = 224 +# _C.DATA.IMG_SIZE = 56 +# Interpolation to resize image (random, bilinear, bicubic) +_C.DATA.INTERPOLATION = 'bicubic' +# Use zipped dataset instead of folder dataset +# could be overwritten by command line argument +_C.DATA.ZIP_MODE = False +# Cache Data in Memory, could be overwritten by command line argument +_C.DATA.CACHE_MODE = 'part' +# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. +_C.DATA.PIN_MEMORY = True +# Number of data loading threads +_C.DATA.NUM_WORKERS = 8 + +# ----------------------------------------------------------------------------- +# Model settings +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Model type +_C.MODEL.TYPE = 'swin' +# Model name +_C.MODEL.NAME = 'swin_tiny_patch4_window7_224' +# Checkpoint to resume, could be overwritten by command line argument +# _C.MODEL.PRETRAIN_CKPT = '/home/wf/Swin-Unet-main/output/epoch_150.pth' +_C.MODEL.PRETRAIN_CKPT = None +_C.MODEL.RESUME = '' +# Number of classes, overwritten in data preparation +_C.MODEL.NUM_CLASSES = 1000 +# Dropout rate +_C.MODEL.DROP_RATE = 0.0 +# Drop path rate +_C.MODEL.DROP_PATH_RATE = 0.1 +# Label Smoothing +_C.MODEL.LABEL_SMOOTHING = 0.1 + +# Swin Transformer parameters +_C.MODEL.SWIN = CN() +_C.MODEL.SWIN.PATCH_SIZE = 4 +# _C.MODEL.SWIN.PATCH_SIZE = 2 +_C.MODEL.SWIN.IN_CHANS = 3 +# _C.MODEL.SWIN.IN_CHANS = 1 +_C.MODEL.SWIN.EMBED_DIM = 96 +_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] +_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] +_C.MODEL.SWIN.WINDOW_SIZE = 7 +# _C.MODEL.SWIN.WINDOW_SIZE = 2 +_C.MODEL.SWIN.MLP_RATIO = 4. +_C.MODEL.SWIN.QKV_BIAS = True +_C.MODEL.SWIN.QK_SCALE = None +_C.MODEL.SWIN.APE = False +_C.MODEL.SWIN.PATCH_NORM = True +_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.START_EPOCH = 0 +_C.TRAIN.EPOCHS = 300 +_C.TRAIN.WARMUP_EPOCHS = 20 +_C.TRAIN.WEIGHT_DECAY = 0.05 +_C.TRAIN.BASE_LR = 5e-4 +_C.TRAIN.WARMUP_LR = 5e-7 +_C.TRAIN.MIN_LR = 5e-6 +# Clip gradient norm +_C.TRAIN.CLIP_GRAD = 5.0 +# Auto resume from latest checkpoint +_C.TRAIN.AUTO_RESUME = True +# Gradient accumulation steps +# could be overwritten by command line argument +_C.TRAIN.ACCUMULATION_STEPS = 0 +# Whether to use gradient checkpointing to save memory +# could be overwritten by command line argument +_C.TRAIN.USE_CHECKPOINT = False + +# LR scheduler +_C.TRAIN.LR_SCHEDULER = CN() +_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' +# Epoch interval to decay LR, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 +# LR decay rate, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 + +# Optimizer +_C.TRAIN.OPTIMIZER = CN() +_C.TRAIN.OPTIMIZER.NAME = 'adamw' # - NOT USED +# Optimizer Epsilon +_C.TRAIN.OPTIMIZER.EPS = 1e-8 +# Optimizer Betas +_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) +# SGD momentum +_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 + +# ----------------------------------------------------------------------------- +# Augmentation settings - NOT USED +# ----------------------------------------------------------------------------- +_C.AUG = CN() +# Color jitter factor +_C.AUG.COLOR_JITTER = 0.4 +# Use AutoAugment policy. "v0" or "original" +_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' +# Random erase prob +_C.AUG.REPROB = 0.25 +# Random erase mode +_C.AUG.REMODE = 'pixel' +# Random erase count +_C.AUG.RECOUNT = 1 +# Mixup alpha, mixup enabled if > 0 +_C.AUG.MIXUP = 0.8 +# Cutmix alpha, cutmix enabled if > 0 +_C.AUG.CUTMIX = 1.0 +# Cutmix min/max ratio, overrides alpha and enables cutmix if set +_C.AUG.CUTMIX_MINMAX = None +# Probability of performing mixup or cutmix when either/both is enabled +_C.AUG.MIXUP_PROB = 1.0 +# Probability of switching to cutmix when both mixup and cutmix enabled +_C.AUG.MIXUP_SWITCH_PROB = 0.5 +# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" +_C.AUG.MIXUP_MODE = 'batch' + +# ----------------------------------------------------------------------------- +# Testing settings +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# Whether to use center crop when testing +_C.TEST.CROP = True + +# ----------------------------------------------------------------------------- +# Misc +# ----------------------------------------------------------------------------- +# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') +# overwritten by command line argument +_C.AMP_OPT_LEVEL = '' +# Path to output folder, overwritten by command line argument +_C.OUTPUT = '' +# Tag of experiment, overwritten by command line argument +_C.TAG = 'default' +# Frequency to save checkpoint +_C.SAVE_FREQ = 1 +# Frequency to logging info +_C.PRINT_FREQ = 10 +# Fixed random seed +_C.SEED = 0 +# Perform evaluation only, overwritten by command line argument +_C.EVAL_MODE = False +# Test throughput only, overwritten by command line argument +_C.THROUGHPUT_MODE = False +# local rank for DistributedDataParallel, given by command line argument +_C.LOCAL_RANK = 0 + + +# NOT USED FROM HERE ON + +def _update_config_from_file(config, cfg_file): + config.defrost() + with open(cfg_file, 'r') as f: + yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) + + for cfg in yaml_cfg.setdefault('BASE', ['']): + if cfg: + _update_config_from_file( + config, os.path.join(os.path.dirname(cfg_file), cfg) + ) + print('=> merge config from {}'.format(cfg_file)) + config.merge_from_file(cfg_file) + config.freeze() + + +def update_config(config, args): + _update_config_from_file(config, args.cfg) + + config.defrost() + if args.opts: + config.merge_from_list(args.opts) + + # merge from specific arguments + if args.batch_size: + config.DATA.BATCH_SIZE = args.batch_size + if args.zip: + config.DATA.ZIP_MODE = True + if args.cache_mode: + config.DATA.CACHE_MODE = args.cache_mode + if args.resume: + config.MODEL.RESUME = args.resume + if args.accumulation_steps: + config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps + if args.use_checkpoint: + config.TRAIN.USE_CHECKPOINT = True + if args.amp_opt_level: + config.AMP_OPT_LEVEL = args.amp_opt_level + if args.tag: + config.TAG = args.tag + if args.eval: + config.EVAL_MODE = True + if args.throughput: + config.THROUGHPUT_MODE = True + + config.freeze() + + +def get_config(args): + """Get a yacs CfgNode object with default values.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + config = _C.clone() + update_config(config, args) + + return config -- GitLab From 1ce7ec90c0db9ebbc2214aef59a553dbbd8e2067 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 15:20:07 +0100 Subject: [PATCH 02/11] Adjustmentations: Pass parameters as agrument parset instead of reading them from config file. just for puting it in the same harmony as the rest of the code. Otherwise I like the config file idea. --- .../models/hook/Swin_Transformer_Wrapper.py | 49 +++++++++++-------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py index 528757d..69e68a9 100644 --- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py +++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py @@ -10,32 +10,41 @@ import torch import torch.nn as nn from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys - +from SSLGlacier.models.base_net import Net logger = logging.getLogger(__name__) -class SwinUnet(nn.Module): - def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): - super(SwinUnet, self).__init__() +class SwinUnet(Net): + def __init__(self, num_classes=21843, zero_head=False, **kwargs): + super().__init__( + groups=1, + widen=1, + width_per_group=64, + norm_layer=None, + output_dim=0 if 'output_dim' not in kwargs else kwargs['output_dim'], + hidden_mlp=0 if 'hidden_mlp' not in kwargs else kwargs['hidden_mlp'], + num_prototypes=0 if 'num_prototypes' not in kwargs else kwargs['num_prototypes'], + eval_mode=False, + normalize=False if 'normalize' not in kwargs else kwargs['normalize'] + ) + self.num_classes = num_classes self.zero_head = zero_head - self.config = config - self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWIN.PATCH_SIZE, - in_chans=config.MODEL.SWIN.IN_CHANS, + self.swin_unet = SwinTransformerSys(img_size=kwargs.image_size, + patch_size=kwargs.swin_patch_size, + in_chans=kwargs.swin_in_chans, num_classes=self.num_classes, - embed_dim=config.MODEL.SWIN.EMBED_DIM, - depths=config.MODEL.SWIN.DEPTHS, - num_heads=config.MODEL.SWIN.NUM_HEADS, - window_size=config.MODEL.SWIN.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWIN.MLP_RATIO, - qkv_bias=config.MODEL.SWIN.QKV_BIAS, - qk_scale=config.MODEL.SWIN.QK_SCALE, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWIN.APE, - patch_norm=config.MODEL.SWIN.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) + embed_dim=kwargs.swin_embed_dim, + depths=kwargs.swin_depths, + num_heads=kwargs.swin_num_heads, + window_size=kwargs.swin_window_size, + mlp_ratio=kwargs.swin_mlp_ratio, + qkv_bias=kwargs.swin.QKV_BIAS, + qk_scale=kwargs.swin.QK_SCALE, + drop_rate=kwargs.drop_rate, + drop_path_rate=kwargs.drop_path_rate, + ape=kwargs.swin_ape, + patch_norm=kwargs.swin_patch_norm) def forward(self, x, y): if x.size()[1] == 1: -- GitLab From dd971ad725aa9f886e6a7c5e5bb70b1809732fd3 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 15:21:17 +0100 Subject: [PATCH 03/11] Adjutmentations for HookNet: Add another argument parser just for swin net. --- SSLGlacier/main.py | 54 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py index 135ed7a..61fc68d 100644 --- a/SSLGlacier/main.py +++ b/SSLGlacier/main.py @@ -16,7 +16,9 @@ import numpy as np from utils.utils import str2bool, getDataPath -def main(hparams): + + +def main(hparams, swin_hparams): # TODO Ask NORA, Why do you use clip gradient? did you test it? ######save check points########## checkpoint_dirs = os.path.join('checkpoints', f'ssl_zone_segmentation_{hparams.arch}') @@ -78,7 +80,8 @@ def main(hparams): max_pool1d=False, learning_rate=hparams.base_lr, just_aug_for_same_assign_views=hparams.just_aug_for_same_assign_views, - crops_for_assign=hparams.views_for_assign + crops_for_assign=hparams.views_for_assign, + hparams=swin_hparams ) online_evaluator = SSLOnlineEvaluator( @@ -139,9 +142,9 @@ if __name__ == '__main__': ######################### # Which agumented image should be used for assignment parser.add_argument("--views_for_assign", type=int, nargs="+", default=[0, 1], - help="list of agumented views id used for computing assignments") + help="list of augmented views id used for computing assignments") parser.add_argument("--just_aug_for_same_assign_views", type=str2bool, default=True, - help="If you want to consider agumented version of the assigned view for swap assignment.") + help="If you want to consider augmented version of the assigned view for swap assignment.") parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") parser.add_argument("--epsilons", default=0.05, type=float, @@ -200,8 +203,45 @@ if __name__ == '__main__': 'noise': 0, 'color_jitter': 0, 'gauss_blur': 0, 'salt_or_pepper': 0, 'poission': 0, 'speckle': 1, 'zero': 0, 'normalize': None}) - parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you wants Swav augmentaion or ours, ' - 'other papers agumentaions methods will be added later.') + parser.add_argument('--i_want_swav', default=False, type=str2bool, help='Do you want Swav augmentaion or ours, ' + 'other papers agumentaions methods will be added later.') hparams = parser.parse_args() - main(hparams) + ###########Swin transformer parameters######## + #TODO Config file: is a good idea,maybe rafactor the code later.... + #parameters set in config file, # + # Swin Transformer parameters + # _C.MODEL.SWIN.PATCH_SIZE = 4 + # # _C.MODEL.SWIN.PATCH_SIZE = 2 + # _C.MODEL.SWIN.IN_CHANS = 3 + # # _C.MODEL.SWIN.IN_CHANS = 1 + # _C.MODEL.SWIN.EMBED_DIM = 96 + # _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] + # _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] + # _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] + # _C.MODEL.SWIN.WINDOW_SIZE = 7 + # # _C.MODEL.SWIN.WINDOW_SIZE = 2 + # _C.MODEL.SWIN.MLP_RATIO = 4. + # _C.MODEL.SWIN.QKV_BIAS = True + # _C.MODEL.SWIN.QK_SCALE = None + # _C.MODEL.SWIN.APE = False + # _C.MODEL.SWIN.PATCH_NORM = True + # _C.MODEL.SWIN.FINAL_UPSAMPLE = "expand_first" + ############################################## + swin_parser =ArgumentParser() + swin_parser.add_argument('hidden_mlp') + swin_parser.add_argument('swin_patch_size', default = 4) + swin_parser.add_argument('swin_in_chans', default = 1) + swin_parser.add_argument('swin_embed_dim', default = 96) + swin_parser.add_argument('swin_depths', default = [2, 2, 6, 2]) + swin_parser.add_argument('swin_num_heads', default = [3, 6, 12, 24]) + swin_parser.add_argument('swin_window_size', default = 2) + swin_parser.add_argument('swin_mlp_ratio', default = 4.0) + swin_parser.add_argument('swin_QKV_BIAS', default = True) + swin_parser.add_argument('swin_QK_SCALE', default = None) + swin_parser.add_argument('swin_ape', default = False) + swin_parser.add_argument('swin_path_norm', default = True) + swin_hparams = swin_parser.parse_args() + + + main(hparams, swin_hparams) -- GitLab From 7cb80b33e6decea5368fe7017b56bbc293f38f65 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 15:22:28 +0100 Subject: [PATCH 04/11] Adjutmentations for HookNet: Add the swinnet to the list of models in initmodel section. --- SSLGlacier/processing/swav_module.py | 32 +++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/processing/swav_module.py index 970484b..3b57a4e 100644 --- a/SSLGlacier/processing/swav_module.py +++ b/SSLGlacier/processing/swav_module.py @@ -14,9 +14,11 @@ from SSLGlacier.models.RESNET import resnet18, resnet50 from SSLGlacier.models.UNET import UNet #The main network from pl_bolts.optimizers.lars import LARS from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay -#from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet - +##########Importing files for hook############ +from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet +######Temporary config file for hooknet####### +from utils.config import get_config class SwAV(LightningModule): def __init__( @@ -51,7 +53,7 @@ class SwAV(LightningModule): weight_decay: float = 1e-6, epsilon: float = 0.05, just_aug_for_same_assign_views: bool = False, - **kwargs + swin_hparams = {} ) -> None: """ Args: @@ -148,6 +150,8 @@ class SwAV(LightningModule): self.train_iters_per_epoch = self.num_samples // global_batch_size self.queue = None + ####For hook#### + self.swin_hparams = swin_hparams def setup(self, stage): if self.queue_length > 0: queue_folder = os.path.join(self.logger.log_dir, self.queue_path) @@ -176,12 +180,30 @@ class SwAV(LightningModule): first_conv=self.first_conv, maxpool1=self.maxpool1) elif self.arch == "hook": - backbone = resnet50(normalize=True, + backbone = SwinUnet( + img_size=224, + num_classes=5, + normalize=True, hidden_mlp=self.hidden_mlp, output_dim=self.feat_dim, num_prototypes=self.num_prototypes, first_conv=self.first_conv, - maxpool1=self.maxpool1) + maxpool1=self.maxpool1,######till here for basenet + image_size= self.swin_hparams.image_size,####from here for swin itself + swin_patch_size = self.swin_hparams.swin_patch_size, + swin_in_chans=self.swin_hparams.swin_in_chans, + swin_embed_dim=self.swin_hparams.swin_embed_dim, + swin_depths=self.swin_hparams.swin_depths, + swin_num_heads=self.swin_hparams.swin_num_heads, + swin_window_size=self.swin_hparams.swin_window_size, + swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio, + swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS, + swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + swin_ape=self.swin_hparams.swin_ape, + swin_patch_norm=self.swin_hparams.swin_path_norm + ) elif self.arch == "Unet" or "unet": backbone = UNet(num_classes=5, -- GitLab From 8ab559ce004953838e241f3a8988b3b4a9c52b04 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 17:43:22 +0100 Subject: [PATCH 05/11] Adjustmentations: Pass parameters as agrument parset instead of reading them from config file. just for puting it in the same harmony as the rest of the code. Otherwise I like the config file idea. --- SSLGlacier/main.py | 48 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py index 61fc68d..73d9277 100644 --- a/SSLGlacier/main.py +++ b/SSLGlacier/main.py @@ -15,6 +15,7 @@ from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator import numpy as np from utils.utils import str2bool, getDataPath +from typing import * @@ -230,17 +231,42 @@ if __name__ == '__main__': ############################################## swin_parser =ArgumentParser() swin_parser.add_argument('hidden_mlp') - swin_parser.add_argument('swin_patch_size', default = 4) - swin_parser.add_argument('swin_in_chans', default = 1) - swin_parser.add_argument('swin_embed_dim', default = 96) - swin_parser.add_argument('swin_depths', default = [2, 2, 6, 2]) - swin_parser.add_argument('swin_num_heads', default = [3, 6, 12, 24]) - swin_parser.add_argument('swin_window_size', default = 2) - swin_parser.add_argument('swin_mlp_ratio', default = 4.0) - swin_parser.add_argument('swin_QKV_BIAS', default = True) - swin_parser.add_argument('swin_QK_SCALE', default = None) - swin_parser.add_argument('swin_ape', default = False) - swin_parser.add_argument('swin_path_norm', default = True) + swin_parser.add_argument('swin_patch_size', default = 4, + type=int, + help='The size (resolution) of each patch.') + swin_parser.add_argument('swin_in_chans', default = 1, + ) + swin_parser.add_argument('swin_embed_dim', default = 96, + type=int, + help='Dimensionality of patch embedding.') + swin_parser.add_argument('swin_depths', + default= [2, 2, 6, 2], + type=list[int], + help="Depth of each layer in the Transformer encoder." + ) + swin_parser.add_argument('swin_num_heads', + default = [3, 6, 12, 24], + type=list[int], + help='Number of attention heads in each layer of the Transformer encoder.') + swin_parser.add_argument('swin_window_size', + default=2, + type=int, + help= 'Size of windows') + swin_parser.add_argument('swin_mlp_ratio', + default= 4.0,type=float, + help='Ratio of MLP hidden dimensionality to embedding dimensionality.') + swin_parser.add_argument('swin_QKV_BIAS', default = True, + type=str2bool, + help='Whether or not a learnable bias should be added to the queries, keys and values.') + swin_parser.add_argument('swin_QK_SCALE', default = None, + type=float, + help='Override default qk scale of head_dim ** -0.5 if set.') + swin_parser.add_argument('swin_ape', default = False, + type=str2bool, + help="If True, add absolute position embedding to the patch embedding.") + swin_parser.add_argument('swin_path_norm', default = True, + type=str2bool, + help="If True, add normalization after patch embedding.") swin_hparams = swin_parser.parse_args() -- GitLab From 3a920a15e2bd766ee7cf73293d1fd7d583c84d73 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 17:44:36 +0100 Subject: [PATCH 06/11] Adjutmentations for HookNet: Add explnations to patch expand, what is that, why is that. --- SSLGlacier/models/hook/patch_expand.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_expand.py index ef397f8..b253c80 100644 --- a/SSLGlacier/models/hook/patch_expand.py +++ b/SSLGlacier/models/hook/patch_expand.py @@ -4,6 +4,16 @@ from einops import rearrange class PatchExpand(nn.Module): + """ + from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9 + Take the first patch expanding layer as an example, before up-sampling, + a linear layer is applied on the input features (W/32, H/32, 8C) to increase the feature dimension to + 2 * the original dimension (W/32, H/32, 16C). Then, we use rearrange operation to expand the resolution + of the input features to 2*the input resolution and reduce the feature dimension to + quarter of the input dimension ((W/32, H/32, 16C) --> (W/16, H/16, 4C)). + We will discuss the impact of using patch expanding layer + to perform up-sampling inSect. 4.5 in https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9. + """ def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() self.input_resolution = input_resolution -- GitLab From d78eac5a82771edcc08a26d042b3093cdc0261b6 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Thu, 14 Dec 2023 17:45:43 +0100 Subject: [PATCH 07/11] Import problem: Solve import error, by changing import path. --- SSLGlacier/models/hook/swin_transformer_block.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py index dd8938c..52fa6dc 100644 --- a/SSLGlacier/models/hook/swin_transformer_block.py +++ b/SSLGlacier/models/hook/swin_transformer_block.py @@ -1,9 +1,9 @@ import torch import torch.nn as nn from timm.models.layers import DropPath, to_2tuple -from models.window_attention import WindowAttention -from models.mlp import Mlp -from models.window_utils import window_partition, window_reverse +from SSLGlacier.models.hook.window_attention import WindowAttention +from SSLGlacier.models.hook.mlp import Mlp +from SSLGlacier.models.hook.window_utils import window_partition, window_reverse class SwinTransformerBlock(nn.Module): -- GitLab From 15b8f2c7c0a187162bf95e6983102571e3f6cb9a Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Fri, 15 Dec 2023 09:34:28 +0100 Subject: [PATCH 08/11] Refactor the code: Reduce the number of files, put the classes in same file, make less clumsy at least in the directory level. --- SSLGlacier/models/hook/Swin_Transformer.py | 10 +- SSLGlacier/models/hook/basic_layers.py | 229 ++++++++++++++ SSLGlacier/models/hook/basic_swin_layer.py | 74 ----- .../models/hook/basic_swin_layer_cross.py | 74 ----- SSLGlacier/models/hook/basic_swin_layer_up.py | 63 ---- SSLGlacier/models/hook/mlp.py | 20 -- SSLGlacier/models/hook/patch_embedding.py | 51 --- SSLGlacier/models/hook/patch_merging.py | 51 --- .../{patch_expand.py => patch_processing.py} | 123 +++++++- .../models/hook/swin_transformer_block.py | 136 -------- .../hook/swin_transformer_block_cross.py | 151 --------- .../models/hook/swin_transformer_blocks.py | 296 ++++++++++++++++++ SSLGlacier/models/hook/window_attention.py | 104 ------ SSLGlacier/models/hook/window_attentions.py | 240 ++++++++++++++ .../models/hook/window_cross_attention.py | 106 ------- SSLGlacier/models/hook/window_utils.py | 34 -- 16 files changed, 886 insertions(+), 876 deletions(-) create mode 100644 SSLGlacier/models/hook/basic_layers.py delete mode 100644 SSLGlacier/models/hook/basic_swin_layer.py delete mode 100644 SSLGlacier/models/hook/basic_swin_layer_cross.py delete mode 100644 SSLGlacier/models/hook/basic_swin_layer_up.py delete mode 100644 SSLGlacier/models/hook/mlp.py delete mode 100644 SSLGlacier/models/hook/patch_embedding.py delete mode 100644 SSLGlacier/models/hook/patch_merging.py rename SSLGlacier/models/hook/{patch_expand.py => patch_processing.py} (56%) delete mode 100644 SSLGlacier/models/hook/swin_transformer_block.py delete mode 100644 SSLGlacier/models/hook/swin_transformer_block_cross.py create mode 100644 SSLGlacier/models/hook/swin_transformer_blocks.py delete mode 100644 SSLGlacier/models/hook/window_attention.py create mode 100644 SSLGlacier/models/hook/window_attentions.py delete mode 100644 SSLGlacier/models/hook/window_cross_attention.py delete mode 100644 SSLGlacier/models/hook/window_utils.py diff --git a/SSLGlacier/models/hook/Swin_Transformer.py b/SSLGlacier/models/hook/Swin_Transformer.py index 992126a..1315bfd 100644 --- a/SSLGlacier/models/hook/Swin_Transformer.py +++ b/SSLGlacier/models/hook/Swin_Transformer.py @@ -2,12 +2,10 @@ import torch import torch.nn as nn import numpy as np from timm.models.layers import trunc_normal_ -from SSLGlacier.models.hook.patch_embedding import PatchEmbed -from SSLGlacier.models.hook.patch_merging import PatchMerging -from SSLGlacier.models.hook.patch_expand import PatchExpand, PatchExpandC, FinalPatchExpand_X4 -from SSLGlacier.models.hook.basic_swin_layer import BasicLayer -from SSLGlacier.models.hook.basic_swin_layer_up import BasicLayer_up -from SSLGlacier.models.hook.basic_swin_layer_cross import BasicLayer_Cross +from SSLGlacier.models.hook.patch_processing import PatchEmbed +from SSLGlacier.models.hook.patch_processing import PatchMerging +from SSLGlacier.models.hook.patch_processing import PatchExpand, PatchExpandC, FinalPatchExpand_X4 +from SSLGlacier.models.hook.basic_layers import BasicLayer, BasicLayer_up, BasicLayer_Cross torch.set_printoptions(threshold=np.inf) torch.set_printoptions(linewidth=300) diff --git a/SSLGlacier/models/hook/basic_layers.py b/SSLGlacier/models/hook/basic_layers.py new file mode 100644 index 0000000..079d1cc --- /dev/null +++ b/SSLGlacier/models/hook/basic_layers.py @@ -0,0 +1,229 @@ +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from SSLGlacier.models.hook.patch_processing import PatchExpand +from SSLGlacier.models.hook.swin_transformer_blocks import SwinTransformerBlock, SwinTransformerBlock_Cross +#----------------------------------------------- +########### Basic Swin Layer Classes ########### +#----------------------------------------------- +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution # 56 + self.depth = depth # 2 + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, # 96, 56 + num_heads=num_heads, window_size=window_size, # 3, 14 + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, # 4 + qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None + drop=drop, attn_drop=attn_drop, # 0.2, 0 + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) # layer-norm + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops +#----------------------------------------------- +####### End of Basic Swin Layer Classes ######## +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +####### Basic Swin Layer cross Classes ######### +#----------------------------------------------- + +class BasicLayer_Cross(nn.Module): # Input of second attention is output of first. See paper + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution # 56 + self.depth = depth # 2 + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution, # 96, 56 + num_heads=num_heads, window_size=window_size, # 3, 14 + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, # 4 + qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None + drop=drop, attn_drop=attn_drop, # 0.2, 0 + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) # layer-norm + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, y): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, y) + if self.downsample is not None: + x = self.downsample(x) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops +#----------------------------------------------- +##### End of Basic Swin Layer cross Classes #### +#----------------------------------------------- + +################################################ +################################################ + + +#----------------------------------------------- +######### Basic Swin Layer UP Classes ########## +#----------------------------------------------- + +class BasicLayer_up(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if upsample is not None: + self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) + else: + self.upsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.upsample is not None: + x = self.upsample(x) + return x + +#----------------------------------------------- +##### End of Basic Swin Layer UP Classes #### +#----------------------------------------------- diff --git a/SSLGlacier/models/hook/basic_swin_layer.py b/SSLGlacier/models/hook/basic_swin_layer.py deleted file mode 100644 index e9570cb..0000000 --- a/SSLGlacier/models/hook/basic_swin_layer.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution # 56 - self.depth = depth # 2 - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, # 96, 56 - num_heads=num_heads, window_size=window_size, # 3, 14 - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, # 4 - qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None - drop=drop, attn_drop=attn_drop, # 0.2, 0 - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) # layer-norm - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - if self.downsample is not None: - x = self.downsample(x) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops \ No newline at end of file diff --git a/SSLGlacier/models/hook/basic_swin_layer_cross.py b/SSLGlacier/models/hook/basic_swin_layer_cross.py deleted file mode 100644 index e9481d9..0000000 --- a/SSLGlacier/models/hook/basic_swin_layer_cross.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block_cross import SwinTransformerBlock_Cross - - -class BasicLayer_Cross(nn.Module): # Input of second attention is output of first. See paper - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution # 56 - self.depth = depth # 2 - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock_Cross(dim=dim, input_resolution=input_resolution, # 96, 56 - num_heads=num_heads, window_size=window_size, # 3, 14 - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, # 4 - qkv_bias=qkv_bias, qk_scale=qk_scale, # True, None - drop=drop, attn_drop=attn_drop, # 0.2, 0 - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) # layer-norm - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, y): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x, y) - if self.downsample is not None: - x = self.downsample(x) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops diff --git a/SSLGlacier/models/hook/basic_swin_layer_up.py b/SSLGlacier/models/hook/basic_swin_layer_up.py deleted file mode 100644 index 9ee1b87..0000000 --- a/SSLGlacier/models/hook/basic_swin_layer_up.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from SSLGlacier.models.hook.swin_transformer_block import SwinTransformerBlock -from SSLGlacier.models.hook.patch_expand import PatchExpand - - -class BasicLayer_up(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if upsample is not None: - self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer) - else: - self.upsample = None - - def forward(self, x): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) - else: - x = blk(x) - if self.upsample is not None: - x = self.upsample(x) - return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/mlp.py b/SSLGlacier/models/hook/mlp.py deleted file mode 100644 index 6e7730f..0000000 --- a/SSLGlacier/models/hook/mlp.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch.nn as nn - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x \ No newline at end of file diff --git a/SSLGlacier/models/hook/patch_embedding.py b/SSLGlacier/models/hook/patch_embedding.py deleted file mode 100644 index e9fba67..0000000 --- a/SSLGlacier/models/hook/patch_embedding.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch.nn as nn -from timm.models.layers import to_2tuple - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size # 224 - self.patch_size = patch_size # 4 - self.patches_resolution = patches_resolution # 56 - self.num_patches = patches_resolution[0] * patches_resolution[1] # 3136 - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 22, 96, 56, 56 - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert H == self.img_size[0] and W == self.img_size[1], \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C (22, 3136, 96) - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops \ No newline at end of file diff --git a/SSLGlacier/models/hook/patch_merging.py b/SSLGlacier/models/hook/patch_merging.py deleted file mode 100644 index 0567504..0000000 --- a/SSLGlacier/models/hook/patch_merging.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.nn as nn - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - diff --git a/SSLGlacier/models/hook/patch_expand.py b/SSLGlacier/models/hook/patch_processing.py similarity index 56% rename from SSLGlacier/models/hook/patch_expand.py rename to SSLGlacier/models/hook/patch_processing.py index b253c80..d7e1f33 100644 --- a/SSLGlacier/models/hook/patch_expand.py +++ b/SSLGlacier/models/hook/patch_processing.py @@ -1,8 +1,124 @@ +import torch import torch.nn as nn from einops import rearrange -# import torch +from timm.models.layers import to_2tuple +#----------------------------------------------- +############## Patch Merging Class ########### +#----------------------------------------------- +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + +#----------------------------------------------- +######### End of Patch Merging Class ########### +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +############## Patch Embedding Class ########### +#----------------------------------------------- +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size # 224 + self.patch_size = patch_size # 4 + self.patches_resolution = patches_resolution # 56 + self.num_patches = patches_resolution[0] * patches_resolution[1] # 3136 + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # 22, 96, 56, 56 + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C (22, 3136, 96) + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops +#----------------------------------------------- +######### End of Patch Embeding Class ########### +#----------------------------------------------- + +################################################ +################################################ + +#----------------------------------------------- +############# Patch Expand Classes ############# +#----------------------------------------------- class PatchExpand(nn.Module): """ from:https://link.springer.com/chapter/10.1007/978-3-031-25066-8_9 @@ -36,7 +152,6 @@ class PatchExpand(nn.Module): return x - # class PatchExpand4(nn.Module): # def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): # super().__init__() @@ -59,8 +174,6 @@ class PatchExpand(nn.Module): # x = self.norm(x) # # return x - - class PatchExpandC(nn.Module): def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm): super().__init__() @@ -84,7 +197,6 @@ class PatchExpandC(nn.Module): return x - # class PCAa(object): # def __init__(self, n_components=2): # self.n_components = n_components @@ -136,7 +248,6 @@ class PatchExpandC(nn.Module): # # return x - class FinalPatchExpand_X4(nn.Module): def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm): super().__init__() diff --git a/SSLGlacier/models/hook/swin_transformer_block.py b/SSLGlacier/models/hook/swin_transformer_block.py deleted file mode 100644 index 52fa6dc..0000000 --- a/SSLGlacier/models/hook/swin_transformer_block.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import DropPath, to_2tuple -from SSLGlacier.models.hook.window_attention import WindowAttention -from SSLGlacier.models.hook.mlp import Mlp -from SSLGlacier.models.hook.window_utils import window_partition, window_reverse - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size # 0 or 7 - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution # 56, 56 - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 - h_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - - w_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, - self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( - 0.0)) # 64, 49, 49 # 16, 196, 196 - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, x): - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - x_windows = x_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - # W-MSA/SW-MSA - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - - x = x.view(B, H * W, C) # 7, 56*56, 96 - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops diff --git a/SSLGlacier/models/hook/swin_transformer_block_cross.py b/SSLGlacier/models/hook/swin_transformer_block_cross.py deleted file mode 100644 index 8dbfb67..0000000 --- a/SSLGlacier/models/hook/swin_transformer_block_cross.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import DropPath, to_2tuple -from models.window_cross_attention import WindowAttention_Cross -from models.mlp import Mlp -from models.window_utils import window_partition, window_reverse - - -class SwinTransformerBlock_Cross(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim # 96 - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size # 0 or 7 - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention_Cross( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - # calculate attention mask for SW-MSA - H, W = self.input_resolution # 56, 56 - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 - h_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - - w_slices = (slice(0, -self.window_size), # 0, -14 - slice(-self.window_size, -self.shift_size), # -14, -7 - slice(-self.shift_size, None)) # -7, None - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, - self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( - 0.0)) # 64, 49, 49 # 16, 196, 196 - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def forward(self, target, context): - H, W = self.input_resolution - B, L, C = target.shape - assert L == H * W, "input feature has wrong size" - - shortcut = target - target = self.norm1(target) - target = target.view(B, H, W, C) - - context = self.norm1(context) - context = context.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_target = target - shifted_context = context - - # partition windows - target_windows = window_partition(shifted_target, - self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - context_windows = window_partition(shifted_context, - self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 - # print(x_windows.shape, 'ss') - target_windows = target_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - context_windows = context_windows.view(-1, self.window_size * self.window_size, - C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 - - # W-MSA/SW-MSA - attn_windows = self.attn(target_windows, context_windows, - mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 - shifted_target = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 - # reverse cyclic shift - if self.shift_size > 0: - target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - target = shifted_target - - target = target.view(B, H * W, C) # 7, 56*56, 96 - - # FFN - target = shortcut + self.drop_path(target) - target = target + self.drop_path(self.mlp(self.norm2(target))) - - return target - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops diff --git a/SSLGlacier/models/hook/swin_transformer_blocks.py b/SSLGlacier/models/hook/swin_transformer_blocks.py new file mode 100644 index 0000000..f6745c8 --- /dev/null +++ b/SSLGlacier/models/hook/swin_transformer_blocks.py @@ -0,0 +1,296 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple +from SSLGlacier.models.hook.window_attentions import WindowAttention, WindowAttention_Cross, window_partition, window_reverse + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +class SwinTransformerBlock_Cross(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size # 0 or 7 + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention_Cross( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution # 56, 56 + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 + h_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + + w_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, + self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( + 0.0)) # 64, 49, 49 # 16, 196, 196 + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, target, context): + H, W = self.input_resolution + B, L, C = target.shape + assert L == H * W, "input feature has wrong size" + + shortcut = target + target = self.norm1(target) + target = target.view(B, H, W, C) + + context = self.norm1(context) + context = context.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_target = torch.roll(target, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + shifted_context = torch.roll(context, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_target = target + shifted_context = context + + # partition windows + target_windows = window_partition(shifted_target, + self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + context_windows = window_partition(shifted_context, + self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + # print(x_windows.shape, 'ss') + target_windows = target_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + context_windows = context_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + # W-MSA/SW-MSA + attn_windows = self.attn(target_windows, context_windows, + mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 + shifted_target = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 + # reverse cyclic shift + if self.shift_size > 0: + target = torch.roll(shifted_target, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + target = shifted_target + + target = target.view(B, H * W, C) # 7, 56*56, 96 + + # FFN + target = shortcut + self.drop_path(target) + target = target + self.drop_path(self.mlp(self.norm2(target))) + + return target + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim # 96 + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size # 0 or 7 + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution # 56, 56 + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1; 1, 56, 56, 1 + h_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + + w_slices = (slice(0, -self.window_size), # 0, -14 + slice(-self.window_size, -self.shift_size), # -14, -7 + slice(-self.shift_size, None)) # -7, None + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, + self.window_size) # nW, window_size, window_size, 1; 64, 7, 7, 1 # 16, 14, 14, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 64, 49 # 16, 196 + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 64, 1, 49 - 64, 49, 1 + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float( + 0.0)) # 64, 49, 49 # 16, 196, 196 + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C; 448, 7, 7, 96 + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C; 448, 49, 96 # 7, 4, 4, 196, 96 + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C; 448, 49, 96 + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 448, 7, 7, 96 + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C; 7, 56, 56, 96 + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + x = x.view(B, H * W, C) # 7, 56*56, 96 + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops diff --git a/SSLGlacier/models/hook/window_attention.py b/SSLGlacier/models/hook/window_attention.py deleted file mode 100644 index 39e9be5..0000000 --- a/SSLGlacier/models/hook/window_attention.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import trunc_normal_ - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads # 3, 6, 12, 24 - head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 - - self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) - - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 - - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 - - coords_flatten = torch.flatten(coords, 1) # 2, 49 - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 - relative_position_index = relative_coords.sum(-1) # 49, 49 - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape # -1, 49, 96 - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, - 4) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # 448, 3, 49, 32 - - q = q * self.scale # 1/np.sqrt(32) - attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], - -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 - - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 - attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( - 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 - attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops diff --git a/SSLGlacier/models/hook/window_attentions.py b/SSLGlacier/models/hook/window_attentions.py new file mode 100644 index 0000000..4225479 --- /dev/null +++ b/SSLGlacier/models/hook/window_attentions.py @@ -0,0 +1,240 @@ +import torch +import torch.nn as nn +from timm.models.layers import trunc_normal_ + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads # 3, 6, 12, 24 + head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 + + self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 + + coords_flatten = torch.flatten(coords, 1) # 2, 49 + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 + relative_position_index = relative_coords.sum(-1) # 49, 49 + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape # -1, 49, 96 + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, + 4) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # 448, 3, 49, 32 + + q = q * self.scale # 1/np.sqrt(32) + attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 + attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 + attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +class WindowAttention_Cross(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads # 3, 6, 12, 24 + head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 + + self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) + + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 + + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 + + coords_flatten = torch.flatten(coords, 1) # 2, 49 + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 + relative_position_index = relative_coords.sum(-1) # 49, 49 + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape # -1, 49, 96 + + q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 + + q = q * self.scale # 1/np.sqrt(32) + attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 + attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( + 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 + attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape # 7, 56, 56, 96; 1, 56, 56, 1 + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) # 7, 8, 7, 8, 7, 96 # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1 + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, + C) # 7, 8, 8, 7, 7, 96; -1, 7, 7, 96 + # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1 + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) # 448 / (56*56/7/7) = 7 + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # 7, 8, 8, 7, 7, 96 + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # 7, 8, 7, 8, 96; 7, 56, 56, 96 + return x + diff --git a/SSLGlacier/models/hook/window_cross_attention.py b/SSLGlacier/models/hook/window_cross_attention.py deleted file mode 100644 index 7b5ddfd..0000000 --- a/SSLGlacier/models/hook/window_cross_attention.py +++ /dev/null @@ -1,106 +0,0 @@ -import torch -import torch.nn as nn -from timm.models.layers import trunc_normal_ - - -class WindowAttention_Cross(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads # 3, 6, 12, 24 - head_dim = dim // num_heads # 96 / 3 = 32; 96, 192, 384, 768 - - self.scale = qk_scale or head_dim ** -0.5 # 1/np.sqrt(32) - - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH; 169, 3 - - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 7, 7 - - coords_flatten = torch.flatten(coords, 1) # 2, 49 - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, 49, 49 - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 49, 49, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 14-1 - relative_position_index = relative_coords.sum(-1) # 49, 49 - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 96, 192, 384, 768 - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) # 96, 192, 384, 768 - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, y, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape # -1, 49, 96 - - q = x.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - k = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - v = y.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # 448, 49, 3, 3, 32; 3, 448, 3, 49, 32 - - q = q * self.scale # 1/np.sqrt(32) - attn = (q @ k.transpose(-2, -1)) # 448, 3, 49, 49 - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], - -1) # Wh*Ww,Wh*Ww,nH; 49, 49, 3 - - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww; 3, 49, 49 - attn = attn + relative_position_bias.unsqueeze(0) # 488, 3, 49, 49 - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze( - 0) # 7, 64, 3, 49, 49 + 1, 64, 1, 49, 49 - attn = attn.view(-1, self.num_heads, N, N) # 488, 3, 49, 49 - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # 448, 3, 49, 32; 448, 49, 3, 32; 448, 49, 96 - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - diff --git a/SSLGlacier/models/hook/window_utils.py b/SSLGlacier/models/hook/window_utils.py deleted file mode 100644 index fe08bad..0000000 --- a/SSLGlacier/models/hook/window_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape # 7, 56, 56, 96; 1, 56, 56, 1 - x = x.view(B, H // window_size, window_size, W // window_size, window_size, - C) # 7, 8, 7, 8, 7, 96 # 1, 8, 7, 8, 7, 1 # 1, 4, 14, 4, 14, 1 - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, - C) # 7, 8, 8, 7, 7, 96; -1, 7, 7, 96 - # 1, 8, 8, 7, 7, 1 # 1, 4, 4, 14, 14, 1 # 16, 14, 14, 1 - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) # 448 / (56*56/7/7) = 7 - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) # 7, 8, 8, 7, 7, 96 - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) # 7, 8, 7, 8, 96; 7, 56, 56, 96 - return x - -- GitLab From 7f7de8cd002e873590c5d4622af3c66b6151687b Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Fri, 15 Dec 2023 09:38:22 +0100 Subject: [PATCH 09/11] Refactor the code: Change the base class name to SWINNET, it was Net. --- SSLGlacier/models/RESNET.py | 4 ++-- SSLGlacier/models/{base_net.py => SWINNET.py} | 2 +- SSLGlacier/models/UNET.py | 4 ++-- SSLGlacier/models/hook/Swin_Transformer_Wrapper.py | 14 ++++++-------- 4 files changed, 11 insertions(+), 13 deletions(-) rename SSLGlacier/models/{base_net.py => SWINNET.py} (99%) diff --git a/SSLGlacier/models/RESNET.py b/SSLGlacier/models/RESNET.py index 9a094ef..0ee1479 100644 --- a/SSLGlacier/models/RESNET.py +++ b/SSLGlacier/models/RESNET.py @@ -1,7 +1,7 @@ """Adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py.""" import torch from torch import nn -from SSLGlacier.models.base_net import Net +from SSLGlacier.models.SWINNET import SwinNet def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): @@ -122,7 +122,7 @@ class Bottleneck(nn.Module): return self.relu(out) -class ResNet(Net): +class ResNet(SwinNet): def __init__( self, block, diff --git a/SSLGlacier/models/base_net.py b/SSLGlacier/models/SWINNET.py similarity index 99% rename from SSLGlacier/models/base_net.py rename to SSLGlacier/models/SWINNET.py index 4489ec3..9924814 100644 --- a/SSLGlacier/models/base_net.py +++ b/SSLGlacier/models/SWINNET.py @@ -3,7 +3,7 @@ import torch from torch import nn -class Net(nn.Module): +class SwinNet(nn.Module): def __init__( self, groups=1, diff --git a/SSLGlacier/models/UNET.py b/SSLGlacier/models/UNET.py index ad4f65c..5061bb0 100644 --- a/SSLGlacier/models/UNET.py +++ b/SSLGlacier/models/UNET.py @@ -1,9 +1,9 @@ import torch from torch import Tensor, nn from torch.nn import functional as F # noqa: N812 -from SSLGlacier.models.base_net import Net +from SSLGlacier.models.SWINNET import SwinNet -class UNet(Net): +class UNet(SwinNet): """Pytorch Lightning implementation of U-Net. Paper: `U-Net: Convolutional Networks for Biomedical Image Segmentation <https://arxiv.org/abs/1505.04597>`_ diff --git a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py index 69e68a9..4f96144 100644 --- a/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py +++ b/SSLGlacier/models/hook/Swin_Transformer_Wrapper.py @@ -1,19 +1,17 @@ # coding=utf-8 -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function import copy import logging - import torch -import torch.nn as nn - +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from SSLGlacier.models.SWINNET import SwinNet from SSLGlacier.models.hook.Swin_Transformer import SwinTransformerSys -from SSLGlacier.models.base_net import Net + logger = logging.getLogger(__name__) -class SwinUnet(Net): +class SwinUnet(SwinNet): def __init__(self, num_classes=21843, zero_head=False, **kwargs): super().__init__( groups=1, -- GitLab From a9f5a009a609a22d5a416f8eb8b6442182e2a926 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Fri, 15 Dec 2023 09:46:31 +0100 Subject: [PATCH 10/11] Refactoring: Make the folder names more clear. --- SSLGlacier/main.py | 4 +- SSLGlacier/modules/agumentations_.py | 411 ++++++++++++++++++++ SSLGlacier/modules/datamodule_.py | 149 +++++++ SSLGlacier/modules/semi_module.py | 29 ++ SSLGlacier/modules/swav_module.py | 562 +++++++++++++++++++++++++++ SSLGlacier/modules/transformers_.py | 248 ++++++++++++ 6 files changed, 1401 insertions(+), 2 deletions(-) create mode 100644 SSLGlacier/modules/agumentations_.py create mode 100644 SSLGlacier/modules/datamodule_.py create mode 100644 SSLGlacier/modules/semi_module.py create mode 100644 SSLGlacier/modules/swav_module.py create mode 100644 SSLGlacier/modules/transformers_.py diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py index 73d9277..cc7d5b4 100644 --- a/SSLGlacier/main.py +++ b/SSLGlacier/main.py @@ -5,12 +5,12 @@ from torchinfo import summary as torchinfo_summery from torchsummary import summary as torch_summary from torch._inductor.config import trace -from processing import datamodule_, transformers_ +from modules import datamodule_, transformers_ import pytorch_lightning as pl from models import ResNet_light_mine, Swav_Orig_ResNet import os import torch -from processing.swav_module import SwAV +from modules.swav_module import SwAV from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator import numpy as np from utils.utils import str2bool, getDataPath diff --git a/SSLGlacier/modules/agumentations_.py b/SSLGlacier/modules/agumentations_.py new file mode 100644 index 0000000..05cb774 --- /dev/null +++ b/SSLGlacier/modules/agumentations_.py @@ -0,0 +1,411 @@ +# PILRandomGaussianBlur and get_color_distortion are used and implemented +# by Swav and SimCLR - https://arxiv.org/abs/2002.05709 +import random +from logging import getLogger + +import torch.nn +from PIL import ImageFilter +import numpy as np +import torchvision.transforms as transforms +import torchvision, tormentor +from torchvision.transforms import functional as F +import torch +import matplotlib.pyplot as plt + +logger = getLogger() + + +class Compose(object): + """ + Class for chaining transforms together + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return [image, target] + + def __getitem__(self, item): + return self.transforms[item] + +class DoNothing(torch.nn.Module): + def __init__(self): + super(DoNothing, self).__init__() + + def __call__(self, img, mask): + return [img, mask] + +class Cropper(torch.nn.Module): + def __init__(self, i, j, h, w): + super(Cropper, self).__init__() + self.left = i + self.right = j + self.height = h + self.width = w + + def __call__(self, img, mask): + cropped_img = F.crop(img, self.left, self.right, self.height, self.width) + cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) + return [cropped_img, cropped_mask] + + +class RandomCropper(torch.nn.Module): + ''' + This function returns one patch at time, if you need more patches in one image, call it in a loop + Args: + orig_img: get an png or jpg image and crop one patch randomly, in both image and mask + orig_mask: This is the images mask, we crop same area from + Returns: + A list of two argument, first cropped patch in image,second cropped patch in mask + ''' + + def __init__(self, size): + super(RandomCropper, self).__init__() + self.left = 0 + self.right = 0 + self.height = 0 + self.width = 0 + self.size = size + + def forward(self, img, mask): + self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params( + img, output_size=(self.size, self.size)) + cropped_img = F.crop(img, self.left, self.right, self.height, self.width) + cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) + return [cropped_img, cropped_mask] + + +class PILRandomGaussianBlur(torch.nn.Module): + def __init__(self, radius_min=0.1, radius_max=2.): + """ + Apply Gaussian Blur to the PIL image. Take the radius and probability of + application as the parameter. + This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 + """ + super(PILRandomGaussianBlur, self).__init__() + self.radius_min = radius_min + self.radius_max = radius_max + + def forward(self, img, mask): + return [img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ), mask] + + +class RandomHorizontalFlip(torch.nn.Module): + def __init__(self): + super(RandomHorizontalFlip, self).__init__() + + def forward(self, img, mask): + image = torchvision.transforms.functional.hflip(img) + mask = torchvision.transforms.functional.hflip(mask) + return [image, mask] + + +class GetColorDistortion(torch.nn.Module): + def __int__(self, s=1.0): + super(GetColorDistortion, self).__init__() + self.s = s + + def forward(self, img, mask): + color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) + return color_distort(img, mask) + + +# TODO this net +class GaussNoise(torch.nn.Module): + def __init__(self, mean=0, var=10000): + super(GaussNoise, self).__init__() + self.mean = mean + self.var = var + + def forward(self, img, mask): + row, col = img.size + sigma = self.var ** 0.5 + gauss = np.random.normal(self.mean, sigma, (row, col)) + gauss = gauss.reshape(row, col) + noisy = img + gauss + return [noisy, mask] + + +class SaltPepperNoise(torch.nn.Module): + def __init__(self, salt_or_pepper=0.5): + super(SaltPepperNoise, self).__init__() + self.salt_or_pepper = salt_or_pepper + + def forward(self, img, mask): + if len(img.size) == 3: + img_size = img.size[1] * img.size[2] + else: + img_size = img.size[0] * img.size[1] + amount = .4 + noisy = np.copy(img) + # Salt mode + num_salt = np.ceil(amount * img_size * self.salt_or_pepper) + target_pixels = [np.random.randint(0, i - 1, int(num_salt)) + for i in img.size] + target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) + for i, j in target_pixels: + noisy[i][j] = 1 + # Pepper mode + num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper)) + target_pixels = [np.random.randint(0, i - 1, int(num_pepper)) + for i in img.size] + target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) + + for i, j in target_pixels: + noisy[i][j] = 0 + + # plt.imshow(img) + # plt.imshow(noisy) + # plt.show() + return [noisy, mask] + + +class PoissionNoise(torch.nn.Module): + def __init__(self): + super(PoissionNoise, self).__init__() + + def forward(self, img, mask): + vals = len(np.unique(img)) + vals = 2 ** np.ceil(np.log2(vals)) + noisy = np.random.poisson(img * vals) / float(vals) + return [noisy, mask] + + +# TODO this one +class SpeckleNoise(torch.nn.Module): + def __init__(self): + super(SpeckleNoise, self).__init__() + + def forward(self, img, mask): + row, col = img.size + gauss = np.random.randn(row, col) + gauss = gauss.reshape(row, col) + noisy = img + img * gauss + return [noisy, mask] + + +class SetZeroNoise(torch.nn.Module): + def __init__(self): + super(SetZeroNoise, self).__init__() + + def forward(self, img, mask): + row, col = img.size + img_size = row * col + random_rows = np.random.randint(0, int(row), (1, int(img_size))) + random_cols = np.random.randint(0, int(col), (1, int(img_size))) + target_pixels = list(zip(random_rows[0], random_cols[0])) + for pix in target_pixels: + img[pix[0], pix[1]] = 0 + return [img, mask] + + +#################################################################################### + +###########Base_code Agumentations suggestions: #TODO do not need them now########## +class MyWrap(torch.nn.Module): + """ + Random wrap augmentation taken from tormentor + """ + + def __init__(self): + super(MyWrap, self).__init__() + + def forward(self, img, target): + # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics + wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)), + intensity=tormentor.random.Uniform(value_range=(.0, 1.))) + wrap = wrap_rand() + image = wrap(img) + mask = wrap(target, is_mask=True) + return [image, mask] + + +class Rotate(torch.nn.Module): + """ + Random rotation augmentation + """ + + def __init__(self): + super(Rotate, self).__init__() + + def forward(self, img, target): + random = np.random.randint(0, 3) + angle = 90 + if random == 1: + angle = 180 + elif random == 2: + angle = 270 + image = torchvision.transforms.functional.rotate(img, angle=angle) + mask = torchvision.transforms.functional.rotate(target, angle=angle) + return [image, mask.squeeze(0)] + + +class Bright(torch.nn.Module): + """ + Random brightness adjustment augmentations + """ + + def __init__(self, + lower_band: float = -0.2, + upper_band: float = 0.2): + super(Bright, self).__init__() + self.lower_band = lower_band + self.upper_band = upper_band + + def forward(self, img, mask): + bright_rand = tormentor.Brightness.override_distributions( + brightness=tormentor.random.Uniform((self.lower_band, self.upper_band))) + # bright = bright_rand() + # image_transformed = img.clone() + image_transformed = bright_rand(np.asarray(img)) + # set NA areas back to zero + image_transformed.seed[img == 0] = 0.0 + + return [image_transformed.seed, mask] + + +class Noise(torch.nn.Module): + """ + Random additive noise augmentation + """ + + def __init__(self): + super(Noise, self).__init__() + + def forward(self, img, target): + # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again + noise = torch.normal(mean=0, std=0.3, size=img.size) + image = img + img * noise + image[image > 1.0] = 1.0 + image[image < 0.0] = 0.0 + + return [image, target] + + +class Resize(torch.nn.Module): + # Library code + """Resize the input image to the given size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means a maximum of two leading dimensions + + .. warning:: + The output image might be different depending on its type: when downsampling, the interpolation of PIL images + and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences + in the performance of a network. Therefore, it is preferable to train and serve a model with the same input + types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors + closer. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image. If the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, + ``size`` will be overruled so that the longer edge is equal to + ``max_size``. + As a result, the smaller edge may be shorter than ``size``. This + is only supported if ``size`` is an int (or a sequence of length + 1 in torchscript mode). + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. + """ + + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"): + super().__init__() + self.size = size + self.max_size = max_size + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img, mask): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask] + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size): + super().__init__() + self.size = size # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img, mask): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER + W, H = img.size + + WW = (W // self.size) + 2 + HH = (H // self.size) + 2 + + return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))] + + +class ToTensorZones(): + def __call__(self, image, target): + image = F.to_tensor(np.array(image).astype(np.float32)) + target = torch.from_numpy(np.array(target).astype(np.float32)) + # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254 + target[target == 0] = 0 + target[target == 64] = 1 + target[target == 127] = 2 + target[target == 254] = 3 + # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3 + return [image, target] diff --git a/SSLGlacier/modules/datamodule_.py b/SSLGlacier/modules/datamodule_.py new file mode 100644 index 0000000..edf3648 --- /dev/null +++ b/SSLGlacier/modules/datamodule_.py @@ -0,0 +1,149 @@ +from torchvision.transforms import functional as F +from pl_bolts.models.self_supervised import swav +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +import pytorch_lightning as pl +from torchvision import transforms as transform_lib +from . import agumentations_ as our_transforms +import torchvision +import numpy as np +from typing import Tuple, List, Any, Callable, Optional +import torch +import cv2 +from PIL import Image,ImageOps +import os + +# TODO add moco, simclr,cpc and amdim transformers +class SSLDataset(Dataset): + def __init__(self, parent_dir, transform,return_index=False, **kwargs): # TODO Pass them as arguments + ''' + Args: + mode: can be tested, train, or validation + parent_dir: directory in which the folders test, train and validation exists + ''' + self.mode = kwargs['mode'] if kwargs['mode'] else 'train' + self.return_index = return_index + self.images_path = os.path.join(parent_dir, "sar_images", self.mode) + self.masks_path = os.path.join(parent_dir, 'zones', self.mode) + self.images = os.listdir(self.images_path) + self.masks = os.listdir(self.masks_path) + assert len(self.masks) == len(self.images), "You don't have the same number of images and masks" + self.transform = transform + # Sort than images and masks fit together + self.images.sort() + self.masks.sort() + + # Let shuffle and save indices, and load them in next calls if they exist + if not os.path.exists(os.path.join("data_processing", "data_splits")): + os.makedirs(os.path.join("data_processing", "data_splits")) + if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")): + shuffle = np.random.permutation(len(self.images)) + # Works for numpy version >= 1.5 + np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, + newline=' ') + + else: + # use already existing shuffle + with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp: + lines = fp.readlines() + shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines] + # if lengths do not match, we need to create a new permutation + if len(shuffle) != len(self.images): + shuffle = np.random.permutation(len(self.images)) + np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, + newline=' ') + + self.images = np.array(self.images) # Why!!!! + self.masks = np.array(self.masks) + self.images = self.images[shuffle].copy() + self.masks = self.masks[shuffle].copy() + self.images = list(self.images) + self.masks = list(self.masks) # Why!!!! + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + img_name = self.images[index] + masks_name = self.masks[index] + assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \ + "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name + #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE) + #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE) + image = Image.open(os.path.join(self.images_path, img_name).__str__()) + image = ImageOps.grayscale(image) + #image = np.array(image).astype(np.float32) + + #rgbimg = Image.new("RGBA", image.size) + #image = rgbimg.paste(image) + #image = ImageOps.grayscale(image) + mask = Image.open(os.path.join(self.masks_path, masks_name).__str__()) + #rgbimg = Image.new("RGBA", mask.size) + #mask = rgbimg.paste(rgbimg) + mask = ImageOps.grayscale(mask) + #mask = np.array(mask).astype(np.float32) + # TODO check if it works or not, it should return multiple transformation for one image + # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!! + # TODO this is a list of tuples of img and masks, think if you want to change it later + + # to_tensor = our_transforms.ToTensorZones() + # _, mask = to_tensor(image,mask) + if self.transform is not None: + augmented_imgs, masks = self.transform(image, mask) + if self.return_index: + return index, augmented_imgs, mask, img_name, masks_name + return augmented_imgs, masks + + +class SSLDataModule(pl.LightningDataModule): + # Base is adopted from Nora's code + def __init__(self, batch_size, parent_dir, args): + """ + :param batch_size: batch size + :param target: Either 'zones' or 'front'. Tells which masks should be used. + """ + super().__init__() + self.transforms = None + self.batch_size = batch_size + self.glacier_test = None + self.glacier_train = None + self.glacier_val = None + self.parent_dir = parent_dir + self.aug_args = args.agumentations + self.size_crops = args.size_crops + self.num_crops = args.nmb_crops + + def prepare_data(self): + # download, + # only called on 1 GPU/TPU in distributed + pass + def __len__(self): + return len(self.glacier_train) + def setup(self, stage=None): + # process and split here + # make assignments here (val/train/test split) + # called on every process in DDP + if stage == 'test' or stage is None: + self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms + self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage) + if stage == 'fit' or stage is None: + self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms + self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train') + + self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms + self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val') + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True, + drop_last=True) + + def val_dataloader(self)-> DataLoader: + return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True) + + def test_dataloader(self)-> DataLoader: + return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True, + drop_last=True) # TODO self.batch_size + + def _default_transforms(self) -> Callable: + #return transform_lib.Compose([transform_lib.ToTensor()]) + return our_transforms.Compose([our_transforms.ToTensorZones()]) diff --git a/SSLGlacier/modules/semi_module.py b/SSLGlacier/modules/semi_module.py new file mode 100644 index 0000000..9798a36 --- /dev/null +++ b/SSLGlacier/modules/semi_module.py @@ -0,0 +1,29 @@ +from pytorch_lightning import LightningModule +from typing import Any +class Semi_(LightningModule): + def __init__(self): + pass + + def setup(self, stage: str) -> None: + pass + + def init_model(self): + pass + + def forward(self, *args: Any, **kwargs: Any) -> Any: + pass + + def on_train_epoch_start(self) -> None: + pass + + def shared_step(self, batch): + inputs, y = batch + inputs = inputs[:, -1] + embedding = self.model(inputs) + ####Where should I add noise.... + ####assume embedding[b][i] belongs to batch b view i + ####if we have two view, one with noise and one without, then + ##### select some portion of each view belongs to different classes. + ##### pos_set = same classess .. + ##### neg_set = diff classes + diff --git a/SSLGlacier/modules/swav_module.py b/SSLGlacier/modules/swav_module.py new file mode 100644 index 0000000..3b57a4e --- /dev/null +++ b/SSLGlacier/modules/swav_module.py @@ -0,0 +1,562 @@ +"""Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" +import os +from argparse import ArgumentParser + +import torch +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from torch import nn + +#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss +from SSLGlacier.models.losses import SWAVLoss +from SSLGlacier.models.RESNET import resnet18, resnet50 +#from Base_Code.models.ParentUNet import UNet +from SSLGlacier.models.UNET import UNet #The main network +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay + +##########Importing files for hook############ +from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet +######Temporary config file for hooknet####### +from utils.config import get_config + +class SwAV(LightningModule): + def __init__( + self, + gpus: int, + num_samples: int, + batch_size: int, + dataset: str, + num_nodes: int = 1, + arch: str = "resnet50", + hidden_mlp: int = 2048, + feat_dim: int = 128, + warmup_epochs: int = 10, + max_epochs: int = 100, + num_prototypes: int = 3000, + freeze_prototypes_epochs: int = 1, + temperature: float = 0.1, + sinkhorn_iterations: int = 3, + queue_length: int = 0, # must be divisible by total batch-size + queue_path: str = "queue", + epoch_queue_starts: int = 15, + crops_for_assign: tuple = (0, 1), + num_crops: tuple = (2, 6), + num_augs: int= 2, + first_conv: bool = True, + maxpool1: bool = True, + optimizer: str = "adam", + exclude_bn_bias: bool = False, + start_lr: float = 0.0, + learning_rate: float = 1e-3, + final_lr: float = 0.0, + weight_decay: float = 1e-6, + epsilon: float = 0.05, + just_aug_for_same_assign_views: bool = False, + swin_hparams = {} + ) -> None: + """ + Args: + gpus: number of gpus per node used in training, passed to SwAV module + to manage the queue and select distributed sinkhorn + num_nodes: number of nodes to train on + num_samples: number of image samples used for training + batch_size: batch size per GPU in ddp + dataset: dataset being used for train/val + arch: encoder architecture used for pre-training + hidden_mlp: hidden layer of non-linear projection head, set to 0 + to use a linear projection head + feat_dim: output dim of the projection head + warmup_epochs: apply linear warmup for this many epochs + max_epochs: epoch count for pre-training + num_prototypes: count of prototype vectors + freeze_prototypes_epochs: epoch till which gradients of prototype layer + are frozen + temperature: loss temperature + sinkhorn_iterations: iterations for sinkhorn normalization + queue_length: set queue when batch size is small, + must be divisible by total batch-size (i.e. total_gpus * batch_size), + set to 0 to remove the queue + queue_path: folder within the logs directory + epoch_queue_starts: start uing the queue after this epoch + crops_for_assign: list of crop ids for computing assignment + num_crops: number of global and local crops, ex: [2, 6] + first_conv: keep first conv same as the original resnet architecture, + if set to false it is replace by a kernel 3, stride 1 conv (cifar-10) + maxpool1: keep first maxpool layer same as the original resnet architecture, + if set to false, first maxpool is turned off (cifar10, maybe stl10) + optimizer: optimizer to use + exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers + start_lr: starting lr for linear warmup + learning_rate: learning rate + final_lr: float = final learning rate for cosine weight decay + weight_decay: weight decay for optimizer + epsilon: epsilon val for swav assignments + """ + super().__init__() + self.save_hyperparameters() + + self.gpus = gpus + self.num_nodes = num_nodes + self.arch = arch + self.dataset = dataset + self.num_samples = num_samples + self.batch_size = batch_size + + self.hidden_mlp = hidden_mlp + self.feat_dim = feat_dim + self.num_prototypes = num_prototypes + self.freeze_prototypes_epochs = freeze_prototypes_epochs + self.sinkhorn_iterations = sinkhorn_iterations + + self.queue_length = queue_length + self.queue_path = queue_path + self.epoch_queue_starts = epoch_queue_starts + self.crops_for_assign = crops_for_assign + self.num_crops = num_crops + self.num_augs = num_augs + + self.first_conv = first_conv + self.maxpool1 = maxpool1 + + self.optim = optimizer + self.exclude_bn_bias = exclude_bn_bias + self.weight_decay = weight_decay + self.epsilon = epsilon + self.temperature = temperature + + self.start_lr = start_lr + self.final_lr = final_lr + self.learning_rate = learning_rate + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.just_aug_for_same_assign_views = just_aug_for_same_assign_views + + self.model = self.init_model() + self.criterion = SWAVLoss( + gpus=self.gpus, + num_nodes=self.num_nodes, + temperature=self.temperature, + crops_for_assign=self.crops_for_assign, + num_crops=self.num_crops, + num_augs=self.num_augs, + sinkhorn_iterations=self.sinkhorn_iterations, + epsilon=self.epsilon, + just_aug_for_same_assign_views= self.just_aug_for_same_assign_views + ) + self.use_the_queue = None + # compute iters per epoch + global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size + self.train_iters_per_epoch = self.num_samples // global_batch_size + self.queue = None + + ####For hook#### + self.swin_hparams = swin_hparams + def setup(self, stage): + if self.queue_length > 0: + queue_folder = os.path.join(self.logger.log_dir, self.queue_path) + if not os.path.exists(queue_folder): + os.makedirs(queue_folder) + + self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth") + + if os.path.isfile(self.queue_path): + self.queue = torch.load(self.queue_path)["queue"] + + def init_model(self): + if self.arch == "resnet18": + backbone = resnet18(normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1) + + elif self.arch == "resnet50": + backbone = resnet50(normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1) + elif self.arch == "hook": + backbone = SwinUnet( + img_size=224, + num_classes=5, + normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1,######till here for basenet + image_size= self.swin_hparams.image_size,####from here for swin itself + swin_patch_size = self.swin_hparams.swin_patch_size, + swin_in_chans=self.swin_hparams.swin_in_chans, + swin_embed_dim=self.swin_hparams.swin_embed_dim, + swin_depths=self.swin_hparams.swin_depths, + swin_num_heads=self.swin_hparams.swin_num_heads, + swin_window_size=self.swin_hparams.swin_window_size, + swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio, + swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS, + swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + swin_ape=self.swin_hparams.swin_ape, + swin_patch_norm=self.swin_hparams.swin_path_norm + ) + + elif self.arch == "Unet" or "unet": + backbone = UNet(num_classes=5, + input_channels=1, + num_layers=5, + features_start=64, + bilinear=False, + normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1 + ) + else: + raise f'{self.arch} model is not defined' + + return backbone + + def forward(self, x): + # pass single batch from the resnet backbone + return self.model.forward_backbone(x) + + def on_train_epoch_start(self): + if self.queue_length > 0: + if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None: + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // self.gpus, # change to nodes * gpus once multi-node + self.feat_dim, + ) + + if self.queue is not None: + self.queue = self.queue.to(self.device) + + self.use_the_queue = False + + def on_train_epoch_end(self) -> None: + if self.queue is not None: + torch.save({"queue": self.queue}, self.queue_path) + + def on_after_backward(self): + if self.current_epoch < self.freeze_prototypes_epochs: + for name, p in self.model.named_parameters(): + if "prototypes" in name: + p.grad = None + + def shared_step(self, batch): + if self.dataset == "stl10": + unlabeled_batch = batch[0] + batch = unlabeled_batch + + inputs, y = batch + inputs = inputs[:-1] # remove online train/eval transforms at this point + + # 1. normalize the prototypes + with torch.no_grad(): + w = self.model.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + self.model.prototypes.weight.copy_(w) + + # 2. multi-res forward passes + embedding, output = self.model(inputs) + embedding = embedding.detach() + bs = inputs[0].size(0) + + # SWAV loss computation + loss, queue, use_queue = self.criterion( + output=output, + embedding=embedding, + prototype_weights=self.model.prototypes.weight, + batch_size=bs, + queue=self.queue, + use_queue=self.use_the_queue, + ) + self.queue = queue + self.use_the_queue = use_queue + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + + self.log("train_loss", loss, on_step=True, on_epoch=False) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + return loss + + def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): + params = [] + excluded_params = [] + + for name, param in named_params: + if not param.requires_grad: + continue + if any(layer_name in name for layer_name in skip_list): + excluded_params.append(param) + else: + params.append(param) + + return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}] + + def configure_optimizers(self): + if self.exclude_bn_bias: + params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) + else: + params = self.parameters() + + if self.optim == "lars": + optimizer = LARS( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay, + trust_coefficient=0.001, + ) + elif self.optim == "adam": + optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) + + warmup_steps = self.train_iters_per_epoch * self.warmup_epochs + total_steps = self.train_iters_per_epoch * self.max_epochs + + scheduler = { + "scheduler": torch.optim.lr_scheduler.LambdaLR( + optimizer, + linear_warmup_decay(warmup_steps, total_steps, cosine=True), + ), + "interval": "step", + "frequency": 1, + } + + return [optimizer], [scheduler] + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + # model params + parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") + # specify flags to store false + parser.add_argument("--first_conv", action="store_false") + parser.add_argument("--maxpool1", action="store_false") + parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") + parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") + parser.add_argument("--online_ft", action="store_true") + parser.add_argument("--fp32", action="store_true") + + # transform params + parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") + parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") + parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10") + parser.add_argument("--data_dir", type=str, default=".", help="path to download data") + parser.add_argument("--queue_path", type=str, default="queue", help="path for queue") + + parser.add_argument( + "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])" + ) + parser.add_argument( + "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])" + ) + parser.add_argument( + "--min_scale_crops", + type=float, + default=[0.33, 0.10], + nargs="+", + help="argument in RandomResizedCrop (example: [0.14, 0.05])", + ) + parser.add_argument( + "--max_scale_crops", + type=float, + default=[1, 0.33], + nargs="+", + help="argument in RandomResizedCrop (example: [1., 0.14])", + ) + + # training params + parser.add_argument("--fast_dev_run", default=1, type=int) + parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") + parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") + parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") + parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") + parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") + parser.add_argument("--max_steps", default=-1, type=int, help="max steps") + parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") + parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") + + parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") + parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") + parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") + parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") + + # swav params + parser.add_argument( + "--crops_for_assign", + type=int, + nargs="+", + default=[0, 1], + help="list of crops id used for computing assignments", + ) + parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") + parser.add_argument( + "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm" + ) + parser.add_argument( + "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm" + ) + parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes") + parser.add_argument( + "--queue_length", + type=int, + default=0, + help="length of the queue (0 for no queue); must be divisible by total batch size", + ) + parser.add_argument( + "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue" + ) + parser.add_argument( + "--freeze_prototypes_epochs", + default=1, + type=int, + help="freeze the prototypes during this many epochs from the start", + ) + + return parser + + +def cli_main(): + from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform + + parser = ArgumentParser() + + # model args + parser = SwAV.add_model_specific_args(parser) + args = parser.parse_args() + + if args.dataset == "stl10": + dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + dm.train_dataloader = dm.train_dataloader_mixed + dm.val_dataloader = dm.val_dataloader_mixed + args.num_samples = dm.num_unlabeled_samples + + args.maxpool1 = False + + normalization = stl10_normalization() + elif args.dataset == "cifar10": + args.batch_size = 2 + args.num_workers = 0 + + dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + args.num_samples = dm.num_samples + + args.maxpool1 = False + args.first_conv = False + + normalization = cifar10_normalization() + + # cifar10 specific params + args.size_crops = [32, 16] + args.num_crops = [2, 1] + args.gaussian_blur = False + elif args.dataset == "imagenet": + args.maxpool1 = True + args.first_conv = True + normalization = imagenet_normalization() + + args.size_crops = [224, 96] + args.num_crops = [2, 6] + args.min_scale_crops = [0.14, 0.05] + args.max_scale_crops = [1.0, 0.14] + args.gaussian_blur = True + args.jitter_strength = 1.0 + + args.batch_size = 64 + args.num_nodes = 8 + args.gpus = 8 # per-node + args.max_epochs = 800 + + args.optimizer = "lars" + args.learning_rate = 4.8 + args.final_lr = 0.0048 + args.start_lr = 0.3 + + args.num_prototypes = 3000 + args.online_ft = True + + dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + args.num_samples = dm.num_samples + args.input_height = dm.dims[-1] + else: + raise NotImplementedError("other datasets have not been implemented till now") + + dm.train_transforms = SwAVTrainDataTransform( + normalize=normalization, + size_crops=args.size_crops, + num_crops=args.num_crops, + min_scale_crops=args.min_scale_crops, + max_scale_crops=args.max_scale_crops, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + ) + + dm.val_transforms = SwAVEvalDataTransform( + normalize=normalization, + size_crops=args.size_crops, + num_crops=args.num_crops, + min_scale_crops=args.min_scale_crops, + max_scale_crops=args.max_scale_crops, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + ) + + # swav model init + model = SwAV(**args.__dict__) + + online_evaluator = None + if args.online_ft: + # online eval + online_evaluator = SSLOnlineEvaluator( + drop_p=0.0, + hidden_dim=None, + z_dim=args.hidden_mlp, + num_classes=dm.num_classes, + dataset=args.dataset, + ) + + lr_monitor = LearningRateMonitor(logging_interval="step") + model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss") + callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] + callbacks.append(lr_monitor) + + trainer = Trainer( + max_epochs=args.max_epochs, + max_steps=None if args.max_steps == -1 else args.max_steps, + gpus=args.gpus, + num_nodes=args.num_nodes, + accelerator="ddp" if args.gpus > 1 else None, + sync_batchnorm=args.gpus > 1, + precision=32 if args.fp32 else 16, + callbacks=callbacks, + fast_dev_run=args.fast_dev_run, + ) + + trainer.fit(model, datamodule=dm) + + +if __name__ == "__main__": + cli_main() \ No newline at end of file diff --git a/SSLGlacier/modules/transformers_.py b/SSLGlacier/modules/transformers_.py new file mode 100644 index 0000000..f8290bb --- /dev/null +++ b/SSLGlacier/modules/transformers_.py @@ -0,0 +1,248 @@ +from pl_bolts.models.self_supervised import swav +from torchvision import transforms # need it for validation and fine tuning +from . import agumentations_ as ag +from typing import Tuple, List, Dict +from torch import Tensor + + + +class OurTrainTransformer(): + def __init__(self, + size_crops: Tuple[int] = (294, 94), + nmb_crops: Tuple[int] = (2, 4), + augs:Dict = {'SaltPepper':1}, + use_hook_former: bool = True) -> List: + # TODO you can set it in a way that making sure you have one sample from each argumentation method + self.size_crops = size_crops + self.num_crops = nmb_crops + augmented_imgs = [] + self.use_hook_former = use_hook_former + self.transform = [] + if augs['OrigPixelValues']: + augmented_imgs.append(ag.DoNothing()) + if augs['RGaussianB']: + augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.)) + if augs['GaussiN']: + augmented_imgs.append(ag.GaussNoise(mean=0, var=10000)) + if augs['SaltPepper']: + augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2)) + if augs['ZeroN']: + augmented_imgs.append(ag.SetZeroNoise()) + # Noras augmentations methods + if augs['flip']: + # TODO manage this, how to add mask for the times we need transformation for masks too + augmented_imgs.append(ag.RandomHorizentalFlip()) + # image = torchvision.transforms.functional.hflip(image) + # mask = torchvision.transforms.functional.hflip(mask) + # image, mask = ag.ToTensorZones(image=image, target=np.array(mask)) + # augmented_imgs.append((image, mask)) + if augs['rotate']: + augmented_imgs.append(ag.Rotate()) + if augs['bright']: + augmented_imgs.append(ag.Bright()) + if augs['wrap']: + augmented_imgs.append() + if augs['noise']: + augmented_imgs.append(ag.Noise()) + + if augs['normalize'] is not None: + self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']]) + else: + self.final_transform = ag.ToTensorZones() + + if self.use_hook_former: + self.hook_former_transformer(augmented_imgs) + else: + self.other_networks_transformers(augmented_imgs) + + + + def __call__(self, image, mask): + """ + import matplotlib.pyplot as plt + img1 =transformed_imgs[0].detach().cpu().numpy().squeeze() + plt.imshow(img1) + plt.show() + """ + transformed_img_mask = [] + for transform in self.transform: + if isinstance(transform.__getitem__(0), ag.RandomCropper): + transformed_img_mask.append(transform(image, mask)) + self.i = transform.__getitem__(0).left + self.j = transform.__getitem__(0).right + self.w = transform.__getitem__(0).width + self.h = transform.__getitem__(0).height + elif isinstance(transform.__getitem__(0), ag.Cropper): + transform.__getitem__(0).left = self.i + transform.__getitem__(0).right = self.j + transform.__getitem__(0).width = self.w + transform.__getitem__(0).height = self.h + transformed_img_mask.append(transform(image, mask)) + else: + transformed_img_mask.append(transform(image,mask)) + transformed_imgs = [item[0] for item in transformed_img_mask] + transformed_masks = [item[1] for item in transformed_img_mask] + + + # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12)) + # fig.suptitle("Patches used in training", fontsize=18, y=.95) + # + # for i, image_id in enumerate(transformed_imgs): + # img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze() + # mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze() + # axs[0, i].imshow(img1) + # axs[0, i].imshow(mask1) + # axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze()) + # axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze()) + # plt.savefig('High.png') + return transformed_imgs, transformed_masks + + def hook_former_transformer(self, augmented_imgs): + transform = [] + for i in range(len(self.size_crops)): + i_transform = [] + for ith_aug, aug in enumerate(augmented_imgs): + # For the first time we crop randomly, then save coordinates for further transformation + if ith_aug == 0: + global_crop = ag.RandomCropper(self.size_crops[i]) + i_transform.extend( + [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])] + ) + else: # Crop same area, and do augmentation + local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4) + i_transform.extend( + [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])]) + transform += i_transform * self.num_crops[i] + self.transform = transform + # add online train transform of the size of global view + online_train_transform = ag.Compose( + [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] + ) + + self.transform.append(online_train_transform) + return transform + + def other_networks_transformers(self, augmented_imgs): + transform = [] + for i in range(len(self.size_crops)): + i_transform = [] + for ith_aug, aug in enumerate(augmented_imgs): + # For the first time we crop randomly, then save coordinates for further transformation + if ith_aug == 0: + random_crop = ag.RandomCropper(self.size_crops[i]) + i_transform.extend( + [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])] + ) + else: # Crop same area, and do augmentation + fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width) + i_transform.extend( + [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])]) + transform += i_transform * self.num_crops[i] + + self.transform = transform + # add online train transform of the size of global view + #What was that for?? I forgot! + online_train_transform = ag.Compose( + [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] + ) + + self.transform.append(online_train_transform) + return transform + +class OurEvalTransformer(OurTrainTransformer): + def __init__(self, + size_crops: Tuple[int] = (294, 294), + nmb_crops: Tuple[int] = (2, 4), + augs:Dict = {'SaltPepper':1}, + use_hook_former: bool =True) -> List: + super().__init__(size_crops=size_crops, + nmb_crops=nmb_crops, + augs = augs, + use_hook_former=use_hook_former) + + input_height = self.size_crops[0] # get global view crop + test_transform = ag.Compose( + [ + ag.Resize(int(input_height + 0.1 * input_height)), + ag.CenterCrop(input_height), + self.final_transform, + ] + ) + + # replace last transform to eval transform in self.transform list + self.transform[-1] = test_transform + +class OurFinetuneTransform: + def __init__(self, + input_height: int = 224, + normalize = None, + eval_transform: bool = False) -> None : + + if not eval_transform:# TODO think about it + data_transforms = [ + transforms.RandomResizedCrop(size=self.input_height), + ag.RandomHorizontalFlip(), + ag.SaltPepperNoise(salt_or_pepper=0.5), + transforms.RandomGrayscale(p=0.2), + ] + else: + data_transforms = ag.Compose( + [ + transforms.Resize(int(input_height + 0.1 * input_height)), + transforms.CenterCrop(input_height), + self.final_transform, + ] + ) + + if normalize is None: + final_transform = transforms.ToTensor() + else: + final_transform = transforms.Compose([transforms.ToTensor(), normalize]) + + data_transforms.append(final_transform) + self.transform = ag.Compose(data_transforms) + + def __call__(self, img: Tensor)->Tensor: + return self.transform(img) + +# TODO use it in the code with a flag +class SwapDefaultTransformer(): + def __init__( + self, + size_crops: Tuple[int] = (96, 36), + nmb_crops: Tuple[int] = (2, 4), + min_scale_crops: Tuple[float] = (0.33, 0.10), + max_scale_crops: Tuple[float] = (1, 0.33), + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + ) -> object: + + self.size_crops = size_crops + self.num_crops = nmb_crops + self.min_scale_crops = min_scale_crops + self.max_scale_crops = max_scale_crops + self.gaussian_blur = gaussian_blur + self.jitter_strength = jitter_strength + + def get(self, mode): + if mode == 'train': + return swav.SwAVTrainDataTransform( + #normalize=self.normalization(), + size_crops=self.size_crops, + num_crops=self.num_crops, + min_scale_crops=self.min_scale_crops, + max_scale_crops=self.max_scale_crops, + gaussian_blur=self.gaussian_blur, + jitter_strength=self.jitter_strength + ) + elif mode == 'val': + return swav.SwAVEvalDataTransform( + #normalize=self.normalization(), + size_crops=self.size_crops, + num_crops=self.num_crops, + min_scale_crops=self.min_scale_crops, + max_scale_crops=self.max_scale_crops, + gaussian_blur=self.gaussian_blur, + jitter_strength=self.jitter_strength + ) + -- GitLab From 5e93adb360abfc69d12690c5b400245b947b8bb0 Mon Sep 17 00:00:00 2001 From: Marziyeh <marziyeh.mohammadi@fau.de> Date: Fri, 15 Dec 2023 09:46:48 +0100 Subject: [PATCH 11/11] Refactoring: Make the folder names more clear. --- SSLGlacier/processing/agumentations_.py | 411 ----------------- SSLGlacier/processing/datamodule_.py | 149 ------- SSLGlacier/processing/semi_module.py | 29 -- SSLGlacier/processing/swav_module.py | 562 ------------------------ SSLGlacier/processing/transformers_.py | 248 ----------- 5 files changed, 1399 deletions(-) delete mode 100644 SSLGlacier/processing/agumentations_.py delete mode 100644 SSLGlacier/processing/datamodule_.py delete mode 100644 SSLGlacier/processing/semi_module.py delete mode 100644 SSLGlacier/processing/swav_module.py delete mode 100644 SSLGlacier/processing/transformers_.py diff --git a/SSLGlacier/processing/agumentations_.py b/SSLGlacier/processing/agumentations_.py deleted file mode 100644 index 05cb774..0000000 --- a/SSLGlacier/processing/agumentations_.py +++ /dev/null @@ -1,411 +0,0 @@ -# PILRandomGaussianBlur and get_color_distortion are used and implemented -# by Swav and SimCLR - https://arxiv.org/abs/2002.05709 -import random -from logging import getLogger - -import torch.nn -from PIL import ImageFilter -import numpy as np -import torchvision.transforms as transforms -import torchvision, tormentor -from torchvision.transforms import functional as F -import torch -import matplotlib.pyplot as plt - -logger = getLogger() - - -class Compose(object): - """ - Class for chaining transforms together - """ - - def __init__(self, transforms): - self.transforms = transforms - - def __call__(self, image, target): - for t in self.transforms: - image, target = t(image, target) - return [image, target] - - def __getitem__(self, item): - return self.transforms[item] - -class DoNothing(torch.nn.Module): - def __init__(self): - super(DoNothing, self).__init__() - - def __call__(self, img, mask): - return [img, mask] - -class Cropper(torch.nn.Module): - def __init__(self, i, j, h, w): - super(Cropper, self).__init__() - self.left = i - self.right = j - self.height = h - self.width = w - - def __call__(self, img, mask): - cropped_img = F.crop(img, self.left, self.right, self.height, self.width) - cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) - return [cropped_img, cropped_mask] - - -class RandomCropper(torch.nn.Module): - ''' - This function returns one patch at time, if you need more patches in one image, call it in a loop - Args: - orig_img: get an png or jpg image and crop one patch randomly, in both image and mask - orig_mask: This is the images mask, we crop same area from - Returns: - A list of two argument, first cropped patch in image,second cropped patch in mask - ''' - - def __init__(self, size): - super(RandomCropper, self).__init__() - self.left = 0 - self.right = 0 - self.height = 0 - self.width = 0 - self.size = size - - def forward(self, img, mask): - self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params( - img, output_size=(self.size, self.size)) - cropped_img = F.crop(img, self.left, self.right, self.height, self.width) - cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) - return [cropped_img, cropped_mask] - - -class PILRandomGaussianBlur(torch.nn.Module): - def __init__(self, radius_min=0.1, radius_max=2.): - """ - Apply Gaussian Blur to the PIL image. Take the radius and probability of - application as the parameter. - This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 - """ - super(PILRandomGaussianBlur, self).__init__() - self.radius_min = radius_min - self.radius_max = radius_max - - def forward(self, img, mask): - return [img.filter( - ImageFilter.GaussianBlur( - radius=random.uniform(self.radius_min, self.radius_max) - ) - ), mask] - - -class RandomHorizontalFlip(torch.nn.Module): - def __init__(self): - super(RandomHorizontalFlip, self).__init__() - - def forward(self, img, mask): - image = torchvision.transforms.functional.hflip(img) - mask = torchvision.transforms.functional.hflip(mask) - return [image, mask] - - -class GetColorDistortion(torch.nn.Module): - def __int__(self, s=1.0): - super(GetColorDistortion, self).__init__() - self.s = s - - def forward(self, img, mask): - color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) - rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) - rnd_gray = transforms.RandomGrayscale(p=0.2) - color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) - return color_distort(img, mask) - - -# TODO this net -class GaussNoise(torch.nn.Module): - def __init__(self, mean=0, var=10000): - super(GaussNoise, self).__init__() - self.mean = mean - self.var = var - - def forward(self, img, mask): - row, col = img.size - sigma = self.var ** 0.5 - gauss = np.random.normal(self.mean, sigma, (row, col)) - gauss = gauss.reshape(row, col) - noisy = img + gauss - return [noisy, mask] - - -class SaltPepperNoise(torch.nn.Module): - def __init__(self, salt_or_pepper=0.5): - super(SaltPepperNoise, self).__init__() - self.salt_or_pepper = salt_or_pepper - - def forward(self, img, mask): - if len(img.size) == 3: - img_size = img.size[1] * img.size[2] - else: - img_size = img.size[0] * img.size[1] - amount = .4 - noisy = np.copy(img) - # Salt mode - num_salt = np.ceil(amount * img_size * self.salt_or_pepper) - target_pixels = [np.random.randint(0, i - 1, int(num_salt)) - for i in img.size] - target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) - for i, j in target_pixels: - noisy[i][j] = 1 - # Pepper mode - num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper)) - target_pixels = [np.random.randint(0, i - 1, int(num_pepper)) - for i in img.size] - target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) - - for i, j in target_pixels: - noisy[i][j] = 0 - - # plt.imshow(img) - # plt.imshow(noisy) - # plt.show() - return [noisy, mask] - - -class PoissionNoise(torch.nn.Module): - def __init__(self): - super(PoissionNoise, self).__init__() - - def forward(self, img, mask): - vals = len(np.unique(img)) - vals = 2 ** np.ceil(np.log2(vals)) - noisy = np.random.poisson(img * vals) / float(vals) - return [noisy, mask] - - -# TODO this one -class SpeckleNoise(torch.nn.Module): - def __init__(self): - super(SpeckleNoise, self).__init__() - - def forward(self, img, mask): - row, col = img.size - gauss = np.random.randn(row, col) - gauss = gauss.reshape(row, col) - noisy = img + img * gauss - return [noisy, mask] - - -class SetZeroNoise(torch.nn.Module): - def __init__(self): - super(SetZeroNoise, self).__init__() - - def forward(self, img, mask): - row, col = img.size - img_size = row * col - random_rows = np.random.randint(0, int(row), (1, int(img_size))) - random_cols = np.random.randint(0, int(col), (1, int(img_size))) - target_pixels = list(zip(random_rows[0], random_cols[0])) - for pix in target_pixels: - img[pix[0], pix[1]] = 0 - return [img, mask] - - -#################################################################################### - -###########Base_code Agumentations suggestions: #TODO do not need them now########## -class MyWrap(torch.nn.Module): - """ - Random wrap augmentation taken from tormentor - """ - - def __init__(self): - super(MyWrap, self).__init__() - - def forward(self, img, target): - # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics - wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)), - intensity=tormentor.random.Uniform(value_range=(.0, 1.))) - wrap = wrap_rand() - image = wrap(img) - mask = wrap(target, is_mask=True) - return [image, mask] - - -class Rotate(torch.nn.Module): - """ - Random rotation augmentation - """ - - def __init__(self): - super(Rotate, self).__init__() - - def forward(self, img, target): - random = np.random.randint(0, 3) - angle = 90 - if random == 1: - angle = 180 - elif random == 2: - angle = 270 - image = torchvision.transforms.functional.rotate(img, angle=angle) - mask = torchvision.transforms.functional.rotate(target, angle=angle) - return [image, mask.squeeze(0)] - - -class Bright(torch.nn.Module): - """ - Random brightness adjustment augmentations - """ - - def __init__(self, - lower_band: float = -0.2, - upper_band: float = 0.2): - super(Bright, self).__init__() - self.lower_band = lower_band - self.upper_band = upper_band - - def forward(self, img, mask): - bright_rand = tormentor.Brightness.override_distributions( - brightness=tormentor.random.Uniform((self.lower_band, self.upper_band))) - # bright = bright_rand() - # image_transformed = img.clone() - image_transformed = bright_rand(np.asarray(img)) - # set NA areas back to zero - image_transformed.seed[img == 0] = 0.0 - - return [image_transformed.seed, mask] - - -class Noise(torch.nn.Module): - """ - Random additive noise augmentation - """ - - def __init__(self): - super(Noise, self).__init__() - - def forward(self, img, target): - # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again - noise = torch.normal(mean=0, std=0.3, size=img.size) - image = img + img * noise - image[image > 1.0] = 1.0 - image[image < 0.0] = 0.0 - - return [image, target] - - -class Resize(torch.nn.Module): - # Library code - """Resize the input image to the given size. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means a maximum of two leading dimensions - - .. warning:: - The output image might be different depending on its type: when downsampling, the interpolation of PIL images - and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences - in the performance of a network. Therefore, it is preferable to train and serve a model with the same input - types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors - closer. - - Args: - size (sequence or int): Desired output size. If size is a sequence like - (h, w), output size will be matched to this. If size is an int, - smaller edge of the image will be matched to this number. - i.e, if height > width, then image will be rescaled to - (size * height / width, size). - - .. note:: - In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, - ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. - The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. - max_size (int, optional): The maximum allowed for the longer edge of - the resized image. If the longer edge of the image is greater - than ``max_size`` after being resized according to ``size``, - ``size`` will be overruled so that the longer edge is equal to - ``max_size``. - As a result, the smaller edge may be shorter than ``size``. This - is only supported if ``size`` is an int (or a sequence of length - 1 in torchscript mode). - antialias (bool, optional): Whether to apply antialiasing. - It only affects **tensors** with bilinear or bicubic modes and it is - ignored otherwise: on PIL images, antialiasing is always applied on - bilinear or bicubic modes; on other modes (for PIL images and - tensors), antialiasing makes no sense and this parameter is ignored. - Possible values are: - - - ``True``: will apply antialiasing for bilinear or bicubic modes. - Other mode aren't affected. This is probably what you want to use. - - ``False``: will not apply antialiasing for tensors on any mode. PIL - images are still antialiased on bilinear or bicubic modes, because - PIL doesn't support no antialias. - - ``None``: equivalent to ``False`` for tensors and ``True`` for - PIL images. This value exists for legacy reasons and you probably - don't want to use it unless you really know what you are doing. - - The current default is ``None`` **but will change to** ``True`` **in - v0.17** for the PIL and Tensor backends to be consistent. - """ - - def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"): - super().__init__() - self.size = size - self.max_size = max_size - self.interpolation = interpolation - self.antialias = antialias - - def forward(self, img, mask): - """ - Args: - img (PIL Image or Tensor): Image to be scaled. - - Returns: - PIL Image or Tensor: Rescaled image. - """ - return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask] - - -class CenterCrop(torch.nn.Module): - """Crops the given image at the center. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - """ - - def __init__(self, size): - super().__init__() - self.size = size # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - - def forward(self, img, mask): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER - W, H = img.size - - WW = (W // self.size) + 2 - HH = (H // self.size) + 2 - - return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))] - - -class ToTensorZones(): - def __call__(self, image, target): - image = F.to_tensor(np.array(image).astype(np.float32)) - target = torch.from_numpy(np.array(target).astype(np.float32)) - # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254 - target[target == 0] = 0 - target[target == 64] = 1 - target[target == 127] = 2 - target[target == 254] = 3 - # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3 - return [image, target] diff --git a/SSLGlacier/processing/datamodule_.py b/SSLGlacier/processing/datamodule_.py deleted file mode 100644 index edf3648..0000000 --- a/SSLGlacier/processing/datamodule_.py +++ /dev/null @@ -1,149 +0,0 @@ -from torchvision.transforms import functional as F -from pl_bolts.models.self_supervised import swav -from torch.utils.data import DataLoader -from torch.utils.data import Dataset -import pytorch_lightning as pl -from torchvision import transforms as transform_lib -from . import agumentations_ as our_transforms -import torchvision -import numpy as np -from typing import Tuple, List, Any, Callable, Optional -import torch -import cv2 -from PIL import Image,ImageOps -import os - -# TODO add moco, simclr,cpc and amdim transformers -class SSLDataset(Dataset): - def __init__(self, parent_dir, transform,return_index=False, **kwargs): # TODO Pass them as arguments - ''' - Args: - mode: can be tested, train, or validation - parent_dir: directory in which the folders test, train and validation exists - ''' - self.mode = kwargs['mode'] if kwargs['mode'] else 'train' - self.return_index = return_index - self.images_path = os.path.join(parent_dir, "sar_images", self.mode) - self.masks_path = os.path.join(parent_dir, 'zones', self.mode) - self.images = os.listdir(self.images_path) - self.masks = os.listdir(self.masks_path) - assert len(self.masks) == len(self.images), "You don't have the same number of images and masks" - self.transform = transform - # Sort than images and masks fit together - self.images.sort() - self.masks.sort() - - # Let shuffle and save indices, and load them in next calls if they exist - if not os.path.exists(os.path.join("data_processing", "data_splits")): - os.makedirs(os.path.join("data_processing", "data_splits")) - if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")): - shuffle = np.random.permutation(len(self.images)) - # Works for numpy version >= 1.5 - np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, - newline=' ') - - else: - # use already existing shuffle - with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp: - lines = fp.readlines() - shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines] - # if lengths do not match, we need to create a new permutation - if len(shuffle) != len(self.images): - shuffle = np.random.permutation(len(self.images)) - np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, - newline=' ') - - self.images = np.array(self.images) # Why!!!! - self.masks = np.array(self.masks) - self.images = self.images[shuffle].copy() - self.masks = self.masks[shuffle].copy() - self.images = list(self.images) - self.masks = list(self.masks) # Why!!!! - - def __len__(self): - return len(self.images) - - def __getitem__(self, index): - img_name = self.images[index] - masks_name = self.masks[index] - assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \ - "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name - #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE) - #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE) - image = Image.open(os.path.join(self.images_path, img_name).__str__()) - image = ImageOps.grayscale(image) - #image = np.array(image).astype(np.float32) - - #rgbimg = Image.new("RGBA", image.size) - #image = rgbimg.paste(image) - #image = ImageOps.grayscale(image) - mask = Image.open(os.path.join(self.masks_path, masks_name).__str__()) - #rgbimg = Image.new("RGBA", mask.size) - #mask = rgbimg.paste(rgbimg) - mask = ImageOps.grayscale(mask) - #mask = np.array(mask).astype(np.float32) - # TODO check if it works or not, it should return multiple transformation for one image - # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!! - # TODO this is a list of tuples of img and masks, think if you want to change it later - - # to_tensor = our_transforms.ToTensorZones() - # _, mask = to_tensor(image,mask) - if self.transform is not None: - augmented_imgs, masks = self.transform(image, mask) - if self.return_index: - return index, augmented_imgs, mask, img_name, masks_name - return augmented_imgs, masks - - -class SSLDataModule(pl.LightningDataModule): - # Base is adopted from Nora's code - def __init__(self, batch_size, parent_dir, args): - """ - :param batch_size: batch size - :param target: Either 'zones' or 'front'. Tells which masks should be used. - """ - super().__init__() - self.transforms = None - self.batch_size = batch_size - self.glacier_test = None - self.glacier_train = None - self.glacier_val = None - self.parent_dir = parent_dir - self.aug_args = args.agumentations - self.size_crops = args.size_crops - self.num_crops = args.nmb_crops - - def prepare_data(self): - # download, - # only called on 1 GPU/TPU in distributed - pass - def __len__(self): - return len(self.glacier_train) - def setup(self, stage=None): - # process and split here - # make assignments here (val/train/test split) - # called on every process in DDP - if stage == 'test' or stage is None: - self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms - self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage) - if stage == 'fit' or stage is None: - self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms - self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train') - - self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms - self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val') - - def train_dataloader(self) -> DataLoader: - return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True, - drop_last=True) - - def val_dataloader(self)-> DataLoader: - return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True) - - def test_dataloader(self)-> DataLoader: - return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True, - drop_last=True) # TODO self.batch_size - - def _default_transforms(self) -> Callable: - #return transform_lib.Compose([transform_lib.ToTensor()]) - return our_transforms.Compose([our_transforms.ToTensorZones()]) diff --git a/SSLGlacier/processing/semi_module.py b/SSLGlacier/processing/semi_module.py deleted file mode 100644 index 9798a36..0000000 --- a/SSLGlacier/processing/semi_module.py +++ /dev/null @@ -1,29 +0,0 @@ -from pytorch_lightning import LightningModule -from typing import Any -class Semi_(LightningModule): - def __init__(self): - pass - - def setup(self, stage: str) -> None: - pass - - def init_model(self): - pass - - def forward(self, *args: Any, **kwargs: Any) -> Any: - pass - - def on_train_epoch_start(self) -> None: - pass - - def shared_step(self, batch): - inputs, y = batch - inputs = inputs[:, -1] - embedding = self.model(inputs) - ####Where should I add noise.... - ####assume embedding[b][i] belongs to batch b view i - ####if we have two view, one with noise and one without, then - ##### select some portion of each view belongs to different classes. - ##### pos_set = same classess .. - ##### neg_set = diff classes - diff --git a/SSLGlacier/processing/swav_module.py b/SSLGlacier/processing/swav_module.py deleted file mode 100644 index 3b57a4e..0000000 --- a/SSLGlacier/processing/swav_module.py +++ /dev/null @@ -1,562 +0,0 @@ -"""Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" -import os -from argparse import ArgumentParser - -import torch -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from torch import nn - -#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss -from SSLGlacier.models.losses import SWAVLoss -from SSLGlacier.models.RESNET import resnet18, resnet50 -#from Base_Code.models.ParentUNet import UNet -from SSLGlacier.models.UNET import UNet #The main network -from pl_bolts.optimizers.lars import LARS -from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay - -##########Importing files for hook############ -from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet -######Temporary config file for hooknet####### -from utils.config import get_config - -class SwAV(LightningModule): - def __init__( - self, - gpus: int, - num_samples: int, - batch_size: int, - dataset: str, - num_nodes: int = 1, - arch: str = "resnet50", - hidden_mlp: int = 2048, - feat_dim: int = 128, - warmup_epochs: int = 10, - max_epochs: int = 100, - num_prototypes: int = 3000, - freeze_prototypes_epochs: int = 1, - temperature: float = 0.1, - sinkhorn_iterations: int = 3, - queue_length: int = 0, # must be divisible by total batch-size - queue_path: str = "queue", - epoch_queue_starts: int = 15, - crops_for_assign: tuple = (0, 1), - num_crops: tuple = (2, 6), - num_augs: int= 2, - first_conv: bool = True, - maxpool1: bool = True, - optimizer: str = "adam", - exclude_bn_bias: bool = False, - start_lr: float = 0.0, - learning_rate: float = 1e-3, - final_lr: float = 0.0, - weight_decay: float = 1e-6, - epsilon: float = 0.05, - just_aug_for_same_assign_views: bool = False, - swin_hparams = {} - ) -> None: - """ - Args: - gpus: number of gpus per node used in training, passed to SwAV module - to manage the queue and select distributed sinkhorn - num_nodes: number of nodes to train on - num_samples: number of image samples used for training - batch_size: batch size per GPU in ddp - dataset: dataset being used for train/val - arch: encoder architecture used for pre-training - hidden_mlp: hidden layer of non-linear projection head, set to 0 - to use a linear projection head - feat_dim: output dim of the projection head - warmup_epochs: apply linear warmup for this many epochs - max_epochs: epoch count for pre-training - num_prototypes: count of prototype vectors - freeze_prototypes_epochs: epoch till which gradients of prototype layer - are frozen - temperature: loss temperature - sinkhorn_iterations: iterations for sinkhorn normalization - queue_length: set queue when batch size is small, - must be divisible by total batch-size (i.e. total_gpus * batch_size), - set to 0 to remove the queue - queue_path: folder within the logs directory - epoch_queue_starts: start uing the queue after this epoch - crops_for_assign: list of crop ids for computing assignment - num_crops: number of global and local crops, ex: [2, 6] - first_conv: keep first conv same as the original resnet architecture, - if set to false it is replace by a kernel 3, stride 1 conv (cifar-10) - maxpool1: keep first maxpool layer same as the original resnet architecture, - if set to false, first maxpool is turned off (cifar10, maybe stl10) - optimizer: optimizer to use - exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers - start_lr: starting lr for linear warmup - learning_rate: learning rate - final_lr: float = final learning rate for cosine weight decay - weight_decay: weight decay for optimizer - epsilon: epsilon val for swav assignments - """ - super().__init__() - self.save_hyperparameters() - - self.gpus = gpus - self.num_nodes = num_nodes - self.arch = arch - self.dataset = dataset - self.num_samples = num_samples - self.batch_size = batch_size - - self.hidden_mlp = hidden_mlp - self.feat_dim = feat_dim - self.num_prototypes = num_prototypes - self.freeze_prototypes_epochs = freeze_prototypes_epochs - self.sinkhorn_iterations = sinkhorn_iterations - - self.queue_length = queue_length - self.queue_path = queue_path - self.epoch_queue_starts = epoch_queue_starts - self.crops_for_assign = crops_for_assign - self.num_crops = num_crops - self.num_augs = num_augs - - self.first_conv = first_conv - self.maxpool1 = maxpool1 - - self.optim = optimizer - self.exclude_bn_bias = exclude_bn_bias - self.weight_decay = weight_decay - self.epsilon = epsilon - self.temperature = temperature - - self.start_lr = start_lr - self.final_lr = final_lr - self.learning_rate = learning_rate - self.warmup_epochs = warmup_epochs - self.max_epochs = max_epochs - self.just_aug_for_same_assign_views = just_aug_for_same_assign_views - - self.model = self.init_model() - self.criterion = SWAVLoss( - gpus=self.gpus, - num_nodes=self.num_nodes, - temperature=self.temperature, - crops_for_assign=self.crops_for_assign, - num_crops=self.num_crops, - num_augs=self.num_augs, - sinkhorn_iterations=self.sinkhorn_iterations, - epsilon=self.epsilon, - just_aug_for_same_assign_views= self.just_aug_for_same_assign_views - ) - self.use_the_queue = None - # compute iters per epoch - global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size - self.train_iters_per_epoch = self.num_samples // global_batch_size - self.queue = None - - ####For hook#### - self.swin_hparams = swin_hparams - def setup(self, stage): - if self.queue_length > 0: - queue_folder = os.path.join(self.logger.log_dir, self.queue_path) - if not os.path.exists(queue_folder): - os.makedirs(queue_folder) - - self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth") - - if os.path.isfile(self.queue_path): - self.queue = torch.load(self.queue_path)["queue"] - - def init_model(self): - if self.arch == "resnet18": - backbone = resnet18(normalize=True, - hidden_mlp=self.hidden_mlp, - output_dim=self.feat_dim, - num_prototypes=self.num_prototypes, - first_conv=self.first_conv, - maxpool1=self.maxpool1) - - elif self.arch == "resnet50": - backbone = resnet50(normalize=True, - hidden_mlp=self.hidden_mlp, - output_dim=self.feat_dim, - num_prototypes=self.num_prototypes, - first_conv=self.first_conv, - maxpool1=self.maxpool1) - elif self.arch == "hook": - backbone = SwinUnet( - img_size=224, - num_classes=5, - normalize=True, - hidden_mlp=self.hidden_mlp, - output_dim=self.feat_dim, - num_prototypes=self.num_prototypes, - first_conv=self.first_conv, - maxpool1=self.maxpool1,######till here for basenet - image_size= self.swin_hparams.image_size,####from here for swin itself - swin_patch_size = self.swin_hparams.swin_patch_size, - swin_in_chans=self.swin_hparams.swin_in_chans, - swin_embed_dim=self.swin_hparams.swin_embed_dim, - swin_depths=self.swin_hparams.swin_depths, - swin_num_heads=self.swin_hparams.swin_num_heads, - swin_window_size=self.swin_hparams.swin_window_size, - swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio, - swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS, - swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE, - drop_rate=self.drop_rate, - drop_path_rate=self.drop_path_rate, - swin_ape=self.swin_hparams.swin_ape, - swin_patch_norm=self.swin_hparams.swin_path_norm - ) - - elif self.arch == "Unet" or "unet": - backbone = UNet(num_classes=5, - input_channels=1, - num_layers=5, - features_start=64, - bilinear=False, - normalize=True, - hidden_mlp=self.hidden_mlp, - output_dim=self.feat_dim, - num_prototypes=self.num_prototypes, - first_conv=self.first_conv, - maxpool1=self.maxpool1 - ) - else: - raise f'{self.arch} model is not defined' - - return backbone - - def forward(self, x): - # pass single batch from the resnet backbone - return self.model.forward_backbone(x) - - def on_train_epoch_start(self): - if self.queue_length > 0: - if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None: - self.queue = torch.zeros( - len(self.crops_for_assign), - self.queue_length // self.gpus, # change to nodes * gpus once multi-node - self.feat_dim, - ) - - if self.queue is not None: - self.queue = self.queue.to(self.device) - - self.use_the_queue = False - - def on_train_epoch_end(self) -> None: - if self.queue is not None: - torch.save({"queue": self.queue}, self.queue_path) - - def on_after_backward(self): - if self.current_epoch < self.freeze_prototypes_epochs: - for name, p in self.model.named_parameters(): - if "prototypes" in name: - p.grad = None - - def shared_step(self, batch): - if self.dataset == "stl10": - unlabeled_batch = batch[0] - batch = unlabeled_batch - - inputs, y = batch - inputs = inputs[:-1] # remove online train/eval transforms at this point - - # 1. normalize the prototypes - with torch.no_grad(): - w = self.model.prototypes.weight.data.clone() - w = nn.functional.normalize(w, dim=1, p=2) - self.model.prototypes.weight.copy_(w) - - # 2. multi-res forward passes - embedding, output = self.model(inputs) - embedding = embedding.detach() - bs = inputs[0].size(0) - - # SWAV loss computation - loss, queue, use_queue = self.criterion( - output=output, - embedding=embedding, - prototype_weights=self.model.prototypes.weight, - batch_size=bs, - queue=self.queue, - use_queue=self.use_the_queue, - ) - self.queue = queue - self.use_the_queue = use_queue - return loss - - def training_step(self, batch, batch_idx): - loss = self.shared_step(batch) - - self.log("train_loss", loss, on_step=True, on_epoch=False) - return loss - - def validation_step(self, batch, batch_idx): - loss = self.shared_step(batch) - - self.log("val_loss", loss, on_step=False, on_epoch=True) - return loss - - def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): - params = [] - excluded_params = [] - - for name, param in named_params: - if not param.requires_grad: - continue - if any(layer_name in name for layer_name in skip_list): - excluded_params.append(param) - else: - params.append(param) - - return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}] - - def configure_optimizers(self): - if self.exclude_bn_bias: - params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) - else: - params = self.parameters() - - if self.optim == "lars": - optimizer = LARS( - params, - lr=self.learning_rate, - momentum=0.9, - weight_decay=self.weight_decay, - trust_coefficient=0.001, - ) - elif self.optim == "adam": - optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) - - warmup_steps = self.train_iters_per_epoch * self.warmup_epochs - total_steps = self.train_iters_per_epoch * self.max_epochs - - scheduler = { - "scheduler": torch.optim.lr_scheduler.LambdaLR( - optimizer, - linear_warmup_decay(warmup_steps, total_steps, cosine=True), - ), - "interval": "step", - "frequency": 1, - } - - return [optimizer], [scheduler] - - @staticmethod - def add_model_specific_args(parent_parser): - parser = ArgumentParser(parents=[parent_parser], add_help=False) - - # model params - parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") - # specify flags to store false - parser.add_argument("--first_conv", action="store_false") - parser.add_argument("--maxpool1", action="store_false") - parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") - parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") - parser.add_argument("--online_ft", action="store_true") - parser.add_argument("--fp32", action="store_true") - - # transform params - parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") - parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") - parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10") - parser.add_argument("--data_dir", type=str, default=".", help="path to download data") - parser.add_argument("--queue_path", type=str, default="queue", help="path for queue") - - parser.add_argument( - "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])" - ) - parser.add_argument( - "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])" - ) - parser.add_argument( - "--min_scale_crops", - type=float, - default=[0.33, 0.10], - nargs="+", - help="argument in RandomResizedCrop (example: [0.14, 0.05])", - ) - parser.add_argument( - "--max_scale_crops", - type=float, - default=[1, 0.33], - nargs="+", - help="argument in RandomResizedCrop (example: [1., 0.14])", - ) - - # training params - parser.add_argument("--fast_dev_run", default=1, type=int) - parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") - parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") - parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") - parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") - parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") - parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") - parser.add_argument("--max_steps", default=-1, type=int, help="max steps") - parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") - parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") - - parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") - parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") - parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") - parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") - - # swav params - parser.add_argument( - "--crops_for_assign", - type=int, - nargs="+", - default=[0, 1], - help="list of crops id used for computing assignments", - ) - parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") - parser.add_argument( - "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm" - ) - parser.add_argument( - "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm" - ) - parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes") - parser.add_argument( - "--queue_length", - type=int, - default=0, - help="length of the queue (0 for no queue); must be divisible by total batch size", - ) - parser.add_argument( - "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue" - ) - parser.add_argument( - "--freeze_prototypes_epochs", - default=1, - type=int, - help="freeze the prototypes during this many epochs from the start", - ) - - return parser - - -def cli_main(): - from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator - from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule - from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform - - parser = ArgumentParser() - - # model args - parser = SwAV.add_model_specific_args(parser) - args = parser.parse_args() - - if args.dataset == "stl10": - dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) - - dm.train_dataloader = dm.train_dataloader_mixed - dm.val_dataloader = dm.val_dataloader_mixed - args.num_samples = dm.num_unlabeled_samples - - args.maxpool1 = False - - normalization = stl10_normalization() - elif args.dataset == "cifar10": - args.batch_size = 2 - args.num_workers = 0 - - dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) - - args.num_samples = dm.num_samples - - args.maxpool1 = False - args.first_conv = False - - normalization = cifar10_normalization() - - # cifar10 specific params - args.size_crops = [32, 16] - args.num_crops = [2, 1] - args.gaussian_blur = False - elif args.dataset == "imagenet": - args.maxpool1 = True - args.first_conv = True - normalization = imagenet_normalization() - - args.size_crops = [224, 96] - args.num_crops = [2, 6] - args.min_scale_crops = [0.14, 0.05] - args.max_scale_crops = [1.0, 0.14] - args.gaussian_blur = True - args.jitter_strength = 1.0 - - args.batch_size = 64 - args.num_nodes = 8 - args.gpus = 8 # per-node - args.max_epochs = 800 - - args.optimizer = "lars" - args.learning_rate = 4.8 - args.final_lr = 0.0048 - args.start_lr = 0.3 - - args.num_prototypes = 3000 - args.online_ft = True - - dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) - - args.num_samples = dm.num_samples - args.input_height = dm.dims[-1] - else: - raise NotImplementedError("other datasets have not been implemented till now") - - dm.train_transforms = SwAVTrainDataTransform( - normalize=normalization, - size_crops=args.size_crops, - num_crops=args.num_crops, - min_scale_crops=args.min_scale_crops, - max_scale_crops=args.max_scale_crops, - gaussian_blur=args.gaussian_blur, - jitter_strength=args.jitter_strength, - ) - - dm.val_transforms = SwAVEvalDataTransform( - normalize=normalization, - size_crops=args.size_crops, - num_crops=args.num_crops, - min_scale_crops=args.min_scale_crops, - max_scale_crops=args.max_scale_crops, - gaussian_blur=args.gaussian_blur, - jitter_strength=args.jitter_strength, - ) - - # swav model init - model = SwAV(**args.__dict__) - - online_evaluator = None - if args.online_ft: - # online eval - online_evaluator = SSLOnlineEvaluator( - drop_p=0.0, - hidden_dim=None, - z_dim=args.hidden_mlp, - num_classes=dm.num_classes, - dataset=args.dataset, - ) - - lr_monitor = LearningRateMonitor(logging_interval="step") - model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss") - callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] - callbacks.append(lr_monitor) - - trainer = Trainer( - max_epochs=args.max_epochs, - max_steps=None if args.max_steps == -1 else args.max_steps, - gpus=args.gpus, - num_nodes=args.num_nodes, - accelerator="ddp" if args.gpus > 1 else None, - sync_batchnorm=args.gpus > 1, - precision=32 if args.fp32 else 16, - callbacks=callbacks, - fast_dev_run=args.fast_dev_run, - ) - - trainer.fit(model, datamodule=dm) - - -if __name__ == "__main__": - cli_main() \ No newline at end of file diff --git a/SSLGlacier/processing/transformers_.py b/SSLGlacier/processing/transformers_.py deleted file mode 100644 index f8290bb..0000000 --- a/SSLGlacier/processing/transformers_.py +++ /dev/null @@ -1,248 +0,0 @@ -from pl_bolts.models.self_supervised import swav -from torchvision import transforms # need it for validation and fine tuning -from . import agumentations_ as ag -from typing import Tuple, List, Dict -from torch import Tensor - - - -class OurTrainTransformer(): - def __init__(self, - size_crops: Tuple[int] = (294, 94), - nmb_crops: Tuple[int] = (2, 4), - augs:Dict = {'SaltPepper':1}, - use_hook_former: bool = True) -> List: - # TODO you can set it in a way that making sure you have one sample from each argumentation method - self.size_crops = size_crops - self.num_crops = nmb_crops - augmented_imgs = [] - self.use_hook_former = use_hook_former - self.transform = [] - if augs['OrigPixelValues']: - augmented_imgs.append(ag.DoNothing()) - if augs['RGaussianB']: - augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.)) - if augs['GaussiN']: - augmented_imgs.append(ag.GaussNoise(mean=0, var=10000)) - if augs['SaltPepper']: - augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2)) - if augs['ZeroN']: - augmented_imgs.append(ag.SetZeroNoise()) - # Noras augmentations methods - if augs['flip']: - # TODO manage this, how to add mask for the times we need transformation for masks too - augmented_imgs.append(ag.RandomHorizentalFlip()) - # image = torchvision.transforms.functional.hflip(image) - # mask = torchvision.transforms.functional.hflip(mask) - # image, mask = ag.ToTensorZones(image=image, target=np.array(mask)) - # augmented_imgs.append((image, mask)) - if augs['rotate']: - augmented_imgs.append(ag.Rotate()) - if augs['bright']: - augmented_imgs.append(ag.Bright()) - if augs['wrap']: - augmented_imgs.append() - if augs['noise']: - augmented_imgs.append(ag.Noise()) - - if augs['normalize'] is not None: - self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']]) - else: - self.final_transform = ag.ToTensorZones() - - if self.use_hook_former: - self.hook_former_transformer(augmented_imgs) - else: - self.other_networks_transformers(augmented_imgs) - - - - def __call__(self, image, mask): - """ - import matplotlib.pyplot as plt - img1 =transformed_imgs[0].detach().cpu().numpy().squeeze() - plt.imshow(img1) - plt.show() - """ - transformed_img_mask = [] - for transform in self.transform: - if isinstance(transform.__getitem__(0), ag.RandomCropper): - transformed_img_mask.append(transform(image, mask)) - self.i = transform.__getitem__(0).left - self.j = transform.__getitem__(0).right - self.w = transform.__getitem__(0).width - self.h = transform.__getitem__(0).height - elif isinstance(transform.__getitem__(0), ag.Cropper): - transform.__getitem__(0).left = self.i - transform.__getitem__(0).right = self.j - transform.__getitem__(0).width = self.w - transform.__getitem__(0).height = self.h - transformed_img_mask.append(transform(image, mask)) - else: - transformed_img_mask.append(transform(image,mask)) - transformed_imgs = [item[0] for item in transformed_img_mask] - transformed_masks = [item[1] for item in transformed_img_mask] - - - # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12)) - # fig.suptitle("Patches used in training", fontsize=18, y=.95) - # - # for i, image_id in enumerate(transformed_imgs): - # img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze() - # mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze() - # axs[0, i].imshow(img1) - # axs[0, i].imshow(mask1) - # axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze()) - # axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze()) - # plt.savefig('High.png') - return transformed_imgs, transformed_masks - - def hook_former_transformer(self, augmented_imgs): - transform = [] - for i in range(len(self.size_crops)): - i_transform = [] - for ith_aug, aug in enumerate(augmented_imgs): - # For the first time we crop randomly, then save coordinates for further transformation - if ith_aug == 0: - global_crop = ag.RandomCropper(self.size_crops[i]) - i_transform.extend( - [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])] - ) - else: # Crop same area, and do augmentation - local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4) - i_transform.extend( - [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])]) - transform += i_transform * self.num_crops[i] - self.transform = transform - # add online train transform of the size of global view - online_train_transform = ag.Compose( - [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] - ) - - self.transform.append(online_train_transform) - return transform - - def other_networks_transformers(self, augmented_imgs): - transform = [] - for i in range(len(self.size_crops)): - i_transform = [] - for ith_aug, aug in enumerate(augmented_imgs): - # For the first time we crop randomly, then save coordinates for further transformation - if ith_aug == 0: - random_crop = ag.RandomCropper(self.size_crops[i]) - i_transform.extend( - [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])] - ) - else: # Crop same area, and do augmentation - fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width) - i_transform.extend( - [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])]) - transform += i_transform * self.num_crops[i] - - self.transform = transform - # add online train transform of the size of global view - #What was that for?? I forgot! - online_train_transform = ag.Compose( - [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] - ) - - self.transform.append(online_train_transform) - return transform - -class OurEvalTransformer(OurTrainTransformer): - def __init__(self, - size_crops: Tuple[int] = (294, 294), - nmb_crops: Tuple[int] = (2, 4), - augs:Dict = {'SaltPepper':1}, - use_hook_former: bool =True) -> List: - super().__init__(size_crops=size_crops, - nmb_crops=nmb_crops, - augs = augs, - use_hook_former=use_hook_former) - - input_height = self.size_crops[0] # get global view crop - test_transform = ag.Compose( - [ - ag.Resize(int(input_height + 0.1 * input_height)), - ag.CenterCrop(input_height), - self.final_transform, - ] - ) - - # replace last transform to eval transform in self.transform list - self.transform[-1] = test_transform - -class OurFinetuneTransform: - def __init__(self, - input_height: int = 224, - normalize = None, - eval_transform: bool = False) -> None : - - if not eval_transform:# TODO think about it - data_transforms = [ - transforms.RandomResizedCrop(size=self.input_height), - ag.RandomHorizontalFlip(), - ag.SaltPepperNoise(salt_or_pepper=0.5), - transforms.RandomGrayscale(p=0.2), - ] - else: - data_transforms = ag.Compose( - [ - transforms.Resize(int(input_height + 0.1 * input_height)), - transforms.CenterCrop(input_height), - self.final_transform, - ] - ) - - if normalize is None: - final_transform = transforms.ToTensor() - else: - final_transform = transforms.Compose([transforms.ToTensor(), normalize]) - - data_transforms.append(final_transform) - self.transform = ag.Compose(data_transforms) - - def __call__(self, img: Tensor)->Tensor: - return self.transform(img) - -# TODO use it in the code with a flag -class SwapDefaultTransformer(): - def __init__( - self, - size_crops: Tuple[int] = (96, 36), - nmb_crops: Tuple[int] = (2, 4), - min_scale_crops: Tuple[float] = (0.33, 0.10), - max_scale_crops: Tuple[float] = (1, 0.33), - gaussian_blur: bool = True, - jitter_strength: float = 1.0, - ) -> object: - - self.size_crops = size_crops - self.num_crops = nmb_crops - self.min_scale_crops = min_scale_crops - self.max_scale_crops = max_scale_crops - self.gaussian_blur = gaussian_blur - self.jitter_strength = jitter_strength - - def get(self, mode): - if mode == 'train': - return swav.SwAVTrainDataTransform( - #normalize=self.normalization(), - size_crops=self.size_crops, - num_crops=self.num_crops, - min_scale_crops=self.min_scale_crops, - max_scale_crops=self.max_scale_crops, - gaussian_blur=self.gaussian_blur, - jitter_strength=self.jitter_strength - ) - elif mode == 'val': - return swav.SwAVEvalDataTransform( - #normalize=self.normalization(), - size_crops=self.size_crops, - num_crops=self.num_crops, - min_scale_crops=self.min_scale_crops, - max_scale_crops=self.max_scale_crops, - gaussian_blur=self.gaussian_blur, - jitter_strength=self.jitter_strength - ) - -- GitLab