๐ŸŽฏ Learning Objectives

๐Ÿ“š Core Concepts

1. Knowledge Distillation Fundamentals

Knowledge distillation is a technique where a smaller "student" model learns from a larger "teacher" model by mimicking its soft predictions.

Standard Knowledge Distillation Process

1. Teacher Model

Large, complex model with high accuracy

โ†’
2. Soft Targets

Probability distributions over classes

โ†’
3. Student Model

Smaller, more efficient model

Mathematical Foundation

L_distill = ฮฑ * L_soft + (1-ฮฑ) * L_hard

Where:

  • L_soft: Distillation loss using soft targets
  • L_hard: Standard cross-entropy loss
  • ฮฑ: Weighting parameter (typically 0.5-0.7)

Temperature Scaling

import torch
import torch.nn.functional as F

def temperature_scaling(logits, temperature):
    """
    Apply temperature scaling to logits
    """
    return logits / temperature

def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):
    """
    Compute distillation loss
    """
    # Soft targets from teacher
    soft_targets = F.softmax(temperature_scaling(teacher_logits, temperature), dim=1)
    
    # Soft predictions from student
    soft_preds = F.log_softmax(temperature_scaling(student_logits, temperature), dim=1)
    
    # Distillation loss (KL divergence)
    soft_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
    
    # Hard loss (cross-entropy)
    hard_loss = F.cross_entropy(student_logits, labels)
    
    return alpha * soft_loss + (1 - alpha) * hard_loss

2. Defensive Distillation

Adaptation of knowledge distillation specifically designed to improve model robustness against adversarial attacks.

Key Differences from Standard Distillation

  • Higher temperatures: Use higher temperature values (T > 1) to smooth predictions
  • Adversarial training: Train student model on adversarial examples
  • Robust teacher: Use a teacher model trained with adversarial training
  • Gradient masking: Reduce gradient information available to attackers

Defensive Distillation Implementation

class DefensiveDistillation:
    def __init__(self, teacher_model, student_model, temperature=10.0, alpha=0.7):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.temperature = temperature
        self.alpha = alpha
        
    def train_step(self, x, y, optimizer, adversarial_generator=None):
        """
        Single training step with defensive distillation
        """
        # Generate adversarial examples if provided
        if adversarial_generator:
            x_adv = adversarial_generator.generate(x, y)
            x_combined = torch.cat([x, x_adv], dim=0)
            y_combined = torch.cat([y, y], dim=0)
        else:
            x_combined, y_combined = x, y
        
        # Teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher_model(x_combined)
            teacher_soft = F.softmax(teacher_logits / self.temperature, dim=1)
        
        # Student predictions
        student_logits = self.student_model(x_combined)
        
        # Compute distillation loss
        loss = distillation_loss(
            student_logits, teacher_logits, y_combined, 
            self.temperature, self.alpha
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()
    
    def evaluate_robustness(self, test_loader, attack_method, epsilon=0.3):
        """
        Evaluate model robustness against attacks
        """
        self.student_model.eval()
        
        clean_correct = 0
        adv_correct = 0
        total = 0
        
        for x, y in test_loader:
            # Clean accuracy
            with torch.no_grad():
                clean_pred = self.student_model(x).argmax(dim=1)
                clean_correct += (clean_pred == y).sum().item()
            
            # Adversarial accuracy
            x_adv = attack_method.generate(x, y)
            with torch.no_grad():
                adv_pred = self.student_model(x_adv).argmax(dim=1)
                adv_correct += (adv_pred == y).sum().item()
            
            total += x.size(0)
        
        return {
            'clean_accuracy': clean_correct / total,
            'adversarial_accuracy': adv_correct / total,
            'robustness_gap': (clean_correct - adv_correct) / total
        }

3. Advanced Distillation Techniques

Enhanced distillation methods that provide additional security benefits.

Multi-Teacher Distillation

class MultiTeacherDistillation:
    def __init__(self, teachers, student, temperature=10.0):
        self.teachers = teachers
        self.student = student
        self.temperature = temperature
        
    def ensemble_teacher_predictions(self, x):
        """
        Combine predictions from multiple teachers
        """
        teacher_logits = []
        for teacher in self.teachers:
            with torch.no_grad():
                logits = teacher(x)
                teacher_logits.append(logits)
        
        # Average teacher predictions
        ensemble_logits = torch.stack(teacher_logits).mean(dim=0)
        return ensemble_logits
    
    def compute_distillation_loss(self, x, y):
        """
        Compute loss using ensemble teacher predictions
        """
        # Ensemble teacher predictions
        teacher_logits = self.ensemble_teacher_predictions(x)
        
        # Student predictions
        student_logits = self.student(x)
        
        # Distillation loss
        soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_preds = F.log_softmax(student_logits / self.temperature, dim=1)
        
        kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
        ce_loss = F.cross_entropy(student_logits, y)
        
        return 0.7 * kl_loss + 0.3 * ce_loss

Progressive Distillation

class ProgressiveDistillation:
    def __init__(self, teacher, student, temperature_schedule):
        self.teacher = teacher
        self.student = student
        self.temperature_schedule = temperature_schedule
        
    def train_with_progressive_temperature(self, train_loader, epochs):
        """
        Train with gradually increasing temperature
        """
        optimizer = torch.optim.Adam(self.student.parameters())
        
        for epoch in range(epochs):
            # Get current temperature
            current_temp = self.temperature_schedule[epoch]
            
            epoch_loss = 0
            for x, y in train_loader:
                loss = self.train_step(x, y, optimizer, current_temp)
                epoch_loss += loss
            
            print(f"Epoch {epoch+1}/{epochs}, Temp: {current_temp:.1f}, Loss: {epoch_loss/len(train_loader):.4f}")
    
    def train_step(self, x, y, optimizer, temperature):
        """
        Single training step with specified temperature
        """
        # Teacher predictions
        with torch.no_grad():
            teacher_logits = self.teacher(x)
            soft_targets = F.softmax(teacher_logits / temperature, dim=1)
        
        # Student predictions
        student_logits = self.student(x)
        
        # Distillation loss
        soft_preds = F.log_softmax(student_logits / temperature, dim=1)
        kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
        
        # Standard loss
        ce_loss = F.cross_entropy(student_logits, y)
        
        total_loss = 0.7 * kl_loss + 0.3 * ce_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        return total_loss.item()

๐Ÿ”ง Implementation Strategies

1. Two-Stage Training Process

Implement defensive distillation in two stages for optimal robustness.

class TwoStageDefensiveDistillation:
    def __init__(self, teacher_model, student_model, temperature=10.0):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.temperature = temperature
        
    def stage1_train_teacher(self, train_loader, epochs=100):
        """
        Stage 1: Train teacher model with adversarial training
        """
        print("Stage 1: Training robust teacher model...")
        
        optimizer = torch.optim.Adam(self.teacher_model.parameters())
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        
        # Adversarial training
        for epoch in range(epochs):
            epoch_loss = 0
            for x, y in train_loader:
                # Generate adversarial examples
                x_adv = self.generate_adversarial_examples(x, y)
                
                # Combine clean and adversarial data
                x_combined = torch.cat([x, x_adv], dim=0)
                y_combined = torch.cat([y, y], dim=0)
                
                # Training step
                loss = self.adversarial_training_step(x_combined, y_combined, optimizer)
                epoch_loss += loss
            
            scheduler.step()
            print(f"Teacher Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
    
    def stage2_distill_student(self, train_loader, epochs=50):
        """
        Stage 2: Distill knowledge from robust teacher to student
        """
        print("Stage 2: Distilling knowledge to student model...")
        
        optimizer = torch.optim.Adam(self.student_model.parameters())
        
        for epoch in range(epochs):
            epoch_loss = 0
            for x, y in train_loader:
                loss = self.distillation_step(x, y, optimizer)
                epoch_loss += loss
            
            print(f"Student Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(train_loader):.4f}")
    
    def distillation_step(self, x, y, optimizer):
        """
        Single distillation training step
        """
        # Teacher predictions with high temperature
        with torch.no_grad():
            teacher_logits = self.teacher_model(x)
            soft_targets = F.softmax(teacher_logits / self.temperature, dim=1)
        
        # Student predictions
        student_logits = self.student_model(x)
        
        # Distillation loss
        soft_preds = F.log_softmax(student_logits / self.temperature, dim=1)
        kl_loss = F.kl_div(soft_preds, soft_targets, reduction='batchmean')
        
        # Hard loss
        ce_loss = F.cross_entropy(student_logits, y)
        
        total_loss = 0.7 * kl_loss + 0.3 * ce_loss
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        return total_loss.item()

2. Adaptive Temperature Selection

Dynamically adjust temperature based on model performance and robustness metrics.

class AdaptiveTemperatureDistillation:
    def __init__(self, teacher_model, student_model, initial_temp=10.0):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.current_temp = initial_temp
        self.temp_history = []
        self.performance_history = []
        
    def adaptive_training(self, train_loader, val_loader, max_epochs=100):
        """
        Train with adaptive temperature adjustment
        """
        optimizer = torch.optim.Adam(self.student_model.parameters())
        
        for epoch in range(max_epochs):
            # Training step
            train_loss = self.train_epoch(train_loader, optimizer)
            
            # Validation
            val_metrics = self.evaluate_robustness(val_loader)
            
            # Adjust temperature based on performance
            self.adjust_temperature(val_metrics)
            
            # Log progress
            self.log_progress(epoch, train_loss, val_metrics)
            
            # Early stopping
            if self.should_stop():
                break
    
    def adjust_temperature(self, val_metrics):
        """
        Adjust temperature based on validation performance
        """
        current_robustness = val_metrics['robustness_score']
        
        if len(self.performance_history) > 0:
            prev_robustness = self.performance_history[-1]
            
            # If robustness improved, maintain or slightly increase temperature
            if current_robustness > prev_robustness:
                self.current_temp = min(self.current_temp * 1.05, 20.0)
            # If robustness degraded, decrease temperature
            else:
                self.current_temp = max(self.current_temp * 0.95, 1.0)
        
        self.temp_history.append(self.current_temp)
        self.performance_history.append(current_robustness)
    
    def should_stop(self):
        """
        Early stopping criterion
        """
        if len(self.performance_history) < 10:
            return False
        
        # Check if performance has plateaued
        recent_performance = self.performance_history[-10:]
        if max(recent_performance) - min(recent_performance) < 0.01:
            return True
        
        return False

๐Ÿ“Š Evaluation & Metrics

Robustness Metrics

Adversarial Accuracy

Model accuracy on adversarial examples

Adv_Acc = (Correct_Adv_Predictions) / (Total_Adv_Examples)

Robustness Gap

Difference between clean and adversarial accuracy

Gap = Clean_Accuracy - Adversarial_Accuracy

Distillation Quality

KL divergence between teacher and student predictions

KL = ฮฃ p_teacher * log(p_teacher / p_student)

Gradient Smoothness

Measure of gradient information reduction

Smoothness = ||โˆ‡f(x)||โ‚‚ / ||โˆ‡f_teacher(x)||โ‚‚

Comparative Analysis

def compare_defense_methods(models, test_loader, attacks):
    """
    Compare different defense methods
    """
    results = {}
    
    for model_name, model in models.items():
        model_results = {}
        
        # Clean accuracy
        clean_acc = evaluate_accuracy(model, test_loader)
        model_results['clean_accuracy'] = clean_acc
        
        # Adversarial accuracy for each attack
        for attack_name, attack in attacks.items():
            adv_acc = evaluate_adversarial_accuracy(model, test_loader, attack)
            model_results[f'{attack_name}_accuracy'] = adv_acc
        
        # Compute robustness metrics
        model_results['avg_robustness'] = np.mean([
            model_results[f'{attack}_accuracy'] for attack in attacks.keys()
        ])
        
        model_results['robustness_gap'] = clean_acc - model_results['avg_robustness']
        
        results[model_name] = model_results
    
    return results

def visualize_robustness_comparison(results):
    """
    Create visualization of defense method comparison
    """
    import matplotlib.pyplot as plt
    
    models = list(results.keys())
    clean_accs = [results[model]['clean_accuracy'] for model in models]
    robust_accs = [results[model]['avg_robustness'] for model in models]
    
    x = np.arange(len(models))
    width = 0.35
    
    fig, ax = plt.subplots()
    ax.bar(x - width/2, clean_accs, width, label='Clean Accuracy', alpha=0.8)
    ax.bar(x + width/2, robust_accs, width, label='Adversarial Accuracy', alpha=0.8)
    
    ax.set_xlabel('Defense Methods')
    ax.set_ylabel('Accuracy')
    ax.set_title('Defense Method Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(models, rotation=45)
    ax.legend()
    
    plt.tight_layout()
    plt.show()

โš ๏ธ Limitations & Considerations

Known Limitations

Gradient Masking

  • Problem: High temperatures can mask gradients, making attacks appear ineffective
  • Reality: Models may still be vulnerable to stronger attacks
  • Mitigation: Use adaptive attacks and thorough evaluation

Temperature Sensitivity

  • Problem: Performance highly dependent on temperature selection
  • Challenge: Optimal temperature varies across datasets and models
  • Solution: Use adaptive temperature selection or grid search

Scalability Issues

  • Problem: Requires training two models (teacher + student)
  • Cost: Increased computational and memory requirements
  • Trade-off: Balance between robustness and efficiency

Best Practices

1. Robust Teacher Training

Ensure the teacher model is robust before distillation to avoid transferring vulnerabilities.

2. Temperature Tuning

Experiment with different temperature values to find the optimal balance.

3. Comprehensive Evaluation

Test against multiple attack types, not just the ones used during training.

4. Adaptive Strategies

Use adaptive temperature and progressive distillation for better results.

5. Ensemble Methods

Combine distillation with other defense methods for stronger protection.

๐ŸŽฏ Hands-on Exercise

Exercise: Implement Defensive Distillation

Build a complete defensive distillation system and compare its robustness with standard training.

Tasks:

  1. Train a teacher model with adversarial training
  2. Implement defensive distillation for student model
  3. Experiment with different temperature values
  4. Evaluate robustness against multiple attacks
  5. Compare with baseline model performance

Expected Outcomes:

  • Understanding of distillation mechanics
  • Experience with temperature tuning
  • Insight into robustness trade-offs

๐Ÿ’ป Starter Code

# TODO: Implement defensive distillation
class DefensiveDistillation:
    def __init__(self, teacher, student, temperature=10.0):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        
    def distill(self, train_loader, epochs=50):
        # Your implementation here
        pass
    
    def evaluate_robustness(self, test_loader, attacks):
        # Your implementation here
        pass

# TODO: Compare with baseline
def compare_with_baseline():
    # Your implementation here
    pass