Skip to content
Snippets Groups Projects
anisotropic_replay.py 919 B
Newer Older
Mina Moshfegh's avatar
Mina Moshfegh committed
import torch
import random


# Anisotropic replay mixes a batch of data with a shuffled version of itself
# (like a partial mixup), optionally querying a teacher model for logits.

def anisotropic_replay(X_t, teacher_model=None, alpha_range=(0.3, 0.7)):
    batch_size = X_t.size(0)
    # Shuffle the indices to create a randomly permuted version of X_t
    idx = torch.randperm(batch_size, device=X_t.device)
    x_shuffle_t = X_t[idx]

    # Random alpha from a range to blend the current sample with a shuffled sample
    alpha = random.uniform(*alpha_range)
    X_AR_t = alpha * X_t + (1 - alpha) * x_shuffle_t

    # If we have a teacher model, let's get its logits on the mixed data
    teacher_logits = None
    if teacher_model is not None:
        teacher_model.eval()
        with torch.no_grad():
            teacher_logits = teacher_model(X_AR_t)

    return X_AR_t, teacher_logits