Skip to content
Snippets Groups Projects
train.py 13.74 KiB
import torch
from utils import evaluate_accuracy, evaluate_robust_accuracy
from attacks.fgsm import FGSMAttack
from attacks.pgd import PGDAttack
from defenses.air import AIRDefense
from defenses.lfl import LFLDefense
from defenses.joint_training import JointTrainingDefense
from defenses.vanilla_at import VanillaAdversarialTrainingDefense
from defenses.feat_extraction import FeatureExtractionDefense
from defenses.ewc_defense import EWCDefense


# Function to train one step with AIR defense
def train_one_step_air(student_model, teacher_model, train_loader, optimizer, defense_config, attack=None,
                       device='cuda', logger=None):
    student_model.train()  # Set student model to training mode
    teacher_model.eval()  # Set teacher model to evaluation mode
    for p in teacher_model.parameters():
        p.requires_grad = False  # Freeze the teacher model parameters during training

    # Initialize AIR defense method with the provided configuration
    air_defense = AIRDefense(student_model=student_model, teacher_model=teacher_model, **defense_config)

    total_loss = 0.0  # Total loss tracker
    total_samples = 0  # Number of samples processed

    # Iterate through the training data
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)  # Move data to the correct device (GPU or CPU)
        optimizer.zero_grad()  # Zero gradients before the backward pass

        # If an attack is provided, generate adversarial examples
        if attack is not None:
            attack.model = student_model
            x_adv = attack.generate(x, y)  # Generate adversarial examples
        else:
            x_adv = x  # If no attack, use the original data

        loss = air_defense.loss_function(x_adv, y)  # Compute loss using AIR defense method
        loss.backward()  # Backpropagate the loss
        optimizer.step()  # Update the model parameters

        batch_size = y.size(0)  # Get the batch size
        total_loss += loss.item() * batch_size  # Accumulate the loss
        total_samples += batch_size  # Accumulate the number of samples

        # Log progress every 100 batches (if logger is provided)
        if logger and (batch_idx % 100 == 0):
            logger.info(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

    # Compute and return the average loss for the epoch
    avg_loss = total_loss / total_samples
    return avg_loss


# Function to train the model for multiple epochs using AIR defense
def train_air_for_epochs(student_model, teacher_model, train_loader, test_loader, optimizer, defense_config, attack,
                         device, epochs, logger=None):
    for epoch in range(1, epochs + 1):
        avg_loss = train_one_step_air(student_model, teacher_model, train_loader, optimizer, defense_config, attack,
                                      device, logger)
        if logger:
            logger.info(f"Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        clean_acc, _ = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"Clean Test Accuracy: {clean_acc:.2f}%")

    return student_model


# Function to train with LFL defense for multiple epochs
def train_lfl_for_epochs(student_model, teacher_model, train_loader, test_loader, optimizer, lambda_lfl, feature_lambda,
                         device, epochs, logger=None, freeze_classifier=True):
    lfl_defense = LFLDefense(student_model=student_model, teacher_model=teacher_model, lambda_lfl=lambda_lfl,
                             freeze_classifier=freeze_classifier, feature_lambda=feature_lambda)

    for epoch in range(1, epochs + 1):
        student_model.train()  # Set student model to training mode
        teacher_model.eval()  # Set teacher model to evaluation mode
        total_loss = 0.0
        total_samples = 0

        # Iterate through the training data
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)  # Move data to the correct device
            optimizer.zero_grad()  # Zero gradients before the backward pass

            loss = lfl_defense.loss_function(x, y)  # Compute loss using LFL defense method
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update model parameters

            batch_size = y.size(0)
            total_loss += loss.item() * batch_size  # Accumulate loss
            total_samples += batch_size  # Accumulate the number of samples

            # Log progress every 100 batches (if logger is provided)
            if logger and (batch_idx % 100 == 0):
                logger.info(f"LFL >> Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")

        avg_loss = total_loss / total_samples  # Calculate average loss for the epoch
        if logger:
            logger.info(f"[LFL] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        # Evaluate on clean test data after each epoch
        clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"[LFL] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")

    return student_model


# Function to train with Joint Training defense for multiple epochs
def train_joint_for_epochs(student_model, train_loader, test_loader, optimizer, attack, joint_lambda, device, epochs,
                           logger=None):
    jt_defense = JointTrainingDefense(student_model=student_model, joint_lambda=joint_lambda)

    for epoch in range(1, epochs + 1):
        student_model.train()  # Set student model to training mode
        total_loss = 0.0
        total_samples = 0

        # Iterate through the training data
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()

            # If attack is provided, generate adversarial examples
            if attack is not None:
                attack.model = student_model
                x_adv = attack.generate(x, y)
            else:
                x_adv = None

            loss = jt_defense.loss_function(x=x, y=y, x_adv=x_adv)  # Compute loss using Joint Training defense
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update model parameters

            batch_size = y.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            if logger and (batch_idx % 100 == 0):
                logger.info(f"JointTraining >> Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")

        avg_loss = total_loss / total_samples  # Compute the average loss for the epoch
        if logger:
            logger.info(f"[JointTraining] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        # Evaluate on clean test data after each epoch
        clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"[JointTraining] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")
    return student_model


# Function to train with Vanilla Adversarial Training defense for multiple epochs
def train_vanilla_at_for_epochs(student_model, train_loader, test_loader, optimizer, attack, adv_lambda, device, epochs,
                                logger=None):
    defense = VanillaAdversarialTrainingDefense(student_model=student_model, adv_lambda=adv_lambda)

    for epoch in range(1, epochs + 1):
        student_model.train()  # Set student model to training mode
        total_loss = 0.0
        total_samples = 0

        # Iterate through the training data
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)  # Move data to device
            optimizer.zero_grad()  # Zero gradients before the backward pass

            # If attack is provided, generate adversarial examples
            if attack is not None:
                attack.model = student_model
                x_adv = attack.generate(x, y)
            else:
                x_adv = None

            loss = defense.loss_function(x, y, x_adv=x_adv)  # Compute loss using Vanilla Adversarial Training
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update model parameters

            batch_size = y.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            if logger and (batch_idx % 100 == 0):
                logger.info(f"[VanillaAT] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")

        avg_loss = total_loss / total_samples  # Calculate average loss for the epoch
        if logger:
            logger.info(f"[VanillaAT] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        # Evaluate clean accuracy on the test set after each epoch
        clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"[VanillaAT] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")

    return student_model


