๐งช Defensive Distillation
Use knowledge distillation techniques to create more robust models that are less vulnerable to adversarial attacks
๐ฏ Learning Objectives
- Understand the principles of defensive distillation
- Implement knowledge distillation for security
- Apply temperature scaling for robustness
- Evaluate distillation-based defenses
- Compare different distillation strategies
๐ Core Concepts
1. Knowledge Distillation Fundamentals
Knowledge distillation is a technique where a smaller "student" model learns from a larger "teacher" model by mimicking its soft predictions.
Standard Knowledge Distillation Process
1. Teacher Model
Large, complex model with high accuracy
2. Soft Targets
Probability distributions over classes
3. Student Model
Smaller, more efficient model
Mathematical Foundation
L_distill = ฮฑ * L_soft + (1-ฮฑ) * L_hard
Where:
- L_soft: Distillation loss using soft targets
- L_hard: Standard cross-entropy loss
- ฮฑ: Weighting parameter (typically 0.5-0.7)
Temperature Scaling
import torch
import torch.nn.functional as F
def temperature_scaling(logits, temperature):
"""
Apply temperature scaling to logits
"""
return logits / temperature
def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):
"""
Compute distillation loss
"""
# Soft targets from teacher
soft_targets = F.softmax(temperature_scaling(teacher_logits, temperature), dim=1)
# Soft predictions from student
soft_preds = F.log_softmax(temperature_scaling(student_logits, temperature), dim=1)
# Distillation loss (KL divergence)
soft_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
# Hard loss (cross-entropy)
hard_loss = F.cross_entropy(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
2. Defensive Distillation
Adaptation of knowledge distillation specifically designed to improve model robustness against adversarial attacks.
Key Differences from Standard Distillation
- Higher temperatures: Use higher temperature values (T > 1) to smooth predictions
- Adversarial training: Train student model on adversarial examples
- Robust teacher: Use a teacher model trained with adversarial training
- Gradient masking: Reduce gradient information available to attackers
Defensive Distillation Implementation
class DefensiveDistillation:
def __init__(self, teacher_model, student_model, temperature=10.0, alpha=0.7):
self.teacher_model = teacher_model
self.student_model = student_model
self.temperature = temperature
self.alpha = alpha
def train_step(self, x, y, optimizer, adversarial_generator=None):
"""
Single training step with defensive distillation
"""
# Generate adversarial examples if provided
if adversarial_generator:
x_adv = adversarial_generator.generate(x, y)
x_combined = torch.cat([x, x_adv], dim=0)
y_combined = torch.cat([y, y], dim=0)
else:
x_combined, y_combined = x, y
# Teacher predictions
with torch.no_grad():
teacher_logits = self.teacher_model(x_combined)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
# Student predictions
student_logits = self.student_model(x_combined)
# Compute distillation loss
loss = distillation_loss(
student_logits, teacher_logits, y_combined,
self.temperature, self.alpha
)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def evaluate_robustness(self, test_loader, attack_method, epsilon=0.3):
"""
Evaluate model robustness against attacks
"""
self.student_model.eval()
clean_correct = 0
adv_correct = 0
total = 0
for x, y in test_loader:
# Clean accuracy
with torch.no_grad():
clean_pred = self.student_model(x).argmax(dim=1)
clean_correct += (clean_pred == y).sum().item()
# Adversarial accuracy
x_adv = attack_method.generate(x, y)
with torch.no_grad():
adv_pred = self.student_model(x_adv).argmax(dim=1)
adv_correct += (adv_pred == y).sum().item()
total += x.size(0)
return {
'clean_accuracy': clean_correct / total,
'adversarial_accuracy': adv_correct / total,
'robustness_gap': (clean_correct - adv_correct) / total
}
3. Advanced Distillation Techniques
Enhanced distillation methods that provide additional security benefits.
Multi-Teacher Distillation
class MultiTeacherDistillation:
def __init__(self, teachers, student, temperature=10.0):
self.teachers = teachers
self.student = student
self.temperature = temperature
def ensemble_teacher_predictions(self, x):
"""
Combine predictions from multiple teachers
"""
teacher_logits = []
for teacher in self.teachers:
with torch.no_grad():
logits = teacher(x)
teacher_logits.append(logits)
# Average teacher predictions
ensemble_logits = torch.stack(teacher_logits).mean(dim=0)
return ensemble_logits
def compute_distillation_loss(self, x, y):
"""
Compute loss using ensemble teacher predictions
"""
# Ensemble teacher predictions
teacher_logits = self.ensemble_teacher_predictions(x)
# Student predictions
student_logits = self.student(x)
# Distillation loss
soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
soft_preds = F.log_softmax(student_logits / self.temperature, dim=1)
kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
ce_loss = F.cross_entropy(student_logits, y)
return 0.7 * kl_loss + 0.3 * ce_loss
Progressive Distillation
class ProgressiveDistillation:
def __init__(self, teacher, student, temperature_schedule):
self.teacher = teacher
self.student = student
self.temperature_schedule = temperature_schedule
def train_with_progressive_temperature(self, train_loader, epochs):
"""
Train with gradually increasing temperature
"""
optimizer = torch.optim.Adam(self.student.parameters())
for epoch in range(epochs):
# Get current temperature
current_temp = self.temperature_schedule[epoch]
epoch_loss = 0
for x, y in train_loader:
loss = self.train_step(x, y, optimizer, current_temp)
epoch_loss += loss
print(f"Epoch {epoch+1}/{epochs}, Temp: {current_temp:.1f}, Loss: {epoch_loss/len(train_loader):.4f}")
def train_step(self, x, y, optimizer, temperature):
"""
Single training step with specified temperature
"""
# Teacher predictions
with torch.no_grad():
teacher_logits = self.teacher(x)
soft_targets = F.softmax(teacher_logits / temperature, dim=1)
# Student predictions
student_logits = self.student(x)
# Distillation loss
soft_preds = F.log_softmax(student_logits / temperature, dim=1)
kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
# Standard loss
ce_loss = F.cross_entropy(student_logits, y)
total_loss = 0.7 * kl_loss + 0.3 * ce_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return total_loss.item()
๐ง Implementation Strategies
1. Two-Stage Training Process
Implement defensive distillation in two stages for optimal robustness.
class TwoStageDefensiveDistillation:
def __init__(self, teacher_model, student_model, temperature=10.0):
self.teacher_model = teacher_model
self.student_model = student_model
self.temperature = temperature
def stage1_train_teacher(self, train_loader, epochs=100):
"""
Stage 1: Train teacher model with adversarial training
"""
print("Stage 1: Training robust teacher model...")
optimizer = torch.optim.Adam(self.teacher_model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
# Adversarial training
for epoch in range(epochs):
epoch_loss = 0
for x, y in train_loader:
# Generate adversarial examples
x_adv = self.generate_adversarial_examples(x, y)
# Combine clean and adversarial data
x_combined = torch.cat([x, x_adv], dim=0)
y_combined = torch.cat([y, y], dim=0)
# Training step
loss = self.adversarial_training_step(x_combined, y_combined, optimizer)
epoch_loss += loss
scheduler.step()
print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
def stage2_distill_student(self, train_loader, epochs=50):
"""
Stage 2: Distill knowledge from robust teacher to student
"""
print("Stage 2: Distilling knowledge to student model...")
optimizer = torch.optim.Adam(self.student_model.parameters())
for epoch in range(epochs):
epoch_loss = 0
for x, y in train_loader:
loss = self.distillation_step(x, y, optimizer)
epoch_loss += loss
print(f"Student Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
def distillation_step(self, x, y, optimizer):
"""
Single distillation training step
"""
# Teacher predictions with high temperature
with torch.no_grad():
teacher_logits = self.teacher_model(x)
soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
# Student predictions
student_logits = self.student_model(x)
# Distillation loss
soft_preds = F.log_softmax(student_logits / self.temperature, dim=1)
kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
# Hard loss
ce_loss = F.cross_entropy(student_logits, y)
total_loss = 0.7 * kl_loss + 0.3 * ce_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return total_loss.item()
2. Adaptive Temperature Selection
Dynamically adjust temperature based on model performance and robustness metrics.
class AdaptiveTemperatureDistillation:
def __init__(self, teacher_model, student_model, initial_temp=10.0):
self.teacher_model = teacher_model
self.student_model = student_model
self.current_temp = initial_temp
self.temp_history = []
self.performance_history = []
def adaptive_training(self, train_loader, val_loader, max_epochs=100):
"""
Train with adaptive temperature adjustment
"""
optimizer = torch.optim.Adam(self.student_model.parameters())
for epoch in range(max_epochs):
# Training step
train_loss = self.train_epoch(train_loader, optimizer)
# Validation
val_metrics = self.evaluate_robustness(val_loader)
# Adjust temperature based on performance
self.adjust_temperature(val_metrics)
# Log progress
self.log_progress(epoch, train_loss, val_metrics)
# Early stopping
if self.should_stop():
break
def adjust_temperature(self, val_metrics):
"""
Adjust temperature based on validation performance
"""
current_robustness = val_metrics['robustness_score']
if len(self.performance_history) > 0:
prev_robustness = self.performance_history[-1]
# If robustness improved, maintain or slightly increase temperature
if current_robustness > prev_robustness:
self.current_temp = min(self.current_temp * 1.05, 20.0)
# If robustness degraded, decrease temperature
else:
self.current_temp = max(self.current_temp * 0.95, 1.0)
self.temp_history.append(self.current_temp)
self.performance_history.append(current_robustness)
def should_stop(self):
"""
Early stopping criterion
"""
if len(self.performance_history) < 10:
return False
# Check if performance has plateaued
recent_performance = self.performance_history[-10:]
if max(recent_performance) - min(recent_performance) < 0.01:
return True
return False
๐ Evaluation & Metrics
Robustness Metrics
Adversarial Accuracy
Model accuracy on adversarial examples
Robustness Gap
Difference between clean and adversarial accuracy
Distillation Quality
KL divergence between teacher and student predictions
Gradient Smoothness
Measure of gradient information reduction
Comparative Analysis
def compare_defense_methods(models, test_loader, attacks):
"""
Compare different defense methods
"""
results = {}
for model_name, model in models.items():
model_results = {}
# Clean accuracy
clean_acc = evaluate_accuracy(model, test_loader)
model_results['clean_accuracy'] = clean_acc
# Adversarial accuracy for each attack
for attack_name, attack in attacks.items():
adv_acc = evaluate_adversarial_accuracy(model, test_loader, attack)
model_results[f'{attack_name}_accuracy'] = adv_acc
# Compute robustness metrics
model_results['avg_robustness'] = np.mean([
model_results[f'{attack}_accuracy'] for attack in attacks.keys()
])
model_results['robustness_gap'] = clean_acc - model_results['avg_robustness']
results[model_name] = model_results
return results
def visualize_robustness_comparison(results):
"""
Create visualization of defense method comparison
"""
import matplotlib.pyplot as plt
models = list(results.keys())
clean_accs = [results[model]['clean_accuracy'] for model in models]
robust_accs = [results[model]['avg_robustness'] for model in models]
x = np.arange(len(models))
width = 0.35
fig, ax = plt.subplots()
ax.bar(x - width/2, clean_accs, width, label='Clean Accuracy', alpha=0.8)
ax.bar(x + width/2, robust_accs, width, label='Adversarial Accuracy', alpha=0.8)
ax.set_xlabel('Defense Methods')
ax.set_ylabel('Accuracy')
ax.set_title('Defense Method Comparison')
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=45)
ax.legend()
plt.tight_layout()
plt.show()
โ ๏ธ Limitations & Considerations
Known Limitations
Gradient Masking
- Problem: High temperatures can mask gradients, making attacks appear ineffective
- Reality: Models may still be vulnerable to stronger attacks
- Mitigation: Use adaptive attacks and thorough evaluation
Temperature Sensitivity
- Problem: Performance highly dependent on temperature selection
- Challenge: Optimal temperature varies across datasets and models
- Solution: Use adaptive temperature selection or grid search
Scalability Issues
- Problem: Requires training two models (teacher + student)
- Cost: Increased computational and memory requirements
- Trade-off: Balance between robustness and efficiency
Best Practices
1. Robust Teacher Training
Ensure the teacher model is robust before distillation to avoid transferring vulnerabilities.
2. Temperature Tuning
Experiment with different temperature values to find the optimal balance.
3. Comprehensive Evaluation
Test against multiple attack types, not just the ones used during training.
4. Adaptive Strategies
Use adaptive temperature and progressive distillation for better results.
5. Ensemble Methods
Combine distillation with other defense methods for stronger protection.
๐ฏ Hands-on Exercise
Exercise: Implement Defensive Distillation
Build a complete defensive distillation system and compare its robustness with standard training.
Tasks:
- Train a teacher model with adversarial training
- Implement defensive distillation for student model
- Experiment with different temperature values
- Evaluate robustness against multiple attacks
- Compare with baseline model performance
Expected Outcomes:
- Understanding of distillation mechanics
- Experience with temperature tuning
- Insight into robustness trade-offs
๐ป Starter Code
# TODO: Implement defensive distillation
class DefensiveDistillation:
def __init__(self, teacher, student, temperature=10.0):
self.teacher = teacher
self.student = student
self.temperature = temperature
def distill(self, train_loader, epochs=50):
# Your implementation here
pass
def evaluate_robustness(self, test_loader, attacks):
# Your implementation here
pass
# TODO: Compare with baseline
def compare_with_baseline():
# Your implementation here
pass