Skip to content
Snippets Groups Projects
Commit c26661bc authored by Mina Moshfegh's avatar Mina Moshfegh
Browse files

Upload New File

parent a3079d53
No related branches found
No related tags found
No related merge requests found
import torch
import random
# Anisotropic replay mixes a batch of data with a shuffled version of itself
# (like a partial mixup), optionally querying a teacher model for logits.
def anisotropic_replay(X_t, teacher_model=None, alpha_range=(0.3, 0.7)):
batch_size = X_t.size(0)
# Shuffle the indices to create a randomly permuted version of X_t
idx = torch.randperm(batch_size, device=X_t.device)
x_shuffle_t = X_t[idx]
# Random alpha from a range to blend the current sample with a shuffled sample
alpha = random.uniform(*alpha_range)
X_AR_t = alpha * X_t + (1 - alpha) * x_shuffle_t
# If we have a teacher model, let's get its logits on the mixed data
teacher_logits = None
if teacher_model is not None:
teacher_model.eval()
with torch.no_grad():
teacher_logits = teacher_model(X_AR_t)
return X_AR_t, teacher_logits
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment