Newer
Older
import torch
import random
# Anisotropic replay mixes a batch of data with a shuffled version of itself
# (like a partial mixup), optionally querying a teacher model for logits.
def anisotropic_replay(X_t, teacher_model=None, alpha_range=(0.3, 0.7)):
batch_size = X_t.size(0)
# Shuffle the indices to create a randomly permuted version of X_t
idx = torch.randperm(batch_size, device=X_t.device)
x_shuffle_t = X_t[idx]
# Random alpha from a range to blend the current sample with a shuffled sample
alpha = random.uniform(*alpha_range)
X_AR_t = alpha * X_t + (1 - alpha) * x_shuffle_t
# If we have a teacher model, let's get its logits on the mixed data
teacher_logits = None
if teacher_model is not None:
teacher_model.eval()
with torch.no_grad():
teacher_logits = teacher_model(X_AR_t)
return X_AR_t, teacher_logits