๐ Lesson 3: Model Extraction
Learn about model extraction attacks, API-based extraction, and intellectual property theft in machine learning
๐ Learning Objectives
By the end of this lesson, you will be able to:
- Understand model extraction attack vectors
- Implement API-based model extraction
- Perform black-box model stealing
- Execute query synthesis attacks
- Evaluate extraction attack effectiveness
- Implement defense mechanisms against extraction
๐ฏ Understanding Model Extraction
What is Model Extraction?
Model extraction attacks aim to steal or replicate machine learning models by querying them and using the responses to train a substitute model. This is particularly concerning for models exposed through APIs or web services.
๐ฏ Key Characteristics:
- Intellectual Property Theft: Stealing proprietary model architecture and weights
- Black-box Access: Only input/output access required
- Query-based: Uses prediction queries to extract information
- Substitute Models: Creates functional copies of target models
Attack Taxonomy
๐ By Access Level
- Black-box: Only prediction API access
- Gray-box: Limited architecture knowledge
- White-box: Full model access (rare in extraction)
๐ฏ By Extraction Method
- Query Synthesis: Generate optimal queries
- Data-based: Use existing datasets
- Active Learning: Iterative query optimization
๐ API-Based Model Extraction
Basic Extraction Strategy
API-based extraction involves systematically querying a model's API to collect training data for a substitute model.
๐ Extraction Process:
1. Query target model API with inputs X
2. Collect predictions Y = f_target(X)
3. Train substitute model f_substitute on (X, Y)
4. Validate extraction quality
5. Iterate if necessary
Implementation Example
import requests
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import time
class ModelExtractionAttack:
def __init__(self, target_api_url, max_queries=10000):
"""
Initialize model extraction attack
Args:
target_api_url: URL of target model API
max_queries: Maximum number of API queries allowed
"""
self.target_api_url = target_api_url
self.max_queries = max_queries
self.query_count = 0
self.extracted_data = []
def query_target_model(self, inputs):
"""
Query target model API
Args:
inputs: Input data to query
Returns:
predictions: Model predictions
"""
if self.query_count >= self.max_queries:
raise Exception("Maximum query limit reached")
# Prepare request data
payload = {
'inputs': inputs.tolist() if isinstance(inputs, np.ndarray) else inputs,
'format': 'numpy'
}
try:
# Send request to API
response = requests.post(
self.target_api_url + '/predict',
json=payload,
timeout=30
)
response.raise_for_status()
predictions = response.json()['predictions']
self.query_count += len(inputs)
return np.array(predictions)
except requests.RequestException as e:
print(f"API request failed: {e}")
return None
def synthesize_queries(self, input_shape, num_samples=1000, method='random'):
"""
Synthesize query inputs for extraction
Args:
input_shape: Shape of input data
num_samples: Number of samples to generate
method: Method for query synthesis ('random', 'gaussian', 'uniform')
Returns:
synthetic_inputs: Generated input samples
"""
if method == 'random':
# Random noise inputs
synthetic_inputs = np.random.random((num_samples, *input_shape))
elif method == 'gaussian':
# Gaussian distributed inputs
synthetic_inputs = np.random.normal(0.5, 0.1, (num_samples, *input_shape))
synthetic_inputs = np.clip(synthetic_inputs, 0, 1)
elif method == 'uniform':
# Uniform distributed inputs
synthetic_inputs = np.random.uniform(0, 1, (num_samples, *input_shape))
else:
raise ValueError(f"Unknown synthesis method: {method}")
return synthetic_inputs
def extract_model(self, input_shape, substitute_model, synthesis_method='random',
batch_size=100, max_iterations=10):
"""
Extract target model by training substitute model
Args:
input_shape: Shape of input data
substitute_model: Model to train as substitute
synthesis_method: Method for query synthesis
batch_size: Batch size for queries
max_iterations: Maximum extraction iterations
Returns:
trained_model: Extracted substitute model
extraction_stats: Statistics about extraction process
"""
extraction_stats = {
'total_queries': 0,
'iterations': 0,
'accuracy_history': []
}
# Initialize with random data
synthetic_data = self.synthesize_queries(
input_shape,
num_samples=1000,
method=synthesis_method
)
for iteration in range(max_iterations):
print(f"Extraction iteration {iteration + 1}/{max_iterations}")
# Query target model in batches
batch_predictions = []
for i in range(0, len(synthetic_data), batch_size):
batch = synthetic_data[i:i + batch_size]
predictions = self.query_target_model(batch)
if predictions is not None:
batch_predictions.append(predictions)
else:
break
if not batch_predictions:
print("Failed to get predictions from target model")
break
# Combine all predictions
all_predictions = np.concatenate(batch_predictions, axis=0)
# Store extracted data
self.extracted_data.append((synthetic_data, all_predictions))
extraction_stats['total_queries'] += len(synthetic_data)
# Train substitute model
substitute_model = self.train_substitute_model(
substitute_model,
synthetic_data,
all_predictions
)
# Evaluate extraction quality
if len(self.extracted_data) > 1:
# Use previous data for validation
val_data, val_labels = self.extracted_data[-2]
val_pred = self.predict_substitute(substitute_model, val_data)
accuracy = accuracy_score(val_labels.argmax(axis=1), val_pred.argmax(axis=1))
extraction_stats['accuracy_history'].append(accuracy)
print(f"Extraction accuracy: {accuracy:.4f}")
# Stop if accuracy is good enough
if accuracy > 0.9:
print("High accuracy achieved, stopping extraction")
break
# Generate more diverse data for next iteration
synthetic_data = self.synthesize_queries(
input_shape,
num_samples=1000,
method=synthesis_method
)
extraction_stats['iterations'] = iteration + 1
return substitute_model, extraction_stats
def train_substitute_model(self, model, inputs, targets):
"""
Train substitute model on extracted data
Args:
model: Substitute model to train
inputs: Input data
targets: Target predictions
Returns:
trained_model: Trained substitute model
"""
# Convert to PyTorch tensors
if isinstance(inputs, np.ndarray):
inputs = torch.FloatTensor(inputs)
if isinstance(targets, np.ndarray):
targets = torch.FloatTensor(targets)
# Train model (simplified training loop)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for epoch in range(10): # Quick training
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets.argmax(dim=1))
loss.backward()
optimizer.step()
return model
# Example usage
class SimpleSubstituteModel(nn.Module):
def __init__(self, input_size, num_classes):
super(SimpleSubstituteModel, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(x.size(0), -1) # Flatten
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
๐ง Active Learning for Extraction
Query Optimization Strategy
Active learning techniques can significantly reduce the number of queries needed for successful model extraction by focusing on the most informative samples.
๐ Active Learning Selection:
For each candidate query x:
1. Calculate uncertainty: U(x) = entropy(p_model(x))
2. Calculate diversity: D(x) = min_distance(x, existing_queries)
3. Select x* = argmax(ฮฑ * U(x) + ฮฒ * D(x))
Implementation Example
class ActiveLearningExtraction:
def __init__(self, target_api_url, max_queries=5000):
"""
Initialize active learning-based extraction
Args:
target_api_url: URL of target model API
max_queries: Maximum number of API queries
"""
self.target_api_url = target_api_url
self.max_queries = max_queries
self.query_count = 0
self.extracted_data = []
self.query_history = []
def calculate_uncertainty(self, model, inputs):
"""
Calculate prediction uncertainty for given inputs
Args:
model: Current substitute model
inputs: Input samples
Returns:
uncertainties: Uncertainty scores for each input
"""
model.eval()
with torch.no_grad():
outputs = model(inputs)
# Calculate entropy as uncertainty measure
probs = torch.softmax(outputs, dim=1)
entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
return entropy.numpy()
def calculate_diversity(self, new_inputs, existing_inputs):
"""
Calculate diversity score for new inputs
Args:
new_inputs: New candidate inputs
existing_inputs: Previously queried inputs
Returns:
diversity_scores: Diversity scores for new inputs
"""
if len(existing_inputs) == 0:
return np.ones(len(new_inputs))
# Calculate minimum distance to existing queries
diversity_scores = []
for new_input in new_inputs:
distances = np.linalg.norm(existing_inputs - new_input, axis=1)
min_distance = np.min(distances)
diversity_scores.append(min_distance)
return np.array(diversity_scores)
def select_informative_queries(self, candidate_inputs, model, num_select=100):
"""
Select most informative queries using active learning
Args:
candidate_inputs: Pool of candidate inputs
model: Current substitute model
num_select: Number of queries to select
Returns:
selected_inputs: Most informative inputs to query
"""
# Calculate uncertainty for all candidates
uncertainties = self.calculate_uncertainty(model, candidate_inputs)
# Calculate diversity scores
existing_inputs = np.array(self.query_history) if self.query_history else np.array([])
diversities = self.calculate_diversity(candidate_inputs, existing_inputs)
# Combine uncertainty and diversity (weighted sum)
alpha, beta = 0.7, 0.3 # Weights for uncertainty vs diversity
scores = alpha * uncertainties + beta * diversities
# Select top-k most informative queries
top_indices = np.argsort(scores)[-num_select:]
selected_inputs = candidate_inputs[top_indices]
return selected_inputs
def active_extraction(self, input_shape, substitute_model,
initial_samples=1000, batch_size=50,
num_iterations=20):
"""
Perform active learning-based model extraction
Args:
input_shape: Shape of input data
substitute_model: Model to train as substitute
initial_samples: Number of initial random samples
batch_size: Batch size for active learning queries
num_iterations: Number of active learning iterations
Returns:
trained_model: Extracted substitute model
extraction_stats: Statistics about extraction process
"""
extraction_stats = {
'total_queries': 0,
'iterations': 0,
'accuracy_history': [],
'uncertainty_history': []
}
# Start with random initial samples
print("Collecting initial random samples...")
initial_inputs = self.synthesize_queries(input_shape, initial_samples)
# Query target model for initial data
initial_predictions = self.query_target_model(initial_inputs)
if initial_predictions is None:
raise Exception("Failed to get initial predictions")
# Store initial data
self.extracted_data = [(initial_inputs, initial_predictions)]
self.query_history = initial_inputs.tolist()
extraction_stats['total_queries'] += len(initial_inputs)
# Train initial substitute model
substitute_model = self.train_substitute_model(
substitute_model, initial_inputs, initial_predictions
)
# Active learning iterations
for iteration in range(num_iterations):
print(f"Active learning iteration {iteration + 1}/{num_iterations}")
if self.query_count >= self.max_queries:
print("Maximum query limit reached")
break
# Generate candidate queries
candidate_inputs = self.synthesize_queries(input_shape, 1000)
# Select most informative queries
selected_inputs = self.select_informative_queries(
candidate_inputs, substitute_model, batch_size
)
# Query target model
selected_predictions = self.query_target_model(selected_inputs)
if selected_predictions is None:
break
# Store new data
self.extracted_data.append((selected_inputs, selected_predictions))
self.query_history.extend(selected_inputs.tolist())
extraction_stats['total_queries'] += len(selected_inputs)
# Retrain substitute model with all data
all_inputs = np.concatenate([data[0] for data in self.extracted_data])
all_predictions = np.concatenate([data[1] for data in self.extracted_data])
substitute_model = self.train_substitute_model(
substitute_model, all_inputs, all_predictions
)
# Evaluate extraction quality
avg_uncertainty = np.mean(self.calculate_uncertainty(
substitute_model,
torch.FloatTensor(selected_inputs)
))
extraction_stats['uncertainty_history'].append(avg_uncertainty)
print(f"Average uncertainty: {avg_uncertainty:.4f}")
print(f"Total queries used: {extraction_stats['total_queries']}")
# Check convergence
if len(extraction_stats['uncertainty_history']) > 3:
recent_uncertainties = extraction_stats['uncertainty_history'][-3:]
if np.std(recent_uncertainties) < 0.01: # Low variance indicates convergence
print("Convergence detected, stopping active learning")
break
extraction_stats['iterations'] = iteration + 1
return substitute_model, extraction_stats
def synthesize_queries(self, input_shape, num_samples=1000):
"""Synthesize query inputs (same as basic extraction)"""
return np.random.random((num_samples, *input_shape))
def query_target_model(self, inputs):
"""Query target model API (same as basic extraction)"""
# Implementation similar to basic extraction
# ... (API querying code)
pass
def train_substitute_model(self, model, inputs, targets):
"""Train substitute model (same as basic extraction)"""
# Implementation similar to basic extraction
# ... (training code)
return model
๐ก๏ธ Defense Mechanisms
Query Limiting and Rate Limiting
Implement strict limits on API queries to prevent systematic extraction.
class QueryLimitingDefense:
def __init__(self, max_queries_per_ip=1000, time_window=3600):
"""
Initialize query limiting defense
Args:
max_queries_per_ip: Maximum queries per IP address
time_window: Time window in seconds
"""
self.max_queries = max_queries_per_ip
self.time_window = time_window
self.query_log = {} # IP -> [(timestamp, query_count)]
def check_query_limit(self, client_ip):
"""
Check if client has exceeded query limits
Args:
client_ip: Client IP address
Returns:
allowed: Whether query is allowed
"""
current_time = time.time()
# Clean old entries
if client_ip in self.query_log:
self.query_log[client_ip] = [
(timestamp, count) for timestamp, count in self.query_log[client_ip]
if current_time - timestamp < self.time_window
]
else:
self.query_log[client_ip] = []
# Count queries in time window
total_queries = sum(count for _, count in self.query_log[client_ip])
if total_queries >= self.max_queries:
return False
# Log this query
self.query_log[client_ip].append((current_time, 1))
return True
Output Perturbation
Add noise to model outputs to make extraction more difficult.
class OutputPerturbationDefense:
def __init__(self, noise_std=0.1, top_k=3):
"""
Initialize output perturbation defense
Args:
noise_std: Standard deviation of added noise
top_k: Only return top-k predictions
"""
self.noise_std = noise_std
self.top_k = top_k
def perturb_predictions(self, predictions):
"""
Add noise to model predictions
Args:
predictions: Original model predictions
Returns:
perturbed_predictions: Noisy predictions
"""
# Add Gaussian noise
noise = np.random.normal(0, self.noise_std, predictions.shape)
perturbed = predictions + noise
# Ensure probabilities sum to 1
perturbed = np.exp(perturbed) / np.sum(np.exp(perturbed), axis=1, keepdims=True)
# Return only top-k predictions
top_k_indices = np.argsort(perturbed, axis=1)[:, -self.top_k:]
top_k_values = np.take_along_axis(perturbed, top_k_indices, axis=1)
return top_k_indices, top_k_values
Input Validation and Anomaly Detection
Detect and block suspicious query patterns.
class InputValidationDefense:
def __init__(self, max_input_norm=10.0, similarity_threshold=0.95):
"""
Initialize input validation defense
Args:
max_input_norm: Maximum allowed input norm
similarity_threshold: Threshold for similar input detection
"""
self.max_norm = max_input_norm
self.similarity_threshold = similarity_threshold
self.input_history = []
def validate_input(self, inputs):
"""
Validate input for suspicious patterns
Args:
inputs: Input data to validate
Returns:
valid: Whether input is valid
reason: Reason for rejection if invalid
"""
# Check input norm
input_norm = np.linalg.norm(inputs)
if input_norm > self.max_norm:
return False, f"Input norm too large: {input_norm}"
# Check for similarity to previous inputs
if len(self.input_history) > 0:
similarities = []
for prev_input in self.input_history[-100:]: # Check last 100 inputs
similarity = np.dot(inputs.flatten(), prev_input.flatten()) / (
np.linalg.norm(inputs) * np.linalg.norm(prev_input)
)
similarities.append(similarity)
max_similarity = max(similarities)
if max_similarity > self.similarity_threshold:
return False, f"Input too similar to previous: {max_similarity}"
# Store input for future similarity checks
self.input_history.append(inputs.copy())
if len(self.input_history) > 1000: # Keep only recent history
self.input_history = self.input_history[-1000:]
return True, "Valid input"
๐งช Hands-On Exercise
Exercise: Model Extraction Attack and Defense
Objective: Implement model extraction attacks and corresponding defense mechanisms.
๐ Steps:
-
Setup Target Model
Create a target model and API for extraction:
from flask import Flask, request, jsonify import numpy as np import torch import torch.nn as nn # Create target model class TargetModel(nn.Module): def __init__(self): super(TargetModel, self).__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) self.relu = nn.ReLU() def forward(self, x): x = x.view(x.size(0), -1) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize target model target_model = TargetModel() # Load pre-trained weights here # Create Flask API app = Flask(__name__) @app.route('/predict', methods=['POST']) def predict(): data = request.json inputs = np.array(data['inputs']) # Convert to tensor and predict inputs_tensor = torch.FloatTensor(inputs) with torch.no_grad(): outputs = target_model(inputs_tensor) predictions = torch.softmax(outputs, dim=1).numpy() return jsonify({'predictions': predictions.tolist()}) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000) -
Implement Basic Extraction
Create and test basic model extraction:
def test_basic_extraction(): """ Test basic model extraction attack """ # Initialize extraction attack extraction_attack = ModelExtractionAttack( target_api_url='http://localhost:5000', max_queries=5000 ) # Create substitute model substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10) # Perform extraction extracted_model, stats = extraction_attack.extract_model( input_shape=(784,), substitute_model=substitute_model, synthesis_method='random', batch_size=100 ) print(f"Extraction completed:") print(f"Total queries: {stats['total_queries']}") print(f"Iterations: {stats['iterations']}") print(f"Final accuracy: {stats['accuracy_history'][-1]:.4f}") return extracted_model, stats # Run basic extraction test extracted_model, basic_stats = test_basic_extraction() -
Implement Active Learning Extraction
Compare active learning with basic extraction:
def test_active_learning_extraction(): """ Test active learning-based extraction """ # Initialize active learning extraction active_extraction = ActiveLearningExtraction( target_api_url='http://localhost:5000', max_queries=5000 ) # Create substitute model substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10) # Perform active learning extraction extracted_model, stats = active_extraction.active_extraction( input_shape=(784,), substitute_model=substitute_model, initial_samples=500, batch_size=50, num_iterations=20 ) print(f"Active learning extraction completed:") print(f"Total queries: {stats['total_queries']}") print(f"Iterations: {stats['iterations']}") print(f"Final uncertainty: {stats['uncertainty_history'][-1]:.4f}") return extracted_model, stats # Run active learning extraction test active_model, active_stats = test_active_learning_extraction() -
Implement Defense Mechanisms
Create and test defense mechanisms:
class DefendedAPI: def __init__(self, target_model): """ Initialize defended API with multiple defense mechanisms """ self.target_model = target_model self.query_limiter = QueryLimitingDefense(max_queries_per_ip=100) self.output_perturber = OutputPerturbationDefense(noise_std=0.1) self.input_validator = InputValidationDefense() def predict(self, inputs, client_ip): """ Make prediction with defense mechanisms Args: inputs: Input data client_ip: Client IP address Returns: predictions: Defended predictions or error """ # Check query limits if not self.query_limiter.check_query_limit(client_ip): return {"error": "Query limit exceeded"} # Validate input valid, reason = self.input_validator.validate_input(inputs) if not valid: return {"error": f"Invalid input: {reason}"} # Get model predictions inputs_tensor = torch.FloatTensor(inputs) with torch.no_grad(): outputs = self.target_model(inputs_tensor) predictions = torch.softmax(outputs, dim=1).numpy() # Perturb outputs indices, values = self.output_perturber.perturb_predictions(predictions) return { "predictions": values.tolist(), "indices": indices.tolist() } def test_defense_effectiveness(): """ Test effectiveness of defense mechanisms against extraction """ # Create defended API defended_api = DefendedAPI(target_model) # Attempt extraction against defended API extraction_attack = ModelExtractionAttack( target_api_url='http://localhost:5000', # Would be defended API max_queries=1000 ) substitute_model = SimpleSubstituteModel(input_size=784, num_classes=10) try: extracted_model, stats = extraction_attack.extract_model( input_shape=(784,), substitute_model=substitute_model, batch_size=50 ) print(f"Extraction against defended API:") print(f"Total queries: {stats['total_queries']}") print(f"Success: {stats['total_queries'] > 0}") except Exception as e: print(f"Extraction blocked by defenses: {e}") return stats # Test defense effectiveness defense_stats = test_defense_effectiveness()
๐ Deliverables:
- Working model extraction attack implementation
- Active learning extraction with uncertainty sampling
- Multiple defense mechanisms (rate limiting, output perturbation, input validation)
- Comparison of extraction effectiveness with and without defenses
- Analysis of query efficiency (basic vs. active learning)
- Recommendations for API security best practices