๐ก๏ธ Lesson 1: Adversarial Training
Learn about adversarial training, robust optimization, and defense against adversarial attacks
๐ Learning Objectives
By the end of this lesson, you will be able to:
- Understand adversarial training principles
- Implement robust optimization techniques
- Apply different adversarial training methods
- Evaluate model robustness
- Compare defense effectiveness
- Optimize training hyperparameters
๐ฏ Understanding Adversarial Training
What is Adversarial Training?
Adversarial training is a defense technique that improves model robustness by training on both clean and adversarial examples. The model learns to correctly classify inputs even when they contain adversarial perturbations.
๐ฏ Key Principles:
- Min-Max Optimization: Train against worst-case perturbations
- Data Augmentation: Include adversarial examples in training
- Robust Optimization: Minimize loss under adversarial conditions
- Transferability: Improve robustness across attack types
Mathematical Foundation
๐ Adversarial Training Objective:
min_ฮธ E_{(x,y)~D} [max_ฮด L(f_ฮธ(x + ฮด), y)]
subject to: ||ฮด||_p โค ฮต
Where:
- ฮธ: Model parameters
- (x,y): Clean training examples
- ฮด: Adversarial perturbation
- ฮต: Perturbation budget
- L: Loss function
โ๏ธ Implementation Strategies
Basic Adversarial Training
The simplest approach generates adversarial examples during training and includes them in the loss function.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class AdversarialTrainer:
def __init__(self, model, attack_method='fgsm', epsilon=0.03):
"""
Initialize adversarial trainer
Args:
model: Model to train
attack_method: Method for generating adversarial examples
epsilon: Perturbation magnitude
"""
self.model = model
self.attack_method = attack_method
self.epsilon = epsilon
def generate_adversarial_examples(self, inputs, targets):
"""
Generate adversarial examples using specified attack method
Args:
inputs: Clean input data
targets: True labels
Returns:
adversarial_inputs: Perturbed inputs
"""
if self.attack_method == 'fgsm':
return self.fgsm_attack(inputs, targets)
elif self.attack_method == 'pgd':
return self.pgd_attack(inputs, targets)
else:
raise ValueError(f"Unknown attack method: {self.attack_method}")
def fgsm_attack(self, inputs, targets):
"""Fast Gradient Sign Method attack"""
inputs.requires_grad = True
# Forward pass
outputs = self.model(inputs)
loss = F.cross_entropy(outputs, targets)
# Compute gradients
self.model.zero_grad()
loss.backward()
# Generate adversarial examples
adversarial_inputs = inputs + self.epsilon * inputs.grad.sign()
adversarial_inputs = torch.clamp(adversarial_inputs, 0, 1)
return adversarial_inputs.detach()
def pgd_attack(self, inputs, targets, num_steps=10, step_size=0.01):
"""Projected Gradient Descent attack"""
adversarial_inputs = inputs.clone()
for _ in range(num_steps):
adversarial_inputs.requires_grad = True
# Forward pass
outputs = self.model(adversarial_inputs)
loss = F.cross_entropy(outputs, targets)
# Compute gradients
self.model.zero_grad()
loss.backward()
# Update adversarial examples
with torch.no_grad():
adversarial_inputs = adversarial_inputs + step_size * adversarial_inputs.grad.sign()
# Project back to epsilon ball
delta = torch.clamp(adversarial_inputs - inputs, -self.epsilon, self.epsilon)
adversarial_inputs = inputs + delta
adversarial_inputs = torch.clamp(adversarial_inputs, 0, 1)
return adversarial_inputs.detach()
def adversarial_training_step(self, inputs, targets, optimizer):
"""
Single step of adversarial training
Args:
inputs: Clean training data
targets: True labels
optimizer: Model optimizer
Returns:
total_loss: Combined clean and adversarial loss
"""
# Generate adversarial examples
adversarial_inputs = self.generate_adversarial_examples(inputs, targets)
# Forward pass on clean data
clean_outputs = self.model(inputs)
clean_loss = F.cross_entropy(clean_outputs, targets)
# Forward pass on adversarial data
adversarial_outputs = self.model(adversarial_inputs)
adversarial_loss = F.cross_entropy(adversarial_outputs, targets)
# Combined loss (equal weighting)
total_loss = 0.5 * clean_loss + 0.5 * adversarial_loss
# Backward pass
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return total_loss.item()
def train_epoch(self, dataloader, optimizer):
"""
Train model for one epoch with adversarial training
Args:
dataloader: Training data loader
optimizer: Model optimizer
Returns:
average_loss: Average training loss
"""
self.model.train()
total_loss = 0
num_batches = 0
for inputs, targets in dataloader:
if torch.cuda.is_available():
inputs, targets = inputs.cuda(), targets.cuda()
loss = self.adversarial_training_step(inputs, targets, optimizer)
total_loss += loss
num_batches += 1
return total_loss / num_batches
# Example usage
def train_robust_model():
"""
Train a robust model using adversarial training
"""
# Load model and data
model = create_model() # Your model here
train_loader = create_dataloader() # Your dataloader here
# Initialize adversarial trainer
trainer = AdversarialTrainer(model, attack_method='pgd', epsilon=0.03)
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
avg_loss = trainer.train_epoch(train_loader, optimizer)
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
# Evaluate robustness periodically
if (epoch + 1) % 10 == 0:
robust_accuracy = evaluate_robustness(model, test_loader)
print(f"Robust accuracy: {robust_accuracy:.4f}")
return model
๐งช Hands-On Exercise
Exercise: Implement and Compare Adversarial Training Methods
Objective: Implement different adversarial training approaches and compare their effectiveness.
๐ Steps:
- Setup Environment - Load CIFAR-10 dataset and create model
- Implement Basic Adversarial Training - FGSM-based training
- Implement PGD Adversarial Training - Multi-step training
- Compare Robustness - Test against various attacks
- Analyze Results - Compare clean vs robust accuracy
๐ Deliverables:
- Working adversarial training implementation
- Comparison of FGSM vs PGD training
- Robustness evaluation results
- Analysis of training efficiency
- Recommendations for production use