Skip to content
Snippets Groups Projects
Commit a7d4e1b5 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 0c9c8737
No related branches found
No related tags found
No related merge requests found
# Importing required libraries
import torch
import torch.optim as optim
import random
import numpy as np
from utils import get_logger, evaluate_accuracy, evaluate_robust_accuracy, Config
from data import get_mnist_loaders, get_cifar10_loaders, get_cifar100_loaders
from models import SmallCNN, WideResNet
from attacks.fgsm import FGSMAttack
from attacks.pgd import PGDAttack
from train import (train_air_for_epochs, train_lfl_for_epochs, train_joint_for_epochs, train_vanilla_at_for_epochs,
train_feat_extraction_for_epochs, train_ewc_for_epochs)
# Function to set random seed to ensure reproducibility of results
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# This function helps in creating the right attack instance (FGSM or PGD) based on the attack name passed
def make_attack(attack_name, cfg, model=None):
if attack_name == "None":
return None
elif attack_name == "FGSM":
atk = FGSMAttack(model=model, epsilon=cfg.epsilon)
return atk
elif attack_name == "PGD":
atk = PGDAttack(model=model, epsilon=cfg.epsilon, alpha=cfg.alpha, num_steps=cfg.num_steps,
random_init=cfg.random_init)
return atk
else:
raise ValueError(f"Unknown attack name: {attack_name}")
# Main training and evaluation loop across multiple tasks and attacks
def train_and_eval_sequential_tasks(cfg: Config, logger):
set_seed(cfg.seed) # Set random seed for reproducibility
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") # Set device (GPU or CPU)
# Load data based on the selected dataset
if cfg.dataset == "MNIST":
train_loader, test_loader = get_mnist_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = SmallCNN(num_channels=1, num_classes=10)
teacher_model = SmallCNN(num_channels=1, num_classes=10)
elif cfg.dataset == "CIFAR10":
train_loader, test_loader = get_cifar10_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = WideResNet(depth=34, widen_factor=10, drop_rate=0.0, num_classes=10)
teacher_model = WideResNet(depth=34, widen_factor=10, drop_rate=0.0, num_classes=10)
elif cfg.dataset == "CIFAR100":
train_loader, test_loader = get_cifar100_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = WideResNet(depth=34, widen_factor=20, drop_rate=0.0, num_classes=100)
teacher_model = WideResNet(depth=34, widen_factor=20, drop_rate=0.0, num_classes=100)
else:
raise ValueError(f"Unknown dataset: {cfg.dataset}")
model.to(device)
teacher_model.to(device)
teacher_model.eval() # Set the teacher model to evaluation mode (no gradient updates)
for p in teacher_model.parameters():
p.requires_grad = False # Freeze teacher model parameters
optimizer = optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum,
weight_decay=cfg.weight_decay) # Optimizer for training
# Configuration for the AIR defense method
defense_config_AIR = {'lambda_SD': cfg.lambda_SD, 'lambda_IR': cfg.lambda_IR, 'lambda_AR': cfg.lambda_AR,
'lambda_Reg': cfg.lambda_Reg, 'alpha_range': cfg.alpha_range, 'use_rdrop': cfg.use_rdrop,
'iso_noise_std': cfg.iso_noise_std, 'iso_clamp_min': cfg.iso_clamp_min,
'iso_clamp_max': cfg.iso_clamp_max, 'iso_p_flip': cfg.iso_p_flip,
'iso_flip_dim': cfg.iso_flip_dim, 'iso_p_rotation': cfg.iso_p_rotation,
'iso_max_rotation': cfg.iso_max_rotation, 'iso_p_crop': cfg.iso_p_crop,
'iso_p_erase': cfg.iso_p_erase}
# List to store results for each task
results_per_task = []
# Looping through each task and attack
for task_idx, atk_name in enumerate(cfg.attack_sequence):
logger.info(f"--- Task {task_idx + 1} / {len(cfg.attack_sequence)}: Attack = {atk_name} ---")
current_attack = make_attack(atk_name, cfg, model=None) # Get the attack instance
# Training based on the selected defense method
if cfg.defense_method == "AIR":
logger.info("Using AIR Defense...")
train_air_for_epochs(student_model=model, teacher_model=teacher_model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer, defense_config=defense_config_AIR,
attack=current_attack, device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "LFL":
logger.info("Using LFL Defense...")
model = train_lfl_for_epochs(student_model=model, teacher_model=teacher_model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer, lambda_lfl=cfg.lambda_lfl,
feature_lambda=cfg.feature_lambda, device=device, epochs=cfg.epochs,
logger=logger, freeze_classifier=cfg.freeze_classifier)
elif cfg.defense_method == "JointTraining":
logger.info("Using Joint Training Defense...")
model = train_joint_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, attack=current_attack, joint_lambda=cfg.joint_lambda,
device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "VanillaAT":
logger.info("Using Vanilla Adversarial Training Defense...")
model = train_vanilla_at_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, attack=current_attack, adv_lambda=cfg.adv_lambda,
device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "FeatExtraction":
logger.info("Using Feature Extraction Defense...")
model = train_feat_extraction_for_epochs(student_model=model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer,
feat_lambda=cfg.feat_lambda, noise_std=cfg.noise_std,
device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "EWC":
logger.info("Using EWC Defense...")
model = train_ewc_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, device=device, epochs=cfg.epochs, logger=logger,
lambda_ewc=cfg.lambda_ewc)
else:
raise ValueError(f"Unknown defense method: {cfg.defense_method}")
# After training, we evaluate the model's performance on clean data
clean_acc, clean_loss = evaluate_accuracy(model, test_loader, device=device)
logger.info(f"Clean Accuracy after Task {task_idx + 1}, Attack={atk_name}: {clean_acc:.2f}%")
# Evaluate robust accuracy if there was an attack
if current_attack is not None:
test_attack = make_attack(atk_name, cfg, model=None)
robust_acc_dict, robust_loss_dict, _ = evaluate_robust_accuracy(model, test_loader, [test_attack],
teacher_model=None, device=device)
for attack_name, acc in robust_acc_dict.items():
logger.info(f"Robust Accuracy on {atk_name} after Task {task_idx + 1}: {acc:.2f}%")
results_per_task.append({"task_idx": task_idx, "attack_name": atk_name, "clean_acc": clean_acc})
logger.info("=== All tasks completed! ===")
return model, results_per_task
# Main function to run everything
def main():
logger = get_logger("MultiTaskMain") # Get logger
cfg = Config() # Load configuration
logger.info("=== Configuration ===")
logger.info(cfg)
final_model, results = train_and_eval_sequential_tasks(cfg, logger) # Train and evaluate
logger.info("=== Final Results ===")
logger.info(results)
logger.info("Done!")
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment