-
Mina Moshfegh authoredMina Moshfegh authored
train.py 13.74 KiB
import torch
from utils import evaluate_accuracy, evaluate_robust_accuracy
from attacks.fgsm import FGSMAttack
from attacks.pgd import PGDAttack
from defenses.air import AIRDefense
from defenses.lfl import LFLDefense
from defenses.joint_training import JointTrainingDefense
from defenses.vanilla_at import VanillaAdversarialTrainingDefense
from defenses.feat_extraction import FeatureExtractionDefense
from defenses.ewc_defense import EWCDefense
# Function to train one step with AIR defense
def train_one_step_air(student_model, teacher_model, train_loader, optimizer, defense_config, attack=None,
device='cuda', logger=None):
student_model.train() # Set student model to training mode
teacher_model.eval() # Set teacher model to evaluation mode
for p in teacher_model.parameters():
p.requires_grad = False # Freeze the teacher model parameters during training
# Initialize AIR defense method with the provided configuration
air_defense = AIRDefense(student_model=student_model, teacher_model=teacher_model, **defense_config)
total_loss = 0.0 # Total loss tracker
total_samples = 0 # Number of samples processed
# Iterate through the training data
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device) # Move data to the correct device (GPU or CPU)
optimizer.zero_grad() # Zero gradients before the backward pass
# If an attack is provided, generate adversarial examples
if attack is not None:
attack.model = student_model
x_adv = attack.generate(x, y) # Generate adversarial examples
else:
x_adv = x # If no attack, use the original data
loss = air_defense.loss_function(x_adv, y) # Compute loss using AIR defense method
loss.backward() # Backpropagate the loss
optimizer.step() # Update the model parameters
batch_size = y.size(0) # Get the batch size
total_loss += loss.item() * batch_size # Accumulate the loss
total_samples += batch_size # Accumulate the number of samples
# Log progress every 100 batches (if logger is provided)
if logger and (batch_idx % 100 == 0):
logger.info(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
# Compute and return the average loss for the epoch
avg_loss = total_loss / total_samples
return avg_loss
# Function to train the model for multiple epochs using AIR defense
def train_air_for_epochs(student_model, teacher_model, train_loader, test_loader, optimizer, defense_config, attack,
device, epochs, logger=None):
for epoch in range(1, epochs + 1):
avg_loss = train_one_step_air(student_model, teacher_model, train_loader, optimizer, defense_config, attack,
device, logger)
if logger:
logger.info(f"Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
clean_acc, _ = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"Clean Test Accuracy: {clean_acc:.2f}%")
return student_model
# Function to train with LFL defense for multiple epochs
def train_lfl_for_epochs(student_model, teacher_model, train_loader, test_loader, optimizer, lambda_lfl, feature_lambda,
device, epochs, logger=None, freeze_classifier=True):
lfl_defense = LFLDefense(student_model=student_model, teacher_model=teacher_model, lambda_lfl=lambda_lfl,
freeze_classifier=freeze_classifier, feature_lambda=feature_lambda)
for epoch in range(1, epochs + 1):
student_model.train() # Set student model to training mode
teacher_model.eval() # Set teacher model to evaluation mode
total_loss = 0.0
total_samples = 0
# Iterate through the training data
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device) # Move data to the correct device
optimizer.zero_grad() # Zero gradients before the backward pass
loss = lfl_defense.loss_function(x, y) # Compute loss using LFL defense method
loss.backward() # Backpropagate the loss
optimizer.step() # Update model parameters
batch_size = y.size(0)
total_loss += loss.item() * batch_size # Accumulate loss
total_samples += batch_size # Accumulate the number of samples
# Log progress every 100 batches (if logger is provided)
if logger and (batch_idx % 100 == 0):
logger.info(f"LFL >> Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")
avg_loss = total_loss / total_samples # Calculate average loss for the epoch
if logger:
logger.info(f"[LFL] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
# Evaluate on clean test data after each epoch
clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"[LFL] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
return student_model
# Function to train with Joint Training defense for multiple epochs
def train_joint_for_epochs(student_model, train_loader, test_loader, optimizer, attack, joint_lambda, device, epochs,
logger=None):
jt_defense = JointTrainingDefense(student_model=student_model, joint_lambda=joint_lambda)
for epoch in range(1, epochs + 1):
student_model.train() # Set student model to training mode
total_loss = 0.0
total_samples = 0
# Iterate through the training data
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
# If attack is provided, generate adversarial examples
if attack is not None:
attack.model = student_model
x_adv = attack.generate(x, y)
else:
x_adv = None
loss = jt_defense.loss_function(x=x, y=y, x_adv=x_adv) # Compute loss using Joint Training defense
loss.backward() # Backpropagate the loss
optimizer.step() # Update model parameters
batch_size = y.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
if logger and (batch_idx % 100 == 0):
logger.info(f"JointTraining >> Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")
avg_loss = total_loss / total_samples # Compute the average loss for the epoch
if logger:
logger.info(f"[JointTraining] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
# Evaluate on clean test data after each epoch
clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"[JointTraining] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
return student_model
# Function to train with Vanilla Adversarial Training defense for multiple epochs
def train_vanilla_at_for_epochs(student_model, train_loader, test_loader, optimizer, attack, adv_lambda, device, epochs,
logger=None):
defense = VanillaAdversarialTrainingDefense(student_model=student_model, adv_lambda=adv_lambda)
for epoch in range(1, epochs + 1):
student_model.train() # Set student model to training mode
total_loss = 0.0
total_samples = 0
# Iterate through the training data
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device) # Move data to device
optimizer.zero_grad() # Zero gradients before the backward pass
# If attack is provided, generate adversarial examples
if attack is not None:
attack.model = student_model
x_adv = attack.generate(x, y)
else:
x_adv = None
loss = defense.loss_function(x, y, x_adv=x_adv) # Compute loss using Vanilla Adversarial Training
loss.backward() # Backpropagate the loss
optimizer.step() # Update model parameters
batch_size = y.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
if logger and (batch_idx % 100 == 0):
logger.info(f"[VanillaAT] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")
avg_loss = total_loss / total_samples # Calculate average loss for the epoch
if logger:
logger.info(f"[VanillaAT] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
# Evaluate clean accuracy on the test set after each epoch
clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"[VanillaAT] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
return student_model
# Function to train with Feature Extraction defense for multiple epochs
def train_feat_extraction_for_epochs(student_model, train_loader, test_loader, optimizer, feat_lambda, attack, device,
epochs, logger=None):
defense = FeatureExtractionDefense(student_model=student_model, feat_lambda=feat_lambda, attack=attack)
for epoch in range(1, epochs + 1):
student_model.train() # Set student model to training mode
total_loss = 0.0
total_samples = 0
# Iterate through the training data
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device) # Move data to device
optimizer.zero_grad() # Zero gradients before the backward pass
loss = defense.loss_function(x, y) # Compute loss using Feature Extraction defense
loss.backward() # Backpropagate the loss
optimizer.step() # Update model parameters
batch_size = y.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
if logger and (batch_idx % 100 == 0):
logger.info(f"[FeatExtraction] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")
avg_loss = total_loss / total_samples # Calculate average loss for the epoch
if logger:
logger.info(f"[FeatExtraction] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
# Evaluate on clean test data
clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"[FeatExtraction] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
return student_model
# Function to compute Fisher Information (for EWC)
def compute_fisher_information(model, data_loader, device="cuda"):
model.eval()
fisher_dict = {}
for n, p in model.named_parameters():
fisher_dict[n] = torch.zeros_like(p) # Initialize Fisher information for each parameter
total_samples = 0
for x, y in data_loader:
x, y = x.to(device), y.to(device) # Move data to device
model.zero_grad() # Zero gradients before forward pass
logits = model(x)
preds = logits.argmax(dim=1)
loss = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(logits, dim=1), preds, reduction='sum')
loss.backward() # Backpropagate the loss
batch_size = y.size(0)
for n, p in model.named_parameters():
if p.grad is not None:
fisher_dict[n] += p.grad.data.pow(2) * batch_size # Accumulate Fisher information
total_samples += batch_size
# Normalize Fisher information
for n in fisher_dict:
fisher_dict[n] /= float(total_samples)
return fisher_dict
# Function to train with EWC defense for multiple epochs
def train_ewc_for_epochs(student_model, train_loader, test_loader, optimizer, device, epochs, logger, lambda_ewc=100.0):
ewc_defense = EWCDefense(student_model=student_model, lambda_ewc=lambda_ewc)
ewc_defense.set_old_params() # Initialize old parameters for EWC
ewc_defense.set_fisher({n: torch.zeros_like(p) for n, p in student_model.named_parameters()})
for epoch in range(1, epochs + 1):
student_model.train() # Set student model to training mode
total_loss, total_samples = 0.0, 0
for batch_idx, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device) # Move data to device
optimizer.zero_grad() # Zero gradients before the backward pass
loss = ewc_defense.loss_function(x, y) # Compute loss using EWC defense
loss.backward() # Backpropagate the loss
optimizer.step() # Update model parameters
bs = y.size(0)
total_loss += loss.item() * bs
total_samples += bs
if logger and (batch_idx % 100 == 0):
logger.info(f"[EWC] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")
avg_loss = total_loss / total_samples # Compute the average loss for the epoch
if logger:
logger.info(f"[EWC] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")
clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
if logger:
logger.info(f"[EWC] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
fisher_dict = compute_fisher_information(student_model, train_loader, device=device)
ewc_defense.set_fisher(fisher_dict)
ewc_defense.set_old_params()
return student_model