diff --git a/SSLGlacier/main.py b/SSLGlacier/main.py index 73d9277c31dc350c06c993cc0d123f6e608665ad..cc7d5b4b06f1cf3bda20348953614c420070015f 100644 --- a/SSLGlacier/main.py +++ b/SSLGlacier/main.py @@ -5,12 +5,12 @@ from torchinfo import summary as torchinfo_summery from torchsummary import summary as torch_summary from torch._inductor.config import trace -from processing import datamodule_, transformers_ +from modules import datamodule_, transformers_ import pytorch_lightning as pl from models import ResNet_light_mine, Swav_Orig_ResNet import os import torch -from processing.swav_module import SwAV +from modules.swav_module import SwAV from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator import numpy as np from utils.utils import str2bool, getDataPath diff --git a/SSLGlacier/modules/agumentations_.py b/SSLGlacier/modules/agumentations_.py new file mode 100644 index 0000000000000000000000000000000000000000..05cb774513ca1749f65dc55320f93754ab81db8a --- /dev/null +++ b/SSLGlacier/modules/agumentations_.py @@ -0,0 +1,411 @@ +# PILRandomGaussianBlur and get_color_distortion are used and implemented +# by Swav and SimCLR - https://arxiv.org/abs/2002.05709 +import random +from logging import getLogger + +import torch.nn +from PIL import ImageFilter +import numpy as np +import torchvision.transforms as transforms +import torchvision, tormentor +from torchvision.transforms import functional as F +import torch +import matplotlib.pyplot as plt + +logger = getLogger() + + +class Compose(object): + """ + Class for chaining transforms together + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return [image, target] + + def __getitem__(self, item): + return self.transforms[item] + +class DoNothing(torch.nn.Module): + def __init__(self): + super(DoNothing, self).__init__() + + def __call__(self, img, mask): + return [img, mask] + +class Cropper(torch.nn.Module): + def __init__(self, i, j, h, w): + super(Cropper, self).__init__() + self.left = i + self.right = j + self.height = h + self.width = w + + def __call__(self, img, mask): + cropped_img = F.crop(img, self.left, self.right, self.height, self.width) + cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) + return [cropped_img, cropped_mask] + + +class RandomCropper(torch.nn.Module): + ''' + This function returns one patch at time, if you need more patches in one image, call it in a loop + Args: + orig_img: get an png or jpg image and crop one patch randomly, in both image and mask + orig_mask: This is the images mask, we crop same area from + Returns: + A list of two argument, first cropped patch in image,second cropped patch in mask + ''' + + def __init__(self, size): + super(RandomCropper, self).__init__() + self.left = 0 + self.right = 0 + self.height = 0 + self.width = 0 + self.size = size + + def forward(self, img, mask): + self.left, self.right, self.height, self.width = torchvision.transforms.RandomCrop.get_params( + img, output_size=(self.size, self.size)) + cropped_img = F.crop(img, self.left, self.right, self.height, self.width) + cropped_mask = F.crop(mask, self.left, self.right, self.height, self.width) + return [cropped_img, cropped_mask] + + +class PILRandomGaussianBlur(torch.nn.Module): + def __init__(self, radius_min=0.1, radius_max=2.): + """ + Apply Gaussian Blur to the PIL image. Take the radius and probability of + application as the parameter. + This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 + """ + super(PILRandomGaussianBlur, self).__init__() + self.radius_min = radius_min + self.radius_max = radius_max + + def forward(self, img, mask): + return [img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ), mask] + + +class RandomHorizontalFlip(torch.nn.Module): + def __init__(self): + super(RandomHorizontalFlip, self).__init__() + + def forward(self, img, mask): + image = torchvision.transforms.functional.hflip(img) + mask = torchvision.transforms.functional.hflip(mask) + return [image, mask] + + +class GetColorDistortion(torch.nn.Module): + def __int__(self, s=1.0): + super(GetColorDistortion, self).__init__() + self.s = s + + def forward(self, img, mask): + color_jitter = transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s) + rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) + rnd_gray = transforms.RandomGrayscale(p=0.2) + color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) + return color_distort(img, mask) + + +# TODO this net +class GaussNoise(torch.nn.Module): + def __init__(self, mean=0, var=10000): + super(GaussNoise, self).__init__() + self.mean = mean + self.var = var + + def forward(self, img, mask): + row, col = img.size + sigma = self.var ** 0.5 + gauss = np.random.normal(self.mean, sigma, (row, col)) + gauss = gauss.reshape(row, col) + noisy = img + gauss + return [noisy, mask] + + +class SaltPepperNoise(torch.nn.Module): + def __init__(self, salt_or_pepper=0.5): + super(SaltPepperNoise, self).__init__() + self.salt_or_pepper = salt_or_pepper + + def forward(self, img, mask): + if len(img.size) == 3: + img_size = img.size[1] * img.size[2] + else: + img_size = img.size[0] * img.size[1] + amount = .4 + noisy = np.copy(img) + # Salt mode + num_salt = np.ceil(amount * img_size * self.salt_or_pepper) + target_pixels = [np.random.randint(0, i - 1, int(num_salt)) + for i in img.size] + target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) + for i, j in target_pixels: + noisy[i][j] = 1 + # Pepper mode + num_pepper = np.ceil(amount * img_size * (1. - self.salt_or_pepper)) + target_pixels = [np.random.randint(0, i - 1, int(num_pepper)) + for i in img.size] + target_pixels = list(map(lambda coords: tuple(coords), zip(target_pixels[0], target_pixels[1]))) + + for i, j in target_pixels: + noisy[i][j] = 0 + + # plt.imshow(img) + # plt.imshow(noisy) + # plt.show() + return [noisy, mask] + + +class PoissionNoise(torch.nn.Module): + def __init__(self): + super(PoissionNoise, self).__init__() + + def forward(self, img, mask): + vals = len(np.unique(img)) + vals = 2 ** np.ceil(np.log2(vals)) + noisy = np.random.poisson(img * vals) / float(vals) + return [noisy, mask] + + +# TODO this one +class SpeckleNoise(torch.nn.Module): + def __init__(self): + super(SpeckleNoise, self).__init__() + + def forward(self, img, mask): + row, col = img.size + gauss = np.random.randn(row, col) + gauss = gauss.reshape(row, col) + noisy = img + img * gauss + return [noisy, mask] + + +class SetZeroNoise(torch.nn.Module): + def __init__(self): + super(SetZeroNoise, self).__init__() + + def forward(self, img, mask): + row, col = img.size + img_size = row * col + random_rows = np.random.randint(0, int(row), (1, int(img_size))) + random_cols = np.random.randint(0, int(col), (1, int(img_size))) + target_pixels = list(zip(random_rows[0], random_cols[0])) + for pix in target_pixels: + img[pix[0], pix[1]] = 0 + return [img, mask] + + +#################################################################################### + +###########Base_code Agumentations suggestions: #TODO do not need them now########## +class MyWrap(torch.nn.Module): + """ + Random wrap augmentation taken from tormentor + """ + + def __init__(self): + super(MyWrap, self).__init__() + + def forward(self, img, target): + # This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics + wrap_rand = tormentor.Wrap.override_distributions(roughness=tormentor.random.Uniform(value_range=(.1, .7)), + intensity=tormentor.random.Uniform(value_range=(.0, 1.))) + wrap = wrap_rand() + image = wrap(img) + mask = wrap(target, is_mask=True) + return [image, mask] + + +class Rotate(torch.nn.Module): + """ + Random rotation augmentation + """ + + def __init__(self): + super(Rotate, self).__init__() + + def forward(self, img, target): + random = np.random.randint(0, 3) + angle = 90 + if random == 1: + angle = 180 + elif random == 2: + angle = 270 + image = torchvision.transforms.functional.rotate(img, angle=angle) + mask = torchvision.transforms.functional.rotate(target, angle=angle) + return [image, mask.squeeze(0)] + + +class Bright(torch.nn.Module): + """ + Random brightness adjustment augmentations + """ + + def __init__(self, + lower_band: float = -0.2, + upper_band: float = 0.2): + super(Bright, self).__init__() + self.lower_band = lower_band + self.upper_band = upper_band + + def forward(self, img, mask): + bright_rand = tormentor.Brightness.override_distributions( + brightness=tormentor.random.Uniform((self.lower_band, self.upper_band))) + # bright = bright_rand() + # image_transformed = img.clone() + image_transformed = bright_rand(np.asarray(img)) + # set NA areas back to zero + image_transformed.seed[img == 0] = 0.0 + + return [image_transformed.seed, mask] + + +class Noise(torch.nn.Module): + """ + Random additive noise augmentation + """ + + def __init__(self): + super(Noise, self).__init__() + + def forward(self, img, target): + # add noise. It is a multiplicative gaussian noise so no need to set na areas back to zero again + noise = torch.normal(mean=0, std=0.3, size=img.size) + image = img + img * noise + image[image > 1.0] = 1.0 + image[image < 0.0] = 0.0 + + return [image, target] + + +class Resize(torch.nn.Module): + # Library code + """Resize the input image to the given size. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means a maximum of two leading dimensions + + .. warning:: + The output image might be different depending on its type: when downsampling, the interpolation of PIL images + and tensors is slightly different, because PIL applies antialiasing. This may lead to significant differences + in the performance of a network. Therefore, it is preferable to train and serve a model with the same input + types. See also below the ``antialias`` parameter, which can help making the output of PIL images and tensors + closer. + + Args: + size (sequence or int): Desired output size. If size is a sequence like + (h, w), output size will be matched to this. If size is an int, + smaller edge of the image will be matched to this number. + i.e, if height > width, then image will be rescaled to + (size * height / width, size). + + .. note:: + In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, + ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. + max_size (int, optional): The maximum allowed for the longer edge of + the resized image. If the longer edge of the image is greater + than ``max_size`` after being resized according to ``size``, + ``size`` will be overruled so that the longer edge is equal to + ``max_size``. + As a result, the smaller edge may be shorter than ``size``. This + is only supported if ``size`` is an int (or a sequence of length + 1 in torchscript mode). + antialias (bool, optional): Whether to apply antialiasing. + It only affects **tensors** with bilinear or bicubic modes and it is + ignored otherwise: on PIL images, antialiasing is always applied on + bilinear or bicubic modes; on other modes (for PIL images and + tensors), antialiasing makes no sense and this parameter is ignored. + Possible values are: + + - ``True``: will apply antialiasing for bilinear or bicubic modes. + Other mode aren't affected. This is probably what you want to use. + - ``False``: will not apply antialiasing for tensors on any mode. PIL + images are still antialiased on bilinear or bicubic modes, because + PIL doesn't support no antialias. + - ``None``: equivalent to ``False`` for tensors and ``True`` for + PIL images. This value exists for legacy reasons and you probably + don't want to use it unless you really know what you are doing. + + The current default is ``None`` **but will change to** ``True`` **in + v0.17** for the PIL and Tensor backends to be consistent. + """ + + def __init__(self, size, interpolation=F.InterpolationMode.BILINEAR, max_size=None, antialias="warn"): + super().__init__() + self.size = size + self.max_size = max_size + self.interpolation = interpolation + self.antialias = antialias + + def forward(self, img, mask): + """ + Args: + img (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + return [F.resize(img, self.size, self.interpolation, self.max_size, self.antialias), mask] + + +class CenterCrop(torch.nn.Module): + """Crops the given image at the center. + If the image is torch Tensor, it is expected + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). + """ + + def __init__(self, size): + super().__init__() + self.size = size # _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def forward(self, img, mask): + """ + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + PIL Image or Tensor: Cropped image. + """ + ###ADOPT FROM HOOKFORMER CODE, #todo CHECK IT LATER + W, H = img.size + + WW = (W // self.size) + 2 + HH = (H // self.size) + 2 + + return [F.center_crop(img, (HH * self.size, WW * self.size)), F.center_crop(mask, (HH * self.size, WW * self.size))] + + +class ToTensorZones(): + def __call__(self, image, target): + image = F.to_tensor(np.array(image).astype(np.float32)) + target = torch.from_numpy(np.array(target).astype(np.float32)) + # value for NA area=0, stone=64, glacier=127, ocean with ice melange=254 + target[target == 0] = 0 + target[target == 64] = 1 + target[target == 127] = 2 + target[target == 254] = 3 + # class ids for NA area=0, stone=1, glacier=2, ocean with ice melange=3 + return [image, target] diff --git a/SSLGlacier/modules/datamodule_.py b/SSLGlacier/modules/datamodule_.py new file mode 100644 index 0000000000000000000000000000000000000000..edf3648e314a451703627fcfc8d5f855c90d7870 --- /dev/null +++ b/SSLGlacier/modules/datamodule_.py @@ -0,0 +1,149 @@ +from torchvision.transforms import functional as F +from pl_bolts.models.self_supervised import swav +from torch.utils.data import DataLoader +from torch.utils.data import Dataset +import pytorch_lightning as pl +from torchvision import transforms as transform_lib +from . import agumentations_ as our_transforms +import torchvision +import numpy as np +from typing import Tuple, List, Any, Callable, Optional +import torch +import cv2 +from PIL import Image,ImageOps +import os + +# TODO add moco, simclr,cpc and amdim transformers +class SSLDataset(Dataset): + def __init__(self, parent_dir, transform,return_index=False, **kwargs): # TODO Pass them as arguments + ''' + Args: + mode: can be tested, train, or validation + parent_dir: directory in which the folders test, train and validation exists + ''' + self.mode = kwargs['mode'] if kwargs['mode'] else 'train' + self.return_index = return_index + self.images_path = os.path.join(parent_dir, "sar_images", self.mode) + self.masks_path = os.path.join(parent_dir, 'zones', self.mode) + self.images = os.listdir(self.images_path) + self.masks = os.listdir(self.masks_path) + assert len(self.masks) == len(self.images), "You don't have the same number of images and masks" + self.transform = transform + # Sort than images and masks fit together + self.images.sort() + self.masks.sort() + + # Let shuffle and save indices, and load them in next calls if they exist + if not os.path.exists(os.path.join("data_processing", "data_splits")): + os.makedirs(os.path.join("data_processing", "data_splits")) + if not os.path.isfile(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt")): + shuffle = np.random.permutation(len(self.images)) + # Works for numpy version >= 1.5 + np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, + newline=' ') + + else: + # use already existing shuffle + with open(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), "rb") as fp: + lines = fp.readlines() + shuffle = [np.fromstring(line, dtype=int, sep=' ') for line in lines] + # if lengths do not match, we need to create a new permutation + if len(shuffle) != len(self.images): + shuffle = np.random.permutation(len(self.images)) + np.savetxt(os.path.join("data_processing", "data_splits", "shuffle_" + ".txt"), shuffle, + newline=' ') + + self.images = np.array(self.images) # Why!!!! + self.masks = np.array(self.masks) + self.images = self.images[shuffle].copy() + self.masks = self.masks[shuffle].copy() + self.images = list(self.images) + self.masks = list(self.masks) # Why!!!! + + def __len__(self): + return len(self.images) + + def __getitem__(self, index): + img_name = self.images[index] + masks_name = self.masks[index] + assert img_name.split('.')[0] == masks_name.split('.')[0].replace("_zones", ""), \ + "image and label name don't match. Image name: " + img_name + ". Label name: " + masks_name + #image = cv2.imread(os.path.join(self.images_path, img_name).__str__(), cv2.IMREAD_GRAYSCALE) + #mask = cv2.imread(os.path.join(self.masks_path, masks_name).__str__(), cv2.IMREAD_GRAYSCALE) + image = Image.open(os.path.join(self.images_path, img_name).__str__()) + image = ImageOps.grayscale(image) + #image = np.array(image).astype(np.float32) + + #rgbimg = Image.new("RGBA", image.size) + #image = rgbimg.paste(image) + #image = ImageOps.grayscale(image) + mask = Image.open(os.path.join(self.masks_path, masks_name).__str__()) + #rgbimg = Image.new("RGBA", mask.size) + #mask = rgbimg.paste(rgbimg) + mask = ImageOps.grayscale(mask) + #mask = np.array(mask).astype(np.float32) + # TODO check if it works or not, it should return multiple transformation for one image + # TODO I crop different parts of the image and mask as part of the batch, how should I handle it!! + # TODO this is a list of tuples of img and masks, think if you want to change it later + + # to_tensor = our_transforms.ToTensorZones() + # _, mask = to_tensor(image,mask) + if self.transform is not None: + augmented_imgs, masks = self.transform(image, mask) + if self.return_index: + return index, augmented_imgs, mask, img_name, masks_name + return augmented_imgs, masks + + +class SSLDataModule(pl.LightningDataModule): + # Base is adopted from Nora's code + def __init__(self, batch_size, parent_dir, args): + """ + :param batch_size: batch size + :param target: Either 'zones' or 'front'. Tells which masks should be used. + """ + super().__init__() + self.transforms = None + self.batch_size = batch_size + self.glacier_test = None + self.glacier_train = None + self.glacier_val = None + self.parent_dir = parent_dir + self.aug_args = args.agumentations + self.size_crops = args.size_crops + self.num_crops = args.nmb_crops + + def prepare_data(self): + # download, + # only called on 1 GPU/TPU in distributed + pass + def __len__(self): + return len(self.glacier_train) + def setup(self, stage=None): + # process and split here + # make assignments here (val/train/test split) + # called on every process in DDP + if stage == 'test' or stage is None: + self.transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms + self.glacier_test = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode=stage) + if stage == 'fit' or stage is None: + self.transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms + self.glacier_train = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='train') + + self.transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms + self.glacier_val = SSLDataset(parent_dir=self.parent_dir, transform=self.transforms, mode='val') + + def train_dataloader(self) -> DataLoader: + return DataLoader(self.glacier_train, batch_size=self.batch_size, num_workers=6, pin_memory=True, + drop_last=True) + + def val_dataloader(self)-> DataLoader: + return DataLoader(self.glacier_val, batch_size=self.batch_size, num_workers=6, pin_memory=True, drop_last=True) + + def test_dataloader(self)-> DataLoader: + return DataLoader(self.glacier_test, batch_size=1, num_workers=6, pin_memory=True, + drop_last=True) # TODO self.batch_size + + def _default_transforms(self) -> Callable: + #return transform_lib.Compose([transform_lib.ToTensor()]) + return our_transforms.Compose([our_transforms.ToTensorZones()]) diff --git a/SSLGlacier/modules/semi_module.py b/SSLGlacier/modules/semi_module.py new file mode 100644 index 0000000000000000000000000000000000000000..9798a3620f90f6dc9e849eb08eecafe45f256c24 --- /dev/null +++ b/SSLGlacier/modules/semi_module.py @@ -0,0 +1,29 @@ +from pytorch_lightning import LightningModule +from typing import Any +class Semi_(LightningModule): + def __init__(self): + pass + + def setup(self, stage: str) -> None: + pass + + def init_model(self): + pass + + def forward(self, *args: Any, **kwargs: Any) -> Any: + pass + + def on_train_epoch_start(self) -> None: + pass + + def shared_step(self, batch): + inputs, y = batch + inputs = inputs[:, -1] + embedding = self.model(inputs) + ####Where should I add noise.... + ####assume embedding[b][i] belongs to batch b view i + ####if we have two view, one with noise and one without, then + ##### select some portion of each view belongs to different classes. + ##### pos_set = same classess .. + ##### neg_set = diff classes + diff --git a/SSLGlacier/modules/swav_module.py b/SSLGlacier/modules/swav_module.py new file mode 100644 index 0000000000000000000000000000000000000000..3b57a4ede5265846c6a4aeb4b885d3d31169c206 --- /dev/null +++ b/SSLGlacier/modules/swav_module.py @@ -0,0 +1,562 @@ +"""Adapted from official swav implementation: https://github.com/facebookresearch/swav.""" +import os +from argparse import ArgumentParser + +import torch +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint +from torch import nn + +#from pl_bolts.models.self_supervised.swav.loss import SWAVLoss +from SSLGlacier.models.losses import SWAVLoss +from SSLGlacier.models.RESNET import resnet18, resnet50 +#from Base_Code.models.ParentUNet import UNet +from SSLGlacier.models.UNET import UNet #The main network +from pl_bolts.optimizers.lars import LARS +from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay + +##########Importing files for hook############ +from SSLGlacier.models.hook.Swin_Transformer_Wrapper import SwinUnet +######Temporary config file for hooknet####### +from utils.config import get_config + +class SwAV(LightningModule): + def __init__( + self, + gpus: int, + num_samples: int, + batch_size: int, + dataset: str, + num_nodes: int = 1, + arch: str = "resnet50", + hidden_mlp: int = 2048, + feat_dim: int = 128, + warmup_epochs: int = 10, + max_epochs: int = 100, + num_prototypes: int = 3000, + freeze_prototypes_epochs: int = 1, + temperature: float = 0.1, + sinkhorn_iterations: int = 3, + queue_length: int = 0, # must be divisible by total batch-size + queue_path: str = "queue", + epoch_queue_starts: int = 15, + crops_for_assign: tuple = (0, 1), + num_crops: tuple = (2, 6), + num_augs: int= 2, + first_conv: bool = True, + maxpool1: bool = True, + optimizer: str = "adam", + exclude_bn_bias: bool = False, + start_lr: float = 0.0, + learning_rate: float = 1e-3, + final_lr: float = 0.0, + weight_decay: float = 1e-6, + epsilon: float = 0.05, + just_aug_for_same_assign_views: bool = False, + swin_hparams = {} + ) -> None: + """ + Args: + gpus: number of gpus per node used in training, passed to SwAV module + to manage the queue and select distributed sinkhorn + num_nodes: number of nodes to train on + num_samples: number of image samples used for training + batch_size: batch size per GPU in ddp + dataset: dataset being used for train/val + arch: encoder architecture used for pre-training + hidden_mlp: hidden layer of non-linear projection head, set to 0 + to use a linear projection head + feat_dim: output dim of the projection head + warmup_epochs: apply linear warmup for this many epochs + max_epochs: epoch count for pre-training + num_prototypes: count of prototype vectors + freeze_prototypes_epochs: epoch till which gradients of prototype layer + are frozen + temperature: loss temperature + sinkhorn_iterations: iterations for sinkhorn normalization + queue_length: set queue when batch size is small, + must be divisible by total batch-size (i.e. total_gpus * batch_size), + set to 0 to remove the queue + queue_path: folder within the logs directory + epoch_queue_starts: start uing the queue after this epoch + crops_for_assign: list of crop ids for computing assignment + num_crops: number of global and local crops, ex: [2, 6] + first_conv: keep first conv same as the original resnet architecture, + if set to false it is replace by a kernel 3, stride 1 conv (cifar-10) + maxpool1: keep first maxpool layer same as the original resnet architecture, + if set to false, first maxpool is turned off (cifar10, maybe stl10) + optimizer: optimizer to use + exclude_bn_bias: exclude batchnorm and bias layers from weight decay in optimizers + start_lr: starting lr for linear warmup + learning_rate: learning rate + final_lr: float = final learning rate for cosine weight decay + weight_decay: weight decay for optimizer + epsilon: epsilon val for swav assignments + """ + super().__init__() + self.save_hyperparameters() + + self.gpus = gpus + self.num_nodes = num_nodes + self.arch = arch + self.dataset = dataset + self.num_samples = num_samples + self.batch_size = batch_size + + self.hidden_mlp = hidden_mlp + self.feat_dim = feat_dim + self.num_prototypes = num_prototypes + self.freeze_prototypes_epochs = freeze_prototypes_epochs + self.sinkhorn_iterations = sinkhorn_iterations + + self.queue_length = queue_length + self.queue_path = queue_path + self.epoch_queue_starts = epoch_queue_starts + self.crops_for_assign = crops_for_assign + self.num_crops = num_crops + self.num_augs = num_augs + + self.first_conv = first_conv + self.maxpool1 = maxpool1 + + self.optim = optimizer + self.exclude_bn_bias = exclude_bn_bias + self.weight_decay = weight_decay + self.epsilon = epsilon + self.temperature = temperature + + self.start_lr = start_lr + self.final_lr = final_lr + self.learning_rate = learning_rate + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.just_aug_for_same_assign_views = just_aug_for_same_assign_views + + self.model = self.init_model() + self.criterion = SWAVLoss( + gpus=self.gpus, + num_nodes=self.num_nodes, + temperature=self.temperature, + crops_for_assign=self.crops_for_assign, + num_crops=self.num_crops, + num_augs=self.num_augs, + sinkhorn_iterations=self.sinkhorn_iterations, + epsilon=self.epsilon, + just_aug_for_same_assign_views= self.just_aug_for_same_assign_views + ) + self.use_the_queue = None + # compute iters per epoch + global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size + self.train_iters_per_epoch = self.num_samples // global_batch_size + self.queue = None + + ####For hook#### + self.swin_hparams = swin_hparams + def setup(self, stage): + if self.queue_length > 0: + queue_folder = os.path.join(self.logger.log_dir, self.queue_path) + if not os.path.exists(queue_folder): + os.makedirs(queue_folder) + + self.queue_path = os.path.join(queue_folder, "queue" + str(self.trainer.global_rank) + ".pth") + + if os.path.isfile(self.queue_path): + self.queue = torch.load(self.queue_path)["queue"] + + def init_model(self): + if self.arch == "resnet18": + backbone = resnet18(normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1) + + elif self.arch == "resnet50": + backbone = resnet50(normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1) + elif self.arch == "hook": + backbone = SwinUnet( + img_size=224, + num_classes=5, + normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1,######till here for basenet + image_size= self.swin_hparams.image_size,####from here for swin itself + swin_patch_size = self.swin_hparams.swin_patch_size, + swin_in_chans=self.swin_hparams.swin_in_chans, + swin_embed_dim=self.swin_hparams.swin_embed_dim, + swin_depths=self.swin_hparams.swin_depths, + swin_num_heads=self.swin_hparams.swin_num_heads, + swin_window_size=self.swin_hparams.swin_window_size, + swin_mlp_ratio=self.swin_hparams.swin_mlp_ratio, + swin_QKV_BIAS=self.swin_hparams.swin_QKV_BIAS, + swin_QK_SCALE=self.swin_hparams.swin_QK_SCALE, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + swin_ape=self.swin_hparams.swin_ape, + swin_patch_norm=self.swin_hparams.swin_path_norm + ) + + elif self.arch == "Unet" or "unet": + backbone = UNet(num_classes=5, + input_channels=1, + num_layers=5, + features_start=64, + bilinear=False, + normalize=True, + hidden_mlp=self.hidden_mlp, + output_dim=self.feat_dim, + num_prototypes=self.num_prototypes, + first_conv=self.first_conv, + maxpool1=self.maxpool1 + ) + else: + raise f'{self.arch} model is not defined' + + return backbone + + def forward(self, x): + # pass single batch from the resnet backbone + return self.model.forward_backbone(x) + + def on_train_epoch_start(self): + if self.queue_length > 0: + if self.trainer.current_epoch >= self.epoch_queue_starts and self.queue is None: + self.queue = torch.zeros( + len(self.crops_for_assign), + self.queue_length // self.gpus, # change to nodes * gpus once multi-node + self.feat_dim, + ) + + if self.queue is not None: + self.queue = self.queue.to(self.device) + + self.use_the_queue = False + + def on_train_epoch_end(self) -> None: + if self.queue is not None: + torch.save({"queue": self.queue}, self.queue_path) + + def on_after_backward(self): + if self.current_epoch < self.freeze_prototypes_epochs: + for name, p in self.model.named_parameters(): + if "prototypes" in name: + p.grad = None + + def shared_step(self, batch): + if self.dataset == "stl10": + unlabeled_batch = batch[0] + batch = unlabeled_batch + + inputs, y = batch + inputs = inputs[:-1] # remove online train/eval transforms at this point + + # 1. normalize the prototypes + with torch.no_grad(): + w = self.model.prototypes.weight.data.clone() + w = nn.functional.normalize(w, dim=1, p=2) + self.model.prototypes.weight.copy_(w) + + # 2. multi-res forward passes + embedding, output = self.model(inputs) + embedding = embedding.detach() + bs = inputs[0].size(0) + + # SWAV loss computation + loss, queue, use_queue = self.criterion( + output=output, + embedding=embedding, + prototype_weights=self.model.prototypes.weight, + batch_size=bs, + queue=self.queue, + use_queue=self.use_the_queue, + ) + self.queue = queue + self.use_the_queue = use_queue + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch) + + self.log("train_loss", loss, on_step=True, on_epoch=False) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch) + + self.log("val_loss", loss, on_step=False, on_epoch=True) + return loss + + def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): + params = [] + excluded_params = [] + + for name, param in named_params: + if not param.requires_grad: + continue + if any(layer_name in name for layer_name in skip_list): + excluded_params.append(param) + else: + params.append(param) + + return [{"params": params, "weight_decay": weight_decay}, {"params": excluded_params, "weight_decay": 0.0}] + + def configure_optimizers(self): + if self.exclude_bn_bias: + params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) + else: + params = self.parameters() + + if self.optim == "lars": + optimizer = LARS( + params, + lr=self.learning_rate, + momentum=0.9, + weight_decay=self.weight_decay, + trust_coefficient=0.001, + ) + elif self.optim == "adam": + optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) + + warmup_steps = self.train_iters_per_epoch * self.warmup_epochs + total_steps = self.train_iters_per_epoch * self.max_epochs + + scheduler = { + "scheduler": torch.optim.lr_scheduler.LambdaLR( + optimizer, + linear_warmup_decay(warmup_steps, total_steps, cosine=True), + ), + "interval": "step", + "frequency": 1, + } + + return [optimizer], [scheduler] + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + + # model params + parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") + # specify flags to store false + parser.add_argument("--first_conv", action="store_false") + parser.add_argument("--maxpool1", action="store_false") + parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") + parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") + parser.add_argument("--online_ft", action="store_true") + parser.add_argument("--fp32", action="store_true") + + # transform params + parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") + parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") + parser.add_argument("--dataset", type=str, default="stl10", help="stl10, cifar10") + parser.add_argument("--data_dir", type=str, default=".", help="path to download data") + parser.add_argument("--queue_path", type=str, default="queue", help="path for queue") + + parser.add_argument( + "--num_crops", type=int, default=[2, 4], nargs="+", help="list of number of crops (example: [2, 6])" + ) + parser.add_argument( + "--size_crops", type=int, default=[96, 36], nargs="+", help="crops resolutions (example: [224, 96])" + ) + parser.add_argument( + "--min_scale_crops", + type=float, + default=[0.33, 0.10], + nargs="+", + help="argument in RandomResizedCrop (example: [0.14, 0.05])", + ) + parser.add_argument( + "--max_scale_crops", + type=float, + default=[1, 0.33], + nargs="+", + help="argument in RandomResizedCrop (example: [1., 0.14])", + ) + + # training params + parser.add_argument("--fast_dev_run", default=1, type=int) + parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") + parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") + parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") + parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") + parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") + parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") + parser.add_argument("--max_steps", default=-1, type=int, help="max steps") + parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") + parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") + + parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") + parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") + parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") + parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") + + # swav params + parser.add_argument( + "--crops_for_assign", + type=int, + nargs="+", + default=[0, 1], + help="list of crops id used for computing assignments", + ) + parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") + parser.add_argument( + "--epsilon", default=0.05, type=float, help="regularization parameter for Sinkhorn-Knopp algorithm" + ) + parser.add_argument( + "--sinkhorn_iterations", default=3, type=int, help="number of iterations in Sinkhorn-Knopp algorithm" + ) + parser.add_argument("--num_prototypes", default=512, type=int, help="number of prototypes") + parser.add_argument( + "--queue_length", + type=int, + default=0, + help="length of the queue (0 for no queue); must be divisible by total batch size", + ) + parser.add_argument( + "--epoch_queue_starts", type=int, default=15, help="from this epoch, we start using a queue" + ) + parser.add_argument( + "--freeze_prototypes_epochs", + default=1, + type=int, + help="freeze the prototypes during this many epochs from the start", + ) + + return parser + + +def cli_main(): + from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator + from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule + from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform + + parser = ArgumentParser() + + # model args + parser = SwAV.add_model_specific_args(parser) + args = parser.parse_args() + + if args.dataset == "stl10": + dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + dm.train_dataloader = dm.train_dataloader_mixed + dm.val_dataloader = dm.val_dataloader_mixed + args.num_samples = dm.num_unlabeled_samples + + args.maxpool1 = False + + normalization = stl10_normalization() + elif args.dataset == "cifar10": + args.batch_size = 2 + args.num_workers = 0 + + dm = CIFAR10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + args.num_samples = dm.num_samples + + args.maxpool1 = False + args.first_conv = False + + normalization = cifar10_normalization() + + # cifar10 specific params + args.size_crops = [32, 16] + args.num_crops = [2, 1] + args.gaussian_blur = False + elif args.dataset == "imagenet": + args.maxpool1 = True + args.first_conv = True + normalization = imagenet_normalization() + + args.size_crops = [224, 96] + args.num_crops = [2, 6] + args.min_scale_crops = [0.14, 0.05] + args.max_scale_crops = [1.0, 0.14] + args.gaussian_blur = True + args.jitter_strength = 1.0 + + args.batch_size = 64 + args.num_nodes = 8 + args.gpus = 8 # per-node + args.max_epochs = 800 + + args.optimizer = "lars" + args.learning_rate = 4.8 + args.final_lr = 0.0048 + args.start_lr = 0.3 + + args.num_prototypes = 3000 + args.online_ft = True + + dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) + + args.num_samples = dm.num_samples + args.input_height = dm.dims[-1] + else: + raise NotImplementedError("other datasets have not been implemented till now") + + dm.train_transforms = SwAVTrainDataTransform( + normalize=normalization, + size_crops=args.size_crops, + num_crops=args.num_crops, + min_scale_crops=args.min_scale_crops, + max_scale_crops=args.max_scale_crops, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + ) + + dm.val_transforms = SwAVEvalDataTransform( + normalize=normalization, + size_crops=args.size_crops, + num_crops=args.num_crops, + min_scale_crops=args.min_scale_crops, + max_scale_crops=args.max_scale_crops, + gaussian_blur=args.gaussian_blur, + jitter_strength=args.jitter_strength, + ) + + # swav model init + model = SwAV(**args.__dict__) + + online_evaluator = None + if args.online_ft: + # online eval + online_evaluator = SSLOnlineEvaluator( + drop_p=0.0, + hidden_dim=None, + z_dim=args.hidden_mlp, + num_classes=dm.num_classes, + dataset=args.dataset, + ) + + lr_monitor = LearningRateMonitor(logging_interval="step") + model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss") + callbacks = [model_checkpoint, online_evaluator] if args.online_ft else [model_checkpoint] + callbacks.append(lr_monitor) + + trainer = Trainer( + max_epochs=args.max_epochs, + max_steps=None if args.max_steps == -1 else args.max_steps, + gpus=args.gpus, + num_nodes=args.num_nodes, + accelerator="ddp" if args.gpus > 1 else None, + sync_batchnorm=args.gpus > 1, + precision=32 if args.fp32 else 16, + callbacks=callbacks, + fast_dev_run=args.fast_dev_run, + ) + + trainer.fit(model, datamodule=dm) + + +if __name__ == "__main__": + cli_main() \ No newline at end of file diff --git a/SSLGlacier/modules/transformers_.py b/SSLGlacier/modules/transformers_.py new file mode 100644 index 0000000000000000000000000000000000000000..f8290bb69b2c33649e92a9dd9ae7797595b94e85 --- /dev/null +++ b/SSLGlacier/modules/transformers_.py @@ -0,0 +1,248 @@ +from pl_bolts.models.self_supervised import swav +from torchvision import transforms # need it for validation and fine tuning +from . import agumentations_ as ag +from typing import Tuple, List, Dict +from torch import Tensor + + + +class OurTrainTransformer(): + def __init__(self, + size_crops: Tuple[int] = (294, 94), + nmb_crops: Tuple[int] = (2, 4), + augs:Dict = {'SaltPepper':1}, + use_hook_former: bool = True) -> List: + # TODO you can set it in a way that making sure you have one sample from each argumentation method + self.size_crops = size_crops + self.num_crops = nmb_crops + augmented_imgs = [] + self.use_hook_former = use_hook_former + self.transform = [] + if augs['OrigPixelValues']: + augmented_imgs.append(ag.DoNothing()) + if augs['RGaussianB']: + augmented_imgs.append(ag.PILRandomGaussianBlur(radius_min=0.1, radius_max=2.)) + if augs['GaussiN']: + augmented_imgs.append(ag.GaussNoise(mean=0, var=10000)) + if augs['SaltPepper']: + augmented_imgs.append(ag.SaltPepperNoise(salt_or_pepper=0.2)) + if augs['ZeroN']: + augmented_imgs.append(ag.SetZeroNoise()) + # Noras augmentations methods + if augs['flip']: + # TODO manage this, how to add mask for the times we need transformation for masks too + augmented_imgs.append(ag.RandomHorizentalFlip()) + # image = torchvision.transforms.functional.hflip(image) + # mask = torchvision.transforms.functional.hflip(mask) + # image, mask = ag.ToTensorZones(image=image, target=np.array(mask)) + # augmented_imgs.append((image, mask)) + if augs['rotate']: + augmented_imgs.append(ag.Rotate()) + if augs['bright']: + augmented_imgs.append(ag.Bright()) + if augs['wrap']: + augmented_imgs.append() + if augs['noise']: + augmented_imgs.append(ag.Noise()) + + if augs['normalize'] is not None: + self.final_transform = ag.Compose([augmented_imgs.ToTensor(), augs['normalize']]) + else: + self.final_transform = ag.ToTensorZones() + + if self.use_hook_former: + self.hook_former_transformer(augmented_imgs) + else: + self.other_networks_transformers(augmented_imgs) + + + + def __call__(self, image, mask): + """ + import matplotlib.pyplot as plt + img1 =transformed_imgs[0].detach().cpu().numpy().squeeze() + plt.imshow(img1) + plt.show() + """ + transformed_img_mask = [] + for transform in self.transform: + if isinstance(transform.__getitem__(0), ag.RandomCropper): + transformed_img_mask.append(transform(image, mask)) + self.i = transform.__getitem__(0).left + self.j = transform.__getitem__(0).right + self.w = transform.__getitem__(0).width + self.h = transform.__getitem__(0).height + elif isinstance(transform.__getitem__(0), ag.Cropper): + transform.__getitem__(0).left = self.i + transform.__getitem__(0).right = self.j + transform.__getitem__(0).width = self.w + transform.__getitem__(0).height = self.h + transformed_img_mask.append(transform(image, mask)) + else: + transformed_img_mask.append(transform(image,mask)) + transformed_imgs = [item[0] for item in transformed_img_mask] + transformed_masks = [item[1] for item in transformed_img_mask] + + + # fig, axs = plt.subplots(nrows=6, ncols=4, figsize=(15, 12)) + # fig.suptitle("Patches used in training", fontsize=18, y=.95) + # + # for i, image_id in enumerate(transformed_imgs): + # img1 = transformed_imgs[i*3].detach().cpu().numpy().squeeze() + # mask1 = transformed_masks[i*3].detach().cpu().numpy().squeeze() + # axs[0, i].imshow(img1) + # axs[0, i].imshow(mask1) + # axs[0, i].imshow(transformed_imgs[i*3 +1].detach().cpu().numpy().squeeze()) + # axs[0, i].imshow(transformed_imgs[i*3+2].detach().cpu().numpy().squeeze()) + # plt.savefig('High.png') + return transformed_imgs, transformed_masks + + def hook_former_transformer(self, augmented_imgs): + transform = [] + for i in range(len(self.size_crops)): + i_transform = [] + for ith_aug, aug in enumerate(augmented_imgs): + # For the first time we crop randomly, then save coordinates for further transformation + if ith_aug == 0: + global_crop = ag.RandomCropper(self.size_crops[i]) + i_transform.extend( + [ag.Compose([global_crop] + [aug] + [ag.ToTensorZones()])] + ) + else: # Crop same area, and do augmentation + local_crop = ag.CenterCrop(center_crop_size = self.size_crops[i]/4) + i_transform.extend( + [ag.Compose([local_crop] + [aug] + [ag.ToTensorZones()])]) + transform += i_transform * self.num_crops[i] + self.transform = transform + # add online train transform of the size of global view + online_train_transform = ag.Compose( + [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] + ) + + self.transform.append(online_train_transform) + return transform + + def other_networks_transformers(self, augmented_imgs): + transform = [] + for i in range(len(self.size_crops)): + i_transform = [] + for ith_aug, aug in enumerate(augmented_imgs): + # For the first time we crop randomly, then save coordinates for further transformation + if ith_aug == 0: + random_crop = ag.RandomCropper(self.size_crops[i]) + i_transform.extend( + [ag.Compose([random_crop] + [aug] + [ag.ToTensorZones()])] + ) + else: # Crop same area, and do augmentation + fixed_crop = ag.Cropper(random_crop.left, random_crop.right, random_crop.height, random_crop.width) + i_transform.extend( + [ag.Compose([fixed_crop] + [aug] + [ag.ToTensorZones()])]) + transform += i_transform * self.num_crops[i] + + self.transform = transform + # add online train transform of the size of global view + #What was that for?? I forgot! + online_train_transform = ag.Compose( + [ag.RandomCropper(self.size_crops[i]), ag.RandomHorizontalFlip(), self.final_transform] + ) + + self.transform.append(online_train_transform) + return transform + +class OurEvalTransformer(OurTrainTransformer): + def __init__(self, + size_crops: Tuple[int] = (294, 294), + nmb_crops: Tuple[int] = (2, 4), + augs:Dict = {'SaltPepper':1}, + use_hook_former: bool =True) -> List: + super().__init__(size_crops=size_crops, + nmb_crops=nmb_crops, + augs = augs, + use_hook_former=use_hook_former) + + input_height = self.size_crops[0] # get global view crop + test_transform = ag.Compose( + [ + ag.Resize(int(input_height + 0.1 * input_height)), + ag.CenterCrop(input_height), + self.final_transform, + ] + ) + + # replace last transform to eval transform in self.transform list + self.transform[-1] = test_transform + +class OurFinetuneTransform: + def __init__(self, + input_height: int = 224, + normalize = None, + eval_transform: bool = False) -> None : + + if not eval_transform:# TODO think about it + data_transforms = [ + transforms.RandomResizedCrop(size=self.input_height), + ag.RandomHorizontalFlip(), + ag.SaltPepperNoise(salt_or_pepper=0.5), + transforms.RandomGrayscale(p=0.2), + ] + else: + data_transforms = ag.Compose( + [ + transforms.Resize(int(input_height + 0.1 * input_height)), + transforms.CenterCrop(input_height), + self.final_transform, + ] + ) + + if normalize is None: + final_transform = transforms.ToTensor() + else: + final_transform = transforms.Compose([transforms.ToTensor(), normalize]) + + data_transforms.append(final_transform) + self.transform = ag.Compose(data_transforms) + + def __call__(self, img: Tensor)->Tensor: + return self.transform(img) + +# TODO use it in the code with a flag +class SwapDefaultTransformer(): + def __init__( + self, + size_crops: Tuple[int] = (96, 36), + nmb_crops: Tuple[int] = (2, 4), + min_scale_crops: Tuple[float] = (0.33, 0.10), + max_scale_crops: Tuple[float] = (1, 0.33), + gaussian_blur: bool = True, + jitter_strength: float = 1.0, + ) -> object: + + self.size_crops = size_crops + self.num_crops = nmb_crops + self.min_scale_crops = min_scale_crops + self.max_scale_crops = max_scale_crops + self.gaussian_blur = gaussian_blur + self.jitter_strength = jitter_strength + + def get(self, mode): + if mode == 'train': + return swav.SwAVTrainDataTransform( + #normalize=self.normalization(), + size_crops=self.size_crops, + num_crops=self.num_crops, + min_scale_crops=self.min_scale_crops, + max_scale_crops=self.max_scale_crops, + gaussian_blur=self.gaussian_blur, + jitter_strength=self.jitter_strength + ) + elif mode == 'val': + return swav.SwAVEvalDataTransform( + #normalize=self.normalization(), + size_crops=self.size_crops, + num_crops=self.num_crops, + min_scale_crops=self.min_scale_crops, + max_scale_crops=self.max_scale_crops, + gaussian_blur=self.gaussian_blur, + jitter_strength=self.jitter_strength + ) +