# Function to train with Feature Extraction defense for multiple epochs
def train_feat_extraction_for_epochs(student_model, train_loader, test_loader, optimizer, feat_lambda, attack, device,
                                     epochs, logger=None):
    defense = FeatureExtractionDefense(student_model=student_model, feat_lambda=feat_lambda, attack=attack)

    for epoch in range(1, epochs + 1):
        student_model.train()  # Set student model to training mode
        total_loss = 0.0
        total_samples = 0

        # Iterate through the training data
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)  # Move data to device
            optimizer.zero_grad()  # Zero gradients before the backward pass

            loss = defense.loss_function(x, y)  # Compute loss using Feature Extraction defense
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update model parameters

            batch_size = y.size(0)
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            if logger and (batch_idx % 100 == 0):
                logger.info(f"[FeatExtraction] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")

        avg_loss = total_loss / total_samples  # Calculate average loss for the epoch
        if logger:
            logger.info(f"[FeatExtraction] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        # Evaluate on clean test data
        clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"[FeatExtraction] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")

    return student_model


# Function to compute Fisher Information (for EWC)
def compute_fisher_information(model, data_loader, device="cuda"):
    model.eval()
    fisher_dict = {}
    for n, p in model.named_parameters():
        fisher_dict[n] = torch.zeros_like(p)  # Initialize Fisher information for each parameter

    total_samples = 0
    for x, y in data_loader:
        x, y = x.to(device), y.to(device)  # Move data to device
        model.zero_grad()  # Zero gradients before forward pass
        logits = model(x)
        preds = logits.argmax(dim=1)
        loss = torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(logits, dim=1), preds, reduction='sum')
        loss.backward()  # Backpropagate the loss

        batch_size = y.size(0)
        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher_dict[n] += p.grad.data.pow(2) * batch_size  # Accumulate Fisher information
        total_samples += batch_size

    # Normalize Fisher information
    for n in fisher_dict:
        fisher_dict[n] /= float(total_samples)
    return fisher_dict


# Function to train with EWC defense for multiple epochs
def train_ewc_for_epochs(student_model, train_loader, test_loader, optimizer, device, epochs, logger, lambda_ewc=100.0):
    ewc_defense = EWCDefense(student_model=student_model, lambda_ewc=lambda_ewc)
    ewc_defense.set_old_params()  # Initialize old parameters for EWC
    ewc_defense.set_fisher({n: torch.zeros_like(p) for n, p in student_model.named_parameters()})

    for epoch in range(1, epochs + 1):
        student_model.train()  # Set student model to training mode
        total_loss, total_samples = 0.0, 0

        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)  # Move data to device
            optimizer.zero_grad()  # Zero gradients before the backward pass

            loss = ewc_defense.loss_function(x, y)  # Compute loss using EWC defense
            loss.backward()  # Backpropagate the loss
            optimizer.step()  # Update model parameters

            bs = y.size(0)
            total_loss += loss.item() * bs
            total_samples += bs

            if logger and (batch_idx % 100 == 0):
                logger.info(f"[EWC] Epoch:{epoch} Batch:{batch_idx} Loss:{loss.item():.4f}")

        avg_loss = total_loss / total_samples  # Compute the average loss for the epoch
        if logger:
            logger.info(f"[EWC] Epoch [{epoch}/{epochs}] - Train Loss: {avg_loss:.4f}")

        clean_acc, clean_loss = evaluate_accuracy(student_model, test_loader, device=device)
        if logger:
            logger.info(f"[EWC] Clean Test Acc: {clean_acc:.2f}% Loss:{clean_loss:.4f}")

        fisher_dict = compute_fisher_information(student_model, train_loader, device=device)
        ewc_defense.set_fisher(fisher_dict)
        ewc_defense.set_old_params()

    return student_model