Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
Defense without Forgetting_Continual Adversarial Defense with Anisotropic and Isotropic Pseudo Replay - Reproduction
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Mina Moshfegh
Defense without Forgetting_Continual Adversarial Defense with Anisotropic and Isotropic Pseudo Replay - Reproduction
Commits
c87adeee
Commit
c87adeee
authored
3 weeks ago
by
Mina Moshfegh
Browse files
Options
Downloads
Patches
Plain Diff
Upload New File
parent
4a14ba0d
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/main.py
+179
-0
179 additions, 0 deletions
src/main.py
with
179 additions
and
0 deletions
src/main.py
0 → 100644
+
179
−
0
View file @
c87adeee
import
os
import
torch
import
torch.optim
as
optim
import
random
import
numpy
as
np
from
utils
import
get_logger
,
evaluate_accuracy
,
evaluate_robust_accuracy
,
Config
from
data
import
get_mnist_loaders
,
get_cifar10_loaders
,
get_cifar100_loaders
from
models
import
SmallCNN
,
WideResNet
from
attacks.fgsm
import
FGSMAttack
from
attacks.pgd
import
PGDAttack
from
attacks.no_attack
import
NoOpAttack
from
train
import
(
train_air_for_epochs
,
train_lfl_for_epochs
,
train_joint_for_epochs
,
train_vanilla_at_for_epochs
,
train_feat_extraction_for_epochs
,
train_ewc_for_epochs
)
# Function to set random seed to ensure reproducibility of results
def
set_seed
(
seed
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
def
save_model
(
model
,
dataset
,
defense_method
,
attack_name
,
task_idx
,
model_dir
=
"
./models
"
):
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
model_filename
=
f
"
{
dataset
}
_
{
defense_method
}
_
{
attack_name
}
_task
{
task_idx
}
.pt
"
model_path
=
os
.
path
.
join
(
model_dir
,
model_filename
)
torch
.
save
(
model
.
state_dict
(),
model_path
)
return
model_path
# This function helps in creating the right attack instance (FGSM or PGD) based on the attack name passed
def
make_attack
(
attack_name
,
cfg
,
model
=
None
):
if
attack_name
==
"
None
"
:
atk
=
NoOpAttack
(
model
=
model
)
return
atk
elif
attack_name
==
"
FGSM
"
:
atk
=
FGSMAttack
(
model
=
model
,
epsilon
=
cfg
.
epsilon
)
return
atk
elif
attack_name
==
"
PGD
"
:
atk
=
PGDAttack
(
model
=
model
,
epsilon
=
cfg
.
epsilon
,
alpha
=
cfg
.
alpha
,
num_steps
=
cfg
.
num_steps
,
random_init
=
cfg
.
random_init
)
return
atk
else
:
raise
ValueError
(
f
"
Unknown attack name:
{
attack_name
}
"
)
# Main training and evaluation loop across multiple tasks and attacks
def
train_and_eval_sequential_tasks
(
cfg
:
Config
,
logger
):
set_seed
(
cfg
.
seed
)
# Set random seed for reproducibility
device
=
torch
.
device
(
cfg
.
device
if
torch
.
cuda
.
is_available
()
else
"
cpu
"
)
# Set device (GPU or CPU)
# Load data based on the selected dataset
if
cfg
.
dataset
==
"
MNIST
"
:
train_loader
,
test_loader
=
get_mnist_loaders
(
batch_size
=
cfg
.
batch_size
,
root
=
cfg
.
data_root
,
num_workers
=
cfg
.
num_workers
)
model
=
SmallCNN
(
num_channels
=
1
,
num_classes
=
10
)
teacher_model
=
SmallCNN
(
num_channels
=
1
,
num_classes
=
10
)
elif
cfg
.
dataset
==
"
CIFAR10
"
:
train_loader
,
test_loader
=
get_cifar10_loaders
(
batch_size
=
cfg
.
batch_size
,
root
=
cfg
.
data_root
,
num_workers
=
cfg
.
num_workers
)
model
=
WideResNet
(
depth
=
34
,
widen_factor
=
10
,
drop_rate
=
0.0
,
num_classes
=
10
)
teacher_model
=
WideResNet
(
depth
=
34
,
widen_factor
=
10
,
drop_rate
=
0.0
,
num_classes
=
10
)
elif
cfg
.
dataset
==
"
CIFAR100
"
:
train_loader
,
test_loader
=
get_cifar100_loaders
(
batch_size
=
cfg
.
batch_size
,
root
=
cfg
.
data_root
,
num_workers
=
cfg
.
num_workers
)
model
=
WideResNet
(
depth
=
34
,
widen_factor
=
20
,
drop_rate
=
0.0
,
num_classes
=
100
)
teacher_model
=
WideResNet
(
depth
=
34
,
widen_factor
=
20
,
drop_rate
=
0.0
,
num_classes
=
100
)
else
:
raise
ValueError
(
f
"
Unknown dataset:
{
cfg
.
dataset
}
"
)
model
.
to
(
device
)
teacher_model
.
to
(
device
)
teacher_model
.
eval
()
# Set the teacher model to evaluation mode (no gradient updates)
for
p
in
teacher_model
.
parameters
():
p
.
requires_grad
=
False
# Freeze teacher model parameters
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
cfg
.
learning_rate
,
momentum
=
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
# Configuration for the AIR defense method
defense_config_AIR
=
{
'
lambda_SD
'
:
cfg
.
lambda_SD
,
'
lambda_IR
'
:
cfg
.
lambda_IR
,
'
lambda_AR
'
:
cfg
.
lambda_AR
,
'
lambda_Reg
'
:
cfg
.
lambda_Reg
,
'
alpha_range
'
:
cfg
.
alpha_range
,
'
use_rdrop
'
:
cfg
.
use_rdrop
,
'
iso_noise_std
'
:
cfg
.
iso_noise_std
,
'
iso_clamp_min
'
:
cfg
.
iso_clamp_min
,
'
iso_clamp_max
'
:
cfg
.
iso_clamp_max
,
'
iso_p_flip
'
:
cfg
.
iso_p_flip
,
'
iso_flip_dim
'
:
cfg
.
iso_flip_dim
,
'
iso_p_rotation
'
:
cfg
.
iso_p_rotation
,
'
iso_max_rotation
'
:
cfg
.
iso_max_rotation
,
'
iso_p_crop
'
:
cfg
.
iso_p_crop
,
'
iso_p_erase
'
:
cfg
.
iso_p_erase
}
# List to store results for each task
results_per_task
=
[]
# Looping through each task and attack
for
task_idx
,
atk_name
in
enumerate
(
cfg
.
attack_sequence
):
logger
.
info
(
f
"
--- Task
{
task_idx
+
1
}
/
{
len
(
cfg
.
attack_sequence
)
}
: Attack =
{
atk_name
}
---
"
)
current_attack
=
make_attack
(
atk_name
,
cfg
,
model
=
None
)
# Get the attack instance
# Training based on the selected defense method
if
cfg
.
defense_method
==
"
AIR
"
:
logger
.
info
(
"
Using AIR Defense...
"
)
model
=
train_air_for_epochs
(
student_model
=
model
,
teacher_model
=
teacher_model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
defense_config
=
defense_config_AIR
,
attack
=
current_attack
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
)
elif
cfg
.
defense_method
==
"
LFL
"
:
logger
.
info
(
"
Using LFL Defense...
"
)
model
=
train_lfl_for_epochs
(
student_model
=
model
,
teacher_model
=
teacher_model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
lambda_lfl
=
cfg
.
lambda_lfl
,
feature_lambda
=
cfg
.
feature_lambda
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
,
freeze_classifier
=
cfg
.
freeze_classifier
)
elif
cfg
.
defense_method
==
"
JointTraining
"
:
logger
.
info
(
"
Using Joint Training Defense...
"
)
model
=
train_joint_for_epochs
(
student_model
=
model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
attack
=
current_attack
,
joint_lambda
=
cfg
.
joint_lambda
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
)
elif
cfg
.
defense_method
==
"
VanillaAT
"
:
logger
.
info
(
"
Using Vanilla Adversarial Training Defense...
"
)
model
=
train_vanilla_at_for_epochs
(
student_model
=
model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
attack
=
current_attack
,
adv_lambda
=
cfg
.
adv_lambda
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
)
elif
cfg
.
defense_method
==
"
FeatExtraction
"
:
logger
.
info
(
"
Using Feature Extraction Defense...
"
)
model
=
train_feat_extraction_for_epochs
(
student_model
=
model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
feat_lambda
=
cfg
.
feat_lambda
,
attack
=
current_attack
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
)
elif
cfg
.
defense_method
==
"
EWC
"
:
logger
.
info
(
"
Using EWC Defense...
"
)
model
=
train_ewc_for_epochs
(
student_model
=
model
,
train_loader
=
train_loader
,
test_loader
=
test_loader
,
optimizer
=
optimizer
,
device
=
device
,
epochs
=
cfg
.
epochs
,
logger
=
logger
,
lambda_ewc
=
cfg
.
lambda_ewc
)
else
:
raise
ValueError
(
f
"
Unknown defense method:
{
cfg
.
defense_method
}
"
)
model_path
=
save_model
(
model
,
cfg
.
dataset
,
cfg
.
defense_method
,
atk_name
,
task_idx
)
logger
.
info
(
f
"
Model saved at:
{
model_path
}
"
)
# After training, we evaluate the model's performance on clean data
clean_acc
,
clean_loss
=
evaluate_accuracy
(
model
,
test_loader
,
device
=
device
)
logger
.
info
(
f
"
Clean Accuracy after Task
{
task_idx
+
1
}
, Attack=
{
atk_name
}
:
{
clean_acc
:
.
2
f
}
%
"
)
# Evaluate robust accuracy for all attacks in the sequence
for
atk_name
in
cfg
.
attack_sequence
:
test_attack
=
make_attack
(
atk_name
,
cfg
,
model
=
None
)
robust_acc_dict
,
robust_loss_dict
,
_
=
evaluate_robust_accuracy
(
model
,
test_loader
,
[
test_attack
],
teacher_model
=
None
,
device
=
device
)
for
attack_name
,
acc
in
robust_acc_dict
.
items
():
logger
.
info
(
f
"
Robust Accuracy on
{
atk_name
}
after Task
{
task_idx
+
1
}
:
{
acc
:
.
2
f
}
%
"
)
logger
.
info
(
"
=== All tasks completed! ===
"
)
return
model
,
results_per_task
# Main function to run everything
def
main
():
cfg
=
Config
()
# Load configuration
logger
=
get_logger
(
name
=
f
"
{
cfg
.
dataset
}
-
{
cfg
.
defense_method
}
-
{
str
(
cfg
.
attack_sequence
)
}
"
)
# Get logger
logger
.
info
(
"
=== Configuration ===
"
)
logger
.
info
(
cfg
)
final_model
,
results
=
train_and_eval_sequential_tasks
(
cfg
,
logger
)
# Train and evaluate
logger
.
info
(
"
=== Final Results ===
"
)
logger
.
info
(
results
)
logger
.
info
(
"
Done!
"
)
if
__name__
==
"
__main__
"
:
main
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment