Skip to content
Snippets Groups Projects
Commit c87adeee authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent 4a14ba0d
No related branches found
No related tags found
No related merge requests found
import os
import torch
import torch.optim as optim
import random
import numpy as np
from utils import get_logger, evaluate_accuracy, evaluate_robust_accuracy, Config
from data import get_mnist_loaders, get_cifar10_loaders, get_cifar100_loaders
from models import SmallCNN, WideResNet
from attacks.fgsm import FGSMAttack
from attacks.pgd import PGDAttack
from attacks.no_attack import NoOpAttack
from train import (train_air_for_epochs, train_lfl_for_epochs, train_joint_for_epochs, train_vanilla_at_for_epochs,
train_feat_extraction_for_epochs, train_ewc_for_epochs)
# Function to set random seed to ensure reproducibility of results
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def save_model(model, dataset, defense_method, attack_name, task_idx, model_dir="./models"):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
model_filename = f"{dataset}_{defense_method}_{attack_name}_task{task_idx}.pt"
model_path = os.path.join(model_dir, model_filename)
torch.save(model.state_dict(), model_path)
return model_path
# This function helps in creating the right attack instance (FGSM or PGD) based on the attack name passed
def make_attack(attack_name, cfg, model=None):
if attack_name == "None":
atk = NoOpAttack(model=model)
return atk
elif attack_name == "FGSM":
atk = FGSMAttack(model=model, epsilon=cfg.epsilon)
return atk
elif attack_name == "PGD":
atk = PGDAttack(model=model, epsilon=cfg.epsilon, alpha=cfg.alpha, num_steps=cfg.num_steps,
random_init=cfg.random_init)
return atk
else:
raise ValueError(f"Unknown attack name: {attack_name}")
# Main training and evaluation loop across multiple tasks and attacks
def train_and_eval_sequential_tasks(cfg: Config, logger):
set_seed(cfg.seed) # Set random seed for reproducibility
device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") # Set device (GPU or CPU)
# Load data based on the selected dataset
if cfg.dataset == "MNIST":
train_loader, test_loader = get_mnist_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = SmallCNN(num_channels=1, num_classes=10)
teacher_model = SmallCNN(num_channels=1, num_classes=10)
elif cfg.dataset == "CIFAR10":
train_loader, test_loader = get_cifar10_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = WideResNet(depth=34, widen_factor=10, drop_rate=0.0, num_classes=10)
teacher_model = WideResNet(depth=34, widen_factor=10, drop_rate=0.0, num_classes=10)
elif cfg.dataset == "CIFAR100":
train_loader, test_loader = get_cifar100_loaders(batch_size=cfg.batch_size, root=cfg.data_root,
num_workers=cfg.num_workers)
model = WideResNet(depth=34, widen_factor=20, drop_rate=0.0, num_classes=100)
teacher_model = WideResNet(depth=34, widen_factor=20, drop_rate=0.0, num_classes=100)
else:
raise ValueError(f"Unknown dataset: {cfg.dataset}")
model.to(device)
teacher_model.to(device)
teacher_model.eval() # Set the teacher model to evaluation mode (no gradient updates)
for p in teacher_model.parameters():
p.requires_grad = False # Freeze teacher model parameters
optimizer = optim.SGD(model.parameters(), lr=cfg.learning_rate, momentum=cfg.momentum,
weight_decay=cfg.weight_decay)
# Configuration for the AIR defense method
defense_config_AIR = {'lambda_SD': cfg.lambda_SD, 'lambda_IR': cfg.lambda_IR, 'lambda_AR': cfg.lambda_AR,
'lambda_Reg': cfg.lambda_Reg, 'alpha_range': cfg.alpha_range, 'use_rdrop': cfg.use_rdrop,
'iso_noise_std': cfg.iso_noise_std, 'iso_clamp_min': cfg.iso_clamp_min,
'iso_clamp_max': cfg.iso_clamp_max, 'iso_p_flip': cfg.iso_p_flip,
'iso_flip_dim': cfg.iso_flip_dim, 'iso_p_rotation': cfg.iso_p_rotation,
'iso_max_rotation': cfg.iso_max_rotation, 'iso_p_crop': cfg.iso_p_crop,
'iso_p_erase': cfg.iso_p_erase}
# List to store results for each task
results_per_task = []
# Looping through each task and attack
for task_idx, atk_name in enumerate(cfg.attack_sequence):
logger.info(f"--- Task {task_idx + 1} / {len(cfg.attack_sequence)}: Attack = {atk_name} ---")
current_attack = make_attack(atk_name, cfg, model=None) # Get the attack instance
# Training based on the selected defense method
if cfg.defense_method == "AIR":
logger.info("Using AIR Defense...")
model = train_air_for_epochs(student_model=model, teacher_model=teacher_model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer,
defense_config=defense_config_AIR, attack=current_attack, device=device,
epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "LFL":
logger.info("Using LFL Defense...")
model = train_lfl_for_epochs(student_model=model, teacher_model=teacher_model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer, lambda_lfl=cfg.lambda_lfl,
feature_lambda=cfg.feature_lambda, device=device, epochs=cfg.epochs,
logger=logger, freeze_classifier=cfg.freeze_classifier)
elif cfg.defense_method == "JointTraining":
logger.info("Using Joint Training Defense...")
model = train_joint_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, attack=current_attack, joint_lambda=cfg.joint_lambda,
device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "VanillaAT":
logger.info("Using Vanilla Adversarial Training Defense...")
model = train_vanilla_at_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, attack=current_attack, adv_lambda=cfg.adv_lambda,
device=device, epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "FeatExtraction":
logger.info("Using Feature Extraction Defense...")
model = train_feat_extraction_for_epochs(student_model=model, train_loader=train_loader,
test_loader=test_loader, optimizer=optimizer,
feat_lambda=cfg.feat_lambda, attack=current_attack, device=device,
epochs=cfg.epochs, logger=logger)
elif cfg.defense_method == "EWC":
logger.info("Using EWC Defense...")
model = train_ewc_for_epochs(student_model=model, train_loader=train_loader, test_loader=test_loader,
optimizer=optimizer, device=device, epochs=cfg.epochs, logger=logger,
lambda_ewc=cfg.lambda_ewc)
else:
raise ValueError(f"Unknown defense method: {cfg.defense_method}")
model_path = save_model(model, cfg.dataset, cfg.defense_method, atk_name, task_idx)
logger.info(f"Model saved at: {model_path}")
# After training, we evaluate the model's performance on clean data
clean_acc, clean_loss = evaluate_accuracy(model, test_loader, device=device)
logger.info(f"Clean Accuracy after Task {task_idx + 1}, Attack={atk_name}: {clean_acc:.2f}%")
# Evaluate robust accuracy for all attacks in the sequence
for atk_name in cfg.attack_sequence:
test_attack = make_attack(atk_name, cfg, model=None)
robust_acc_dict, robust_loss_dict, _ = evaluate_robust_accuracy(model, test_loader, [test_attack],
teacher_model=None, device=device)
for attack_name, acc in robust_acc_dict.items():
logger.info(f"Robust Accuracy on {atk_name} after Task {task_idx + 1}: {acc:.2f}%")
logger.info("=== All tasks completed! ===")
return model, results_per_task
# Main function to run everything
def main():
cfg = Config() # Load configuration
logger = get_logger(name=f"{cfg.dataset}-{cfg.defense_method}-{str(cfg.attack_sequence)}") # Get logger
logger.info("=== Configuration ===")
logger.info(cfg)
final_model, results = train_and_eval_sequential_tasks(cfg, logger) # Train and evaluate
logger.info("=== Final Results ===")
logger.info(results)
logger.info("Done!")
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment