diff --git a/src/utils/eval.py b/src/utils/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..d736a3eebf3b2ee2823fbc3f3ca732c4eac4b5c5 --- /dev/null +++ b/src/utils/eval.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F + + +# Evaluate clean accuracy (no adversarial attacks) on a given data_loader +def evaluate_accuracy(model, data_loader, device="cuda"): + model.eval() + correct, total, total_loss = 0, 0, 0.0 + + with torch.no_grad(): + for x, y in data_loader: + x, y = x.to(device), y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y, reduction='sum') + total_loss += loss.item() + + preds = logits.argmax(dim=1) + correct += (preds == y).sum().item() + total += y.size(0) + + avg_loss = total_loss / total + accuracy = 100.0 * correct / total + return accuracy, avg_loss + + +# Evaluate robustness by running each attack in attack_list +def evaluate_robust_accuracy(model, data_loader, attack_list, teacher_model=None, device="cuda"): + model.eval() + + robust_acc_dict = {} + loss_dict = {} + consistency_dict = {} + + with torch.no_grad(): + for attack in attack_list: + # Temporarily link the attack's model to 'model' + attack.model = model + correct, total, total_loss, consistency_loss = 0, 0, 0.0, 0.0 + + for x, y in data_loader: + x, y = x.to(device), y.to(device) + + # Generate adversarial examples + x_adv = attack.generate(x, y) + logits = model(x_adv) + loss = F.cross_entropy(logits, y, reduction='sum') + total_loss += loss.item() + + preds = logits.argmax(dim=1) + correct += (preds == y).sum().item() + total += y.size(0) + + # If we have a teacher model, measure KL between teacher and student + if teacher_model is not None: + teacher_model.eval() + teacher_logits = teacher_model(x_adv) + consistency_loss += F.kl_div(F.log_softmax(logits, dim=1), F.softmax(teacher_logits, dim=1), + reduction='batchmean').item() + + robust_acc_dict[attack.__class__.__name__] = 100.0 * correct / total + loss_dict[attack.__class__.__name__] = total_loss / total + if teacher_model is not None: + consistency_dict[attack.__class__.__name__] = consistency_loss / total + + return robust_acc_dict, loss_dict, consistency_dict