Skip to content
Snippets Groups Projects
Commit 47f66038 authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Delete eval.py

parent 93d5ac34
No related branches found
No related tags found
No related merge requests found
import torch
import torch.nn.functional as F
# Evaluate clean accuracy (no adversarial attacks) on a given data_loader
def evaluate_accuracy(model, data_loader, device="cuda"):
model.eval()
correct, total, total_loss = 0, 0, 0.0
with torch.no_grad():
for x, y in data_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
loss = F.cross_entropy(logits, y, reduction='sum')
total_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
avg_loss = total_loss / total
accuracy = 100.0 * correct / total
return accuracy, avg_loss
# Evaluate robustness by running each attack in attack_list
def evaluate_robust_accuracy(model, data_loader, attack_list, teacher_model=None, device="cuda"):
model.eval()
robust_acc_dict = {}
loss_dict = {}
consistency_dict = {}
with torch.no_grad():
for attack in attack_list:
# Temporarily link the attack's model to 'model'
attack.model = model
correct, total, total_loss, consistency_loss = 0, 0, 0.0, 0.0
for x, y in data_loader:
x, y = x.to(device), y.to(device)
# Generate adversarial examples
x_adv = attack.generate(x, y)
logits = model(x_adv)
loss = F.cross_entropy(logits, y, reduction='sum')
total_loss += loss.item()
preds = logits.argmax(dim=1)
correct += (preds == y).sum().item()
total += y.size(0)
# If we have a teacher model, measure KL between teacher and student
if teacher_model is not None:
teacher_model.eval()
teacher_logits = teacher_model(x_adv)
consistency_loss += F.kl_div(F.log_softmax(logits, dim=1), F.softmax(teacher_logits, dim=1),
reduction='batchmean').item()
robust_acc_dict[attack.__class__.__name__] = 100.0 * correct / total
loss_dict[attack.__class__.__name__] = total_loss / total
if teacher_model is not None:
consistency_dict[attack.__class__.__name__] = consistency_loss / total
return robust_acc_dict, loss_dict, consistency_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment