๐Ÿ“š Learning Objectives

By the end of this lesson, you will be able to:

๐ŸŽฏ Understanding Model Extraction

What is Model Extraction?

Model extraction attacks aim to steal or replicate machine learning models by querying them and using the responses to train a substitute model. This is particularly concerning for models exposed through APIs or web services.

๐ŸŽฏ Key Characteristics:

  • Intellectual Property Theft: Stealing proprietary model architecture and weights
  • Black-box Access: Only input/output access required
  • Query-based: Uses prediction queries to extract information
  • Substitute Models: Creates functional copies of target models

Attack Taxonomy

๐Ÿ“Š By Access Level

  • Black-box: Only prediction API access
  • Gray-box: Limited architecture knowledge
  • White-box: Full model access (rare in extraction)

๐ŸŽฏ By Extraction Method

  • Query Synthesis: Generate optimal queries
  • Data-based: Use existing datasets
  • Active Learning: Iterative query optimization

๐Ÿ”Œ API-Based Model Extraction

Basic Extraction Strategy

API-based extraction involves systematically querying a model's API to collect training data for a substitute model.

๐Ÿ“ Extraction Process:

1. Query target model API with inputs X
2. Collect predictions Y = f_target(X)
3. Train substitute model f_substitute on (X, Y)
4. Validate extraction quality
5. Iterate if necessary
                    

Implementation Example

import requests
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import time

class ModelExtractionAttack:
    def __init__(self, target_api_url, max_queries=10000):
        """
        Initialize model extraction attack
        
        Args:
            target_api_url: URL of target model API
            max_queries: Maximum number of API queries allowed
        """
        self.target_api_url = target_api_url
        self.max_queries = max_queries
        self.query_count = 0
        self.extracted_data = []
        
    def query_target_model(self, inputs):
        """
        Query target model API
        
        Args:
            inputs: Input data to query
        
        Returns:
            predictions: Model predictions
        """
        if self.query_count >= self.max_queries:
            raise Exception("Maximum query limit reached")
        
        # Prepare request data
        payload = {
            'inputs': inputs.tolist() if isinstance(inputs, np.ndarray) else inputs,
            'format': 'numpy'
        }
        
        try:
            # Send request to API
            response = requests.post(
                self.target_api_url + '/predict',
                json=payload,
                timeout=30
            )
            response.raise_for_status()
            
            predictions = response.json()['predictions']
            self.query_count += len(inputs)
            
            return np.array(predictions)
            
        except requests.RequestException as e:
            print(f"API request failed: {e}")
            return None
    
    def synthesize_queries(self, input_shape, num_samples=1000, method='random'):
        """
        Synthesize query inputs for extraction
        
        Args:
            input_shape: Shape of input data
            num_samples: Number of samples to generate
            method: Method for query synthesis ('random', 'gaussian', 'uniform')
        
        Returns:
            synthetic_inputs: Generated input samples
        """
        if method == 'random':
            # Random noise inputs
            synthetic_inputs = np.random.random((num_samples, *input_shape))
        elif method == 'gaussian':
            # Gaussian distributed inputs
            synthetic_inputs = np.random.normal(0.5, 0.1, (num_samples, *input_shape))
            synthetic_inputs = np.clip(synthetic_inputs, 0, 1)
        elif method == 'uniform':
            # Uniform distributed inputs
            synthetic_inputs = np.random.uniform(0, 1, (num_samples, *input_shape))
        else:
            raise ValueError(f"Unknown synthesis method: {method}")
        
        return synthetic_inputs
    
    def extract_model(self, input_shape, substitute_model, synthesis_method='random', 
                     batch_size=100, max_iterations=10):
        """
        Extract target model by training substitute model
        
        Args:
            input_shape: Shape of input data
            substitute_model: Model to train as substitute
            synthesis_method: Method for query synthesis
            batch_size: Batch size for queries
            max_iterations: Maximum extraction iterations
        
        Returns:
            trained_model: Extracted substitute model
            extraction_stats: Statistics about extraction process
        """
        extraction_stats = {
            'total_queries': 0,
            'iterations': 0,
            'accuracy_history': []
        }
        
        # Initialize with random data
        synthetic_data = self.synthesize_queries(
            input_shape, 
            num_samples=1000, 
            method=synthesis_method
        )
        
        for iteration in range(max_iterations):
            print(f"Extraction iteration {iteration + 1}/{max_iterations}")
            
            # Query target model in batches
            batch_predictions = []
            for i in range(0, len(synthetic_data), batch_size):
                batch = synthetic_data[i:i + batch_size]
                predictions = self.query_target_model(batch)
                
                if predictions is not None:
                    batch_predictions.append(predictions)
                else:
                    break
            
            if not batch_predictions:
                print("Failed to get predictions from target model")
                break
            
            # Combine all predictions
            all_predictions = np.concatenate(batch_predictions, axis=0)
            
            # Store extracted data
            self.extracted_data.append((synthetic_data, all_predictions))
            extraction_stats['total_queries'] += len(synthetic_data)
            
            # Train substitute model
            substitute_model = self.train_substitute_model(
                substitute_model, 
                synthetic_data, 
                all_predictions
            )
            
            # Evaluate extraction quality
            if len(self.extracted_data) > 1:
                # Use previous data for validation
                val_data, val_labels = self.extracted_data[-2]
                val_pred = self.predict_substitute(substitute_model, val_data)
                accuracy = accuracy_score(val_labels.argmax(axis=1), val_pred.argmax(axis=1))
                extraction_stats['accuracy_history'].append(accuracy)
                
                print(f"Extraction accuracy: {accuracy:.4f}")
                
                # Stop if accuracy is good enough
                if accuracy > 0.9:
                    print("High accuracy achieved, stopping extraction")
                    break
            
            # Generate more diverse data for next iteration
            synthetic_data = self.synthesize_queries(
                input_shape,
                num_samples=1000,
                method=synthesis_method
            )
        
        extraction_stats['iterations'] = iteration + 1
        return substitute_model, extraction_stats
    
    def train_substitute_model(self, model, inputs, targets):
        """
        Train substitute model on extracted data
        
        Args:
            model: Substitute model to train
            inputs: Input data
            targets: Target predictions
        
        Returns:
            trained_model: Trained substitute model
        """
        # Convert to PyTorch tensors
        if isinstance(inputs, np.ndarray):
            inputs = torch.FloatTensor(inputs)
        if isinstance(targets, np.ndarray):
            targets = torch.FloatTensor(targets)
        
        # Train model (simplified training loop)
        model.train()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(10):  # Quick training
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets.argmax(dim=1))
            loss.backward()
            optimizer.step()
        
        return model

# Example usage
class SimpleSubstituteModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleSubstituteModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, num_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
                

๐Ÿง  Active Learning for Extraction

Query Optimization Strategy

Active learning techniques can significantly reduce the number of queries needed for successful model extraction by focusing on the most informative samples.

๐Ÿ“ Active Learning Selection:

For each candidate query x:
    1. Calculate uncertainty: U(x) = entropy(p_model(x))
    2. Calculate diversity: D(x) = min_distance(x, existing_queries)
    3. Select x* = argmax(ฮฑ * U(x) + ฮฒ * D(x))
                    

Implementation Example

class ActiveLearningExtraction:
    def __init__(self, target_api_url, max_queries=5000):
        """
        Initialize active learning-based extraction
        
        Args:
            target_api_url: URL of target model API
            max_queries: Maximum number of API queries
        """
        self.target_api_url = target_api_url
        self.max_queries = max_queries
        self.query_count = 0
        self.extracted_data = []
        self.query_history = []
        
    def calculate_uncertainty(self, model, inputs):
        """
        Calculate prediction uncertainty for given inputs
        
        Args:
            model: Current substitute model
            inputs: Input samples
        
        Returns:
            uncertainties: Uncertainty scores for each input
        """
        model.eval()
        with torch.no_grad():
            outputs = model(inputs)
            # Calculate entropy as uncertainty measure
            probs = torch.softmax(outputs, dim=1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
            return entropy.numpy()
    
    def calculate_diversity(self, new_inputs, existing_inputs):
        """
        Calculate diversity score for new inputs
        
        Args:
            new_inputs: New candidate inputs
            existing_inputs: Previously queried inputs
        
        Returns:
            diversity_scores: Diversity scores for new inputs
        """
        if len(existing_inputs) == 0:
            return np.ones(len(new_inputs))
        
        # Calculate minimum distance to existing queries
        diversity_scores = []
        for new_input in new_inputs:
            distances = np.linalg.norm(existing_inputs - new_input, axis=1)
            min_distance = np.min(distances)
            diversity_scores.append(min_distance)
        
        return np.array(diversity_scores)
    
    def select_informative_queries(self, candidate_inputs, model, num_select=100):
        """
        Select most informative queries using active learning
        
        Args:
            candidate_inputs: Pool of candidate inputs
            model: Current substitute model
            num_select: Number of queries to select
        
        Returns:
            selected_inputs: Most informative inputs to query
        """
        # Calculate uncertainty for all candidates
        uncertainties = self.calculate_uncertainty(model, candidate_inputs)
        
        # Calculate diversity scores
        existing_inputs = np.array(self.query_history) if self.query_history else np.array([])
        diversities = self.calculate_diversity(candidate_inputs, existing_inputs)
        
        # Combine uncertainty and diversity (weighted sum)
        alpha, beta = 0.7, 0.3  # Weights for uncertainty vs diversity
        scores = alpha * uncertainties + beta * diversities
        
        # Select top-k most informative queries
        top_indices = np.argsort(scores)[-num_select:]
        selected_inputs = candidate_inputs[top_indices]
        
        return selected_inputs
    
    def active_extraction(self, input_shape, substitute_model, 
                         initial_samples=1000, batch_size=50, 
                         num_iterations=20):
        """
        Perform active learning-based model extraction
        
        Args:
            input_shape: Shape of input data
            substitute_model: Model to train as substitute
            initial_samples: Number of initial random samples
            batch_size: Batch size for active learning queries
            num_iterations: Number of active learning iterations
        
        Returns:
            trained_model: Extracted substitute model
            extraction_stats: Statistics about extraction process
        """
        extraction_stats = {
            'total_queries': 0,
            'iterations': 0,
            'accuracy_history': [],
            'uncertainty_history': []
        }
        
        # Start with random initial samples
        print("Collecting initial random samples...")
        initial_inputs = self.synthesize_queries(input_shape, initial_samples)
        
        # Query target model for initial data
        initial_predictions = self.query_target_model(initial_inputs)
        if initial_predictions is None:
            raise Exception("Failed to get initial predictions")
        
        # Store initial data
        self.extracted_data = [(initial_inputs, initial_predictions)]
        self.query_history = initial_inputs.tolist()
        extraction_stats['total_queries'] += len(initial_inputs)
        
        # Train initial substitute model
        substitute_model = self.train_substitute_model(
            substitute_model, initial_inputs, initial_predictions
        )
        
        # Active learning iterations
        for iteration in range(num_iterations):
            print(f"Active learning iteration {iteration + 1}/{num_iterations}")
            
            if self.query_count >= self.max_queries:
                print("Maximum query limit reached")
                break
            
            # Generate candidate queries
            candidate_inputs = self.synthesize_queries(input_shape, 1000)
            
            # Select most informative queries
            selected_inputs = self.select_informative_queries(
                candidate_inputs, substitute_model, batch_size
            )
            
            # Query target model
            selected_predictions = self.query_target_model(selected_inputs)
            if selected_predictions is None:
                break
            
            # Store new data
            self.extracted_data.append((selected_inputs, selected_predictions))
            self.query_history.extend(selected_inputs.tolist())
            extraction_stats['total_queries'] += len(selected_inputs)
            
            # Retrain substitute model with all data
            all_inputs = np.concatenate([data[0] for data in self.extracted_data])
            all_predictions = np.concatenate([data[1] for data in self.extracted_data])
            
            substitute_model = self.train_substitute_model(
                substitute_model, all_inputs, all_predictions
            )
            
            # Evaluate extraction quality
            avg_uncertainty = np.mean(self.calculate_uncertainty(
                substitute_model, 
                torch.FloatTensor(selected_inputs)
            ))
            extraction_stats['uncertainty_history'].append(avg_uncertainty)
            
            print(f"Average uncertainty: {avg_uncertainty:.4f}")
            print(f"Total queries used: {extraction_stats['total_queries']}")
            
            # Check convergence
            if len(extraction_stats['uncertainty_history']) > 3:
                recent_uncertainties = extraction_stats['uncertainty_history'][-3:]
                if np.std(recent_uncertainties) < 0.01:  # Low variance indicates convergence
                    print("Convergence detected, stopping active learning")
                    break
        
        extraction_stats['iterations'] = iteration + 1
        return substitute_model, extraction_stats
    
    def synthesize_queries(self, input_shape, num_samples=1000):
        """Synthesize query inputs (same as basic extraction)"""
        return np.random.random((num_samples, *input_shape))
    
    def query_target_model(self, inputs):
        """Query target model API (same as basic extraction)"""
        # Implementation similar to basic extraction
        # ... (API querying code)
        pass
    
    def train_substitute_model(self, model, inputs, targets):
        """Train substitute model (same as basic extraction)"""
        # Implementation similar to basic extraction
        # ... (training code)
        return model
                

๐Ÿ›ก๏ธ Defense Mechanisms

Query Limiting and Rate Limiting

Implement strict limits on API queries to prevent systematic extraction.

class QueryLimitingDefense:
    def __init__(self, max_queries_per_ip=1000, time_window=3600):
        """
        Initialize query limiting defense
        
        Args:
            max_queries_per_ip: Maximum queries per IP address
            time_window: Time window in seconds
        """
        self.max_queries = max_queries_per_ip
        self.time_window = time_window
        self.query_log = {}  # IP -> [(timestamp, query_count)]
    
    def check_query_limit(self, client_ip):
        """
        Check if client has exceeded query limits
        
        Args:
            client_ip: Client IP address
        
        Returns:
            allowed: Whether query is allowed
        """
        current_time = time.time()
        
        # Clean old entries
        if client_ip in self.query_log:
            self.query_log[client_ip] = [
                (timestamp, count) for timestamp, count in self.query_log[client_ip]
                if current_time - timestamp < self.time_window
            ]
        else:
            self.query_log[client_ip] = []
        
        # Count queries in time window
        total_queries = sum(count for _, count in self.query_log[client_ip])
        
        if total_queries >= self.max_queries:
            return False
        
        # Log this query
        self.query_log[client_ip].append((current_time, 1))
        return True
                

Output Perturbation

Add noise to model outputs to make extraction more difficult.

class OutputPerturbationDefense:
    def __init__(self, noise_std=0.1, top_k=3):
        """
        Initialize output perturbation defense
        
        Args:
            noise_std: Standard deviation of added noise
            top_k: Only return top-k predictions
        """
        self.noise_std = noise_std
        self.top_k = top_k
    
    def perturb_predictions(self, predictions):
        """
        Add noise to model predictions
        
        Args:
            predictions: Original model predictions
        
        Returns:
            perturbed_predictions: Noisy predictions
        """
        # Add Gaussian noise
        noise = np.random.normal(0, self.noise_std, predictions.shape)
        perturbed = predictions + noise
        
        # Ensure probabilities sum to 1
        perturbed = np.exp(perturbed) / np.sum(np.exp(perturbed), axis=1, keepdims=True)
        
        # Return only top-k predictions
        top_k_indices = np.argsort(perturbed, axis=1)[:, -self.top_k:]
        top_k_values = np.take_along_axis(perturbed, top_k_indices, axis=1)
        
        return top_k_indices, top_k_values
                

Input Validation and Anomaly Detection

Detect and block suspicious query patterns.

class InputValidationDefense:
    def __init__(self, max_input_norm=10.0, similarity_threshold=0.95):
        """
        Initialize input validation defense
        
        Args:
            max_input_norm: Maximum allowed input norm
            similarity_threshold: Threshold for similar input detection
        """
        self.max_norm = max_input_norm
        self.similarity_threshold = similarity_threshold
        self.input_history = []
    
    def validate_input(self, inputs):
        """
        Validate input for suspicious patterns
        
        Args:
            inputs: Input data to validate
        
        Returns:
            valid: Whether input is valid
            reason: Reason for rejection if invalid
        """
        # Check input norm
        input_norm = np.linalg.norm(inputs)
        if input_norm > self.max_norm:
            return False, f"Input norm too large: {input_norm}"
        
        # Check for similarity to previous inputs
        if len(self.input_history) > 0:
            similarities = []
            for prev_input in self.input_history[-100:]:  # Check last 100 inputs
                similarity = np.dot(inputs.flatten(), prev_input.flatten()) / (
                    np.linalg.norm(inputs) * np.linalg.norm(prev_input)
                )
                similarities.append(similarity)
            
            max_similarity = max(similarities)
            if max_similarity > self.similarity_threshold:
                return False, f"Input too similar to previous: {max_similarity}"
        
        # Store input for future similarity checks
        self.input_history.append(inputs.copy())
        if len(self.input_history) > 1000:  # Keep only recent history
            self.input_history = self.input_history[-1000:]
        
        return True, "Valid input"
                

๐Ÿงช Hands-On Exercise

Exercise: Model Extraction Attack and Defense

Objective: Implement model extraction attacks and corresponding defense mechanisms.

๐Ÿ“‹ Steps:

  1. Setup Target Model

    Create a target model and API for extraction:

    from flask import Flask, request, jsonify
    import numpy as np
    import torch
    import torch.nn as nn
    
    # Create target model
    class TargetModel(nn.Module):
        def __init__(self):
            super(TargetModel, self).__init__()
            self.fc1 = nn.Linear(784, 128)
            self.fc2 = nn.Linear(128, 64)
            self.fc3 = nn.Linear(64, 10)
            self.relu = nn.ReLU()
        
        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = self.relu(self.fc1(x))
            x = self.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    
    # Initialize target model
    target_model = TargetModel()
    # Load pre-trained weights here
    
    # Create Flask API
    app = Flask(__name__)
    
    @app.route('/predict', methods=['POST'])
    def predict():
        data = request.json
        inputs = np.array(data['inputs'])
        
        # Convert to tensor and predict
        inputs_tensor = torch.FloatTensor(inputs)
        with torch.no_grad():
            outputs = target_model(inputs_tensor)
            predictions = torch.softmax(outputs, dim=1).numpy()
        
        return jsonify({'predictions': predictions.tolist()})
    
    if __name__ == '__main__':
        app.run(host='0.0.0.0', port=5000)
                                
  2. Implement Basic Extraction

    Create and test basic model extraction:

    def test_basic_extraction():
        """
        Test basic model extraction attack
        """
        # Initialize extraction attack
        extraction_attack = ModelExtractionAttack(
            target_api_url='http://localhost:5000',
            max_queries=5000
        )
        
        # Create substitute model
        substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10)
        
        # Perform extraction
        extracted_model, stats = extraction_attack.extract_model(
            input_shape=(784,),
            substitute_model=substitute_model,
            synthesis_method='random',
            batch_size=100
        )
        
        print(f"Extraction completed:")
        print(f"Total queries: {stats['total_queries']}")
        print(f"Iterations: {stats['iterations']}")
        print(f"Final accuracy: {stats['accuracy_history'][-1]:.4f}")
        
        return extracted_model, stats
    
    # Run basic extraction test
    extracted_model, basic_stats = test_basic_extraction()
                                
  3. Implement Active Learning Extraction

    Compare active learning with basic extraction:

    def test_active_learning_extraction():
        """
        Test active learning-based extraction
        """
        # Initialize active learning extraction
        active_extraction = ActiveLearningExtraction(
            target_api_url='http://localhost:5000',
            max_queries=5000
        )
        
        # Create substitute model
        substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10)
        
        # Perform active learning extraction
        extracted_model, stats = active_extraction.active_extraction(
            input_shape=(784,),
            substitute_model=substitute_model,
            initial_samples=500,
            batch_size=50,
            num_iterations=20
        )
        
        print(f"Active learning extraction completed:")
        print(f"Total queries: {stats['total_queries']}")
        print(f"Iterations: {stats['iterations']}")
        print(f"Final uncertainty: {stats['uncertainty_history'][-1]:.4f}")
        
        return extracted_model, stats
    
    # Run active learning extraction test
    active_model, active_stats = test_active_learning_extraction()
                                
  4. Implement Defense Mechanisms

    Create and test defense mechanisms:

    class DefendedAPI:
        def __init__(self, target_model):
            """
            Initialize defended API with multiple defense mechanisms
            """
            self.target_model = target_model
            self.query_limiter = QueryLimitingDefense(max_queries_per_ip=100)
            self.output_perturber = OutputPerturbationDefense(noise_std=0.1)
            self.input_validator = InputValidationDefense()
        
        def predict(self, inputs, client_ip):
            """
            Make prediction with defense mechanisms
            
            Args:
                inputs: Input data
                client_ip: Client IP address
            
            Returns:
                predictions: Defended predictions or error
            """
            # Check query limits
            if not self.query_limiter.check_query_limit(client_ip):
                return {"error": "Query limit exceeded"}
            
            # Validate input
            valid, reason = self.input_validator.validate_input(inputs)
            if not valid:
                return {"error": f"Invalid input: {reason}"}
            
            # Get model predictions
            inputs_tensor = torch.FloatTensor(inputs)
            with torch.no_grad():
                outputs = self.target_model(inputs_tensor)
                predictions = torch.softmax(outputs, dim=1).numpy()
            
            # Perturb outputs
            indices, values = self.output_perturber.perturb_predictions(predictions)
            
            return {
                "predictions": values.tolist(),
                "indices": indices.tolist()
            }
    
    def test_defense_effectiveness():
        """
        Test effectiveness of defense mechanisms against extraction
        """
        # Create defended API
        defended_api = DefendedAPI(target_model)
        
        # Attempt extraction against defended API
        extraction_attack = ModelExtractionAttack(
            target_api_url='http://localhost:5000',  # Would be defended API
            max_queries=1000
        )
        
        substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10)
        
        try:
            extracted_model, stats = extraction_attack.extract_model(
                input_shape=(784,),
                substitute_model=substitute_model,
                batch_size=50
            )
            
            print(f"Extraction against defended API:")
            print(f"Total queries: {stats['total_queries']}")
            print(f"Success: {stats['total_queries'] > 0}")
            
        except Exception as e:
            print(f"Extraction blocked by defenses: {e}")
        
        return stats
    
    # Test defense effectiveness
    defense_stats = test_defense_effectiveness()
                                

๐Ÿ“„ Deliverables:

  • Working model extraction attack implementation
  • Active learning extraction with uncertainty sampling
  • Multiple defense mechanisms (rate limiting, output perturbation, input validation)
  • Comparison of extraction effectiveness with and without defenses
  • Analysis of query efficiency (basic vs. active learning)
  • Recommendations for API security best practices

๐Ÿ“Š Knowledge Check

Question 1: What is the primary advantage of active learning over random sampling for model extraction?

Question 2: Which defense mechanism is most effective against systematic extraction attacks?

Question 3: What type of information can be extracted from a black-box model API?