Domain Adaptation
Domain adaptation enables models trained on one domain (source) to perform well on a different but related domain (target). This is crucial when labeled data is scarce in the target domain but abundant in a related source domain.
Interactive Adaptation Simulator
The Domain Shift Problem
Distribution Mismatch
When we deploy models in new domains, we encounter:
- Covariate Shift: P(X) changes but P(Y|X) remains same
- Label Shift: P(Y) changes but P(X|Y) remains same
- Concept Drift: P(Y|X) changes over time
Real-World Examples
| Source Domain | Target Domain | Challenge |
|---|---|---|
| General Web Text | Medical Records | Specialized terminology |
| News Articles | Social Media | Informal language |
| English Reviews | Spanish Reviews | Language + culture |
| Synthetic Data | Real Sensors | Noise patterns |
Adaptation Strategies
1. Fine-Tuning
The simplest approach - continue training on target data:
def fine_tune(model, source_data, target_data, config): # Pre-train on source domain model.fit(source_data, epochs=config.source_epochs) # Fine-tune on target domain with smaller learning rate optimizer = Adam(lr=config.lr * 0.1) for epoch in range(config.target_epochs): # Optional: mix source and target data if config.mix_ratio > 0: batch = mix_batches(source_data, target_data, config.mix_ratio) else: batch = target_data.sample() loss = model.train_step(batch) # Early stopping based on target validation if should_stop(loss, patience=5): break return model
2. Adapter Layers
Parameter-efficient adaptation without forgetting:
class AdapterLayer(nn.Module): def __init__(self, hidden_size, adapter_size=64): super().__init__() self.down_project = nn.Linear(hidden_size, adapter_size) self.up_project = nn.Linear(adapter_size, hidden_size) self.activation = nn.ReLU() def forward(self, x): # Keep original path residual = x # Adapter path x = self.down_project(x) x = self.activation(x) x = self.up_project(x) # Residual connection return residual + x class AdapterBERT(nn.Module): def __init__(self, bert_model): super().__init__() self.bert = bert_model # Freeze BERT parameters for param in self.bert.parameters(): param.requires_grad = False # Add adapters to each layer self.adapters = nn.ModuleList([ AdapterLayer(768) for _ in range(12) ])
3. Elastic Weight Consolidation (EWC)
Prevents catastrophic forgetting:
class EWC: def __init__(self, model, source_data, lambda_ewc=0.4): self.model = model self.lambda_ewc = lambda_ewc # Compute Fisher Information Matrix self.fisher = self.compute_fisher(source_data) # Store optimal source parameters self.optimal_params = { name: param.clone() for name, param in model.named_parameters() } def compute_fisher(self, data): """Estimate importance of each parameter""" fisher = {} model.eval() for batch in data: model.zero_grad() output = model(batch.input) loss = F.cross_entropy(output, batch.target) loss.backward() for name, param in model.named_parameters(): if param.grad is not None: if name not in fisher: fisher[name] = param.grad.data.clone() ** 2 else: fisher[name] += param.grad.data.clone() ** 2 # Normalize for name in fisher: fisher[name] /= len(data) return fisher def penalty(self): """EWC penalty term""" loss = 0 for name, param in self.model.named_parameters(): if name in self.fisher: loss += (self.fisher[name] * (param - self.optimal_params[name]) ** 2).sum() return self.lambda_ewc * loss
4. Domain-Adversarial Training (DANN)
Learn domain-invariant features:
class DANN(nn.Module): def __init__(self, feature_extractor, task_classifier, domain_classifier): super().__init__() self.feature_extractor = feature_extractor self.task_classifier = task_classifier self.domain_classifier = domain_classifier self.gradient_reversal = GradientReversal() def forward(self, x, alpha=1.0): # Extract features features = self.feature_extractor(x) # Task prediction task_output = self.task_classifier(features) # Domain prediction with gradient reversal reversed_features = self.gradient_reversal(features, alpha) domain_output = self.domain_classifier(reversed_features) return task_output, domain_output def train_step(self, source_batch, target_batch): # Process source domain src_task, src_domain = self(source_batch.x) task_loss = F.cross_entropy(src_task, source_batch.y) src_domain_loss = F.binary_cross_entropy( src_domain, torch.zeros_like(src_domain) ) # Process target domain (no task labels) _, tgt_domain = self(target_batch.x) tgt_domain_loss = F.binary_cross_entropy( tgt_domain, torch.ones_like(tgt_domain) ) # Combined loss total_loss = task_loss + src_domain_loss + tgt_domain_loss return total_loss
Advanced Techniques
1. Self-Training / Pseudo-Labeling
Use model predictions as labels:
def self_training(model, source_data, target_data, threshold=0.9): # Initial training on source model.fit(source_data) for iteration in range(num_iterations): # Generate pseudo-labels for target pseudo_labels = [] for batch in target_data: predictions = model.predict(batch) confidence = predictions.max(dim=1)[0] # Only use high-confidence predictions mask = confidence > threshold if mask.any(): pseudo_labels.append({ 'x': batch[mask], 'y': predictions[mask].argmax(dim=1) }) # Retrain with pseudo-labels combined_data = source_data + pseudo_labels model.fit(combined_data) # Gradually decrease threshold threshold *= 0.95
2. Maximum Mean Discrepancy (MMD)
Minimize distribution distance:
def mmd_loss(source_features, target_features, kernel='rbf'): """Maximum Mean Discrepancy for domain alignment""" def rbf_kernel(x, y, gamma=1.0): """RBF kernel for MMD""" xx = torch.matmul(x, x.t()) yy = torch.matmul(y, y.t()) xy = torch.matmul(x, y.t()) rx = xx.diag().unsqueeze(0).expand_as(xx) ry = yy.diag().unsqueeze(0).expand_as(yy) dxx = rx.t() + rx - 2 * xx dyy = ry.t() + ry - 2 * yy dxy = rx.t() + ry - 2 * xy return torch.exp(-gamma * dxx), \ torch.exp(-gamma * dyy), \ torch.exp(-gamma * dxy) kxx, kyy, kxy = rbf_kernel(source_features, target_features) mmd = kxx.mean() + kyy.mean() - 2 * kxy.mean() return mmd
Evaluation Strategies
1. Target Domain Performance
Primary metric - accuracy on target test set:
def evaluate_adaptation(model, source_test, target_test): results = { 'source_accuracy': model.evaluate(source_test), 'target_accuracy': model.evaluate(target_test), 'adaptation_gap': None } # Compute adaptation effectiveness baseline_accuracy = train_from_scratch(target_train).evaluate(target_test) results['adaptation_gap'] = results['target_accuracy'] - baseline_accuracy return results
2. Feature Alignment Metrics
Measure distribution alignment:
def compute_alignment_metrics(source_features, target_features): metrics = {} # A-distance (proxy) metrics['a_distance'] = compute_a_distance( source_features, target_features ) # Correlation alignment cs = torch.matmul(source_features.t(), source_features) ct = torch.matmul(target_features.t(), target_features) metrics['coral_loss'] = torch.norm(cs - ct, 'fro') ** 2 # Earth Mover's Distance metrics['emd'] = wasserstein_distance( source_features, target_features ) return metrics
Best Practices
1. Data Considerations
- Data Quality: Clean target data is crucial
- Data Quantity: Even small target datasets help
- Data Diversity: Cover target domain variations
2. Training Strategies
# Gradual unfreezing def gradual_unfreeze(model, target_data): layers = list(model.children()) for i in range(len(layers)): # Unfreeze from top to bottom for j in range(len(layers) - i, len(layers)): for param in layers[j].parameters(): param.requires_grad = True # Train for a few epochs train_epochs(model, target_data, epochs=2)
3. Hyperparameter Guidelines
| Parameter | Recommended Range | Notes |
|---|---|---|
| Learning Rate | 1e-5 to 1e-4 | Lower than source training |
| Batch Size | 8-32 | Smaller for limited target data |
| Epochs | 3-10 | Avoid overfitting |
| Warmup Steps | 10% of total | Stabilize training |
| Mix Ratio | 0.1-0.3 | Source:Target ratio |
Common Pitfalls
1. Catastrophic Forgetting
Model forgets source knowledge:
Solutions:
- Use adapter layers
- Apply EWC or similar regularization
- Mix source and target data
- Lower learning rates
2. Negative Transfer
Source hurts target performance:
Solutions:
- Careful source selection
- Domain-adversarial training
- Start from general pre-trained models
3. Overfitting to Target
Model memorizes small target dataset:
Solutions:
- Strong regularization
- Data augmentation
- Early stopping
- Ensemble methods
Future Directions
Emerging Techniques
- Meta-Learning: Learn to adapt quickly
- Continuous Adaptation: Online learning in deployment
- Multi-Source Adaptation: Leverage multiple source domains
- Test-Time Adaptation: Adapt during inference
Research Challenges
- Unsupervised domain adaptation
- Open-set domain adaptation
- Domain generalization
- Federated domain adaptation
Conclusion
Domain adaptation is essential for deploying models in real-world scenarios where training and deployment distributions differ. The interactive visualization above demonstrates various adaptation techniques and their effects on model performance across domains.
Success in domain adaptation requires careful consideration of the domain gap, appropriate technique selection, and thorough evaluation. As models become more powerful, effective domain adaptation becomes increasingly critical for practical AI systems.
