Plugin Development Guide¶
This guide shows how to create custom plugins for Tracelet. Plugins are the primary way to extend Tracelet with new backends, framework integrations, and data collectors.
Understanding Tracelet's Architecture¶
Plugin vs Backend - Clarification¶
In Tracelet, backends are a specific type of plugin. This is a common source of confusion for contributors, so let's clarify:
- Plugin: Generic extensibility system that can be any type of component
- Backend: A plugin specifically of type
PluginType.BACKEND
for experiment tracking
# All backends are plugins, but not all plugins are backends
class MLflowBackend(BackendPlugin): # This IS a plugin
@classmethod
def get_metadata(cls) -> PluginMetadata:
return PluginMetadata(
name="mlflow",
type=PluginType.BACKEND, # This makes it a backend-type plugin
# ...
)
class SystemCollector(PluginBase): # This is also a plugin
@classmethod
def get_metadata(cls) -> PluginMetadata:
return PluginMetadata(
name="system",
type=PluginType.COLLECTOR, # This makes it a collector-type plugin
# ...
)
The Plugin Hierarchy¶
PluginBase (Abstract Base Class)
├── BackendPlugin (inherits from PluginBase)
│ ├── MLflowBackend
│ ├── WandBBackend
│ └── ClearMLBackend
├── CollectorPlugin (inherits from PluginBase)
│ ├── SystemCollector
│ └── GitCollector
├── FrameworkPlugin (inherits from PluginBase)
│ ├── PyTorchFramework
│ └── LightningFramework
└── ProcessorPlugin (inherits from PluginBase)
└── CustomProcessors
Key Concepts¶
- Plugin System: The overall architecture for extensibility
- Plugin Types: Categories of plugins (BACKEND, COLLECTOR, FRAMEWORK, PROCESSOR)
- Plugin Manager: Discovers, loads, and manages plugin lifecycle
- Orchestrator: Routes data between sources (frameworks) and sinks (backends)
Plugin Types¶
Tracelet supports four types of plugins:
- Backend Plugins: Experiment tracking backends (MLflow, W&B, etc.)
- Framework Plugins: ML framework integrations (PyTorch, Lightning, etc.)
- Collector Plugins: Data collectors (system metrics, git info, etc.)
- Processor Plugins: Data processors and transformers
Creating a Backend Plugin¶
Backend plugins integrate Tracelet with experiment tracking platforms.
1. Basic Backend Plugin Structure¶
# tracelet/backends/neptune_backend.py
from typing import Any, Dict, Optional
from tracelet.core.plugins import BackendPlugin, PluginMetadata, PluginType
from tracelet.core.orchestrator import MetricData, MetricType
from tracelet.utils.imports import require
class NeptuneBackend(BackendPlugin):
"""Neptune.ai backend plugin for experiment tracking."""
def __init__(self):
super().__init__()
self._run = None
self._project = None
@classmethod
def get_metadata(cls) -> PluginMetadata:
return PluginMetadata(
name="neptune",
version="1.0.0",
type=PluginType.BACKEND,
description="Neptune.ai experiment tracking backend",
author="Your Name",
dependencies=["neptune"],
capabilities={"metrics", "parameters", "artifacts", "logging"}
)
def initialize(self, config: Dict[str, Any]):
"""Initialize Neptune backend with configuration."""
# Use dynamic import for optional dependency
neptune = require("neptune", "Neptune backend")
self._config = config
project_name = config.get("project", "workspace/project")
api_token = config.get("api_token")
# Initialize Neptune
self._run = neptune.init_run(
project=project_name,
api_token=api_token,
name=config.get("run_name")
)
def start(self):
"""Start the backend."""
self._active = True
def stop(self):
"""Stop the backend and cleanup."""
if self._run:
self._run.stop()
self._active = False
def get_status(self) -> Dict[str, Any]:
"""Get backend status."""
return {
"active": self._active,
"run_id": self._run["sys/id"].fetch() if self._run else None
}
def handle_metric(self, metric: MetricData):
"""Handle incoming metric data."""
if not self._run or not self._active:
return
if metric.type == MetricType.SCALAR:
self._run[metric.name].append(metric.value, step=metric.iteration)
elif metric.type == MetricType.PARAMETER:
self._run[f"parameters/{metric.name}"] = metric.value
elif metric.type == MetricType.ARTIFACT:
self._run[f"artifacts/{metric.name}"].upload(metric.value)
2. Register the Backend Plugin¶
# tracelet/backends/__init__.py
from .neptune_backend import NeptuneBackend
def get_backend(name: str):
"""Get backend plugin by name."""
backends = {
"neptune": NeptuneBackend,
# ... other backends
}
return backends.get(name)
3. Add Configuration Support¶
# Update tracelet/settings.py to include Neptune settings
from pydantic import BaseSettings
class TraceletSettings(BaseSettings):
# ... existing settings
# Neptune-specific settings
neptune_project: Optional[str] = None
neptune_api_token: Optional[str] = None
neptune_mode: str = "async"
class Config:
env_prefix = "TRACELET_"
4. Write Tests¶
# tests/unit/backends/test_neptune_backend.py
import pytest
from unittest.mock import Mock, patch
from tracelet.backends.neptune_backend import NeptuneBackend
from tracelet.core.orchestrator import MetricData, MetricType
class TestNeptuneBackend:
@patch('tracelet.backends.neptune_backend.require')
def test_initialize(self, mock_require):
"""Test Neptune backend initialization."""
mock_neptune = Mock()
mock_require.return_value = mock_neptune
backend = NeptuneBackend()
config = {
"project": "workspace/test-project",
"api_token": "test-token"
}
backend.initialize(config)
mock_neptune.init_run.assert_called_once_with(
project="workspace/test-project",
api_token="test-token",
name=None
)
def test_handle_scalar_metric(self):
"""Test handling scalar metrics."""
backend = NeptuneBackend()
backend._run = Mock()
backend._active = True
metric = MetricData(
name="accuracy",
value=0.95,
type=MetricType.SCALAR,
iteration=100
)
backend.handle_metric(metric)
backend._run["accuracy"].append.assert_called_once_with(0.95, step=100)
Creating a Framework Plugin¶
Framework plugins integrate Tracelet with ML frameworks to automatically capture metrics.
1. Basic Framework Plugin Structure¶
# tracelet/frameworks/jax_framework.py
from typing import Any, Dict
from tracelet.core.plugins import PluginBase, PluginMetadata, PluginType
from tracelet.core.orchestrator import MetricData, MetricType, MetricSource
class JAXFramework(PluginBase, MetricSource):
"""JAX framework integration plugin."""
def __init__(self):
self._experiment = None
self._original_functions = {}
self._patched = False
@classmethod
def get_metadata(cls) -> PluginMetadata:
return PluginMetadata(
name="jax",
version="1.0.0",
type=PluginType.FRAMEWORK,
description="JAX framework integration",
dependencies=["jax", "flax"],
capabilities={"metric_capture", "parameter_logging"}
)
def initialize(self, config: Dict[str, Any]):
"""Initialize JAX framework integration."""
self._config = config
def start(self):
"""Start JAX integration by patching functions."""
if not self._patched:
self._patch_jax_logging()
self._patched = True
def stop(self):
"""Stop JAX integration and restore original functions."""
if self._patched:
self._restore_original_functions()
self._patched = False
def get_status(self) -> Dict[str, Any]:
return {"patched": self._patched}
def get_source_id(self) -> str:
return "jax_framework"
def emit_metric(self, metric: MetricData):
"""Emit metric to orchestrator."""
if self._experiment:
self._experiment.emit_metric(metric)
def set_experiment(self, experiment):
"""Set the active experiment."""
self._experiment = experiment
def _patch_jax_logging(self):
"""Patch JAX/Flax logging functions."""
try:
import flax.training.train_state as train_state
# Store original function
self._original_functions['apply_gradients'] = train_state.TrainState.apply_gradients
def wrapped_apply_gradients(train_state_self, **kwargs):
# Call original function
result = self._original_functions['apply_gradients'](train_state_self, **kwargs)
# Capture metrics
step = int(train_state_self.step) if hasattr(train_state_self, 'step') else None
# Emit learning rate if available
if hasattr(train_state_self, 'opt_state') and hasattr(train_state_self.tx, 'learning_rate'):
lr_metric = MetricData(
name="learning_rate",
value=float(train_state_self.tx.learning_rate),
type=MetricType.SCALAR,
iteration=step,
source=self.get_source_id()
)
self.emit_metric(lr_metric)
return result
# Apply patch
train_state.TrainState.apply_gradients = wrapped_apply_gradients
except ImportError:
# JAX/Flax not available
pass
def _restore_original_functions(self):
"""Restore original JAX functions."""
try:
import flax.training.train_state as train_state
if 'apply_gradients' in self._original_functions:
train_state.TrainState.apply_gradients = self._original_functions['apply_gradients']
except ImportError:
pass
2. Advanced Framework Integration¶
For more complex integrations, you can hook into training loops:
class AdvancedJAXFramework(JAXFramework):
"""Advanced JAX integration with training loop detection."""
def _patch_jax_logging(self):
"""Enhanced patching with training loop detection."""
super()._patch_jax_logging()
# Patch common JAX training patterns
self._patch_optax_optimizers()
self._patch_flax_training()
def _patch_optax_optimizers(self):
"""Patch Optax optimizers to capture optimization metrics."""
try:
import optax
# Store original update function
original_update = optax.GradientTransformation.update
def wrapped_update(tx_self, updates, state, params=None):
result = original_update(tx_self, updates, state, params)
# Capture gradient norms
if updates:
grad_norm = jnp.sqrt(sum(jnp.sum(jnp.square(g)) for g in jax.tree_leaves(updates)))
norm_metric = MetricData(
name="gradient_norm",
value=float(grad_norm),
type=MetricType.SCALAR,
source=self.get_source_id()
)
self.emit_metric(norm_metric)
return result
optax.GradientTransformation.update = wrapped_update
except ImportError:
pass
Creating a Collector Plugin¶
Collector plugins gather data from external sources (system metrics, git info, etc.).
1. Basic Collector Plugin¶
# tracelet/collectors/docker_collector.py
import time
from typing import Any, Dict, List
from tracelet.core.plugins import PluginBase, PluginMetadata, PluginType
from tracelet.core.orchestrator import MetricData, MetricType
class DockerCollector(PluginBase):
"""Collects Docker container metrics."""
def __init__(self):
self._docker_client = None
self._container_id = None
self._collection_interval = 30
@classmethod
def get_metadata(cls) -> PluginMetadata:
return PluginMetadata(
name="docker",
version="1.0.0",
type=PluginType.COLLECTOR,
description="Docker container metrics collector",
dependencies=["docker"],
capabilities={"system_metrics", "resource_monitoring"}
)
def initialize(self, config: Dict[str, Any]):
"""Initialize Docker collector."""
from docker import DockerClient
self._config = config
self._collection_interval = config.get("interval", 30)
self._docker_client = DockerClient.from_env()
# Auto-detect current container
self._container_id = self._detect_current_container()
def start(self):
"""Start the collector."""
self._active = True
def stop(self):
"""Stop the collector."""
self._active = False
def get_status(self) -> Dict[str, Any]:
return {
"active": self._active,
"container_id": self._container_id,
"interval": self._collection_interval
}
def collect(self) -> List[MetricData]:
"""Collect Docker container metrics."""
if not self._active or not self._container_id:
return []
try:
container = self._docker_client.containers.get(self._container_id)
stats = container.stats(stream=False)
metrics = []
timestamp = time.time()
# CPU metrics
cpu_percent = self._calculate_cpu_percent(stats)
metrics.append(MetricData(
name="docker/cpu_percent",
value=cpu_percent,
type=MetricType.SYSTEM,
timestamp=timestamp,
source="docker_collector"
))
# Memory metrics
memory_usage = stats['memory_stats']['usage']
memory_limit = stats['memory_stats']['limit']
memory_percent = (memory_usage / memory_limit) * 100
metrics.append(MetricData(
name="docker/memory_percent",
value=memory_percent,
type=MetricType.SYSTEM,
timestamp=timestamp,
source="docker_collector"
))
return metrics
except Exception as e:
# Log error and return empty list
return []
def _detect_current_container(self) -> str:
"""Auto-detect current container ID."""
try:
with open('/proc/self/cgroup', 'r') as f:
for line in f:
if 'docker' in line:
return line.split('/')[-1].strip()
except FileNotFoundError:
pass
return None
def _calculate_cpu_percent(self, stats: Dict) -> float:
"""Calculate CPU percentage from Docker stats."""
cpu_delta = stats['cpu_stats']['cpu_usage']['total_usage'] - \
stats['precpu_stats']['cpu_usage']['total_usage']
system_delta = stats['cpu_stats']['system_cpu_usage'] - \
stats['precpu_stats']['system_cpu_usage']
if system_delta > 0:
return (cpu_delta / system_delta) * len(stats['cpu_stats']['cpu_usage']['percpu_usage']) * 100
return 0.0
Plugin Registration and Discovery¶
1. Manual Registration¶
# Register plugins manually in your application
from tracelet.core.experiment import Experiment
from tracelet.backends.neptune_backend import NeptuneBackend
from tracelet.frameworks.jax_framework import JAXFramework
# Create experiment with custom plugins
exp = Experiment(name="custom_experiment")
# Add custom backend
neptune_backend = NeptuneBackend()
exp._plugin_manager.register_plugin(neptune_backend)
# Add framework integration
jax_framework = JAXFramework()
exp._plugin_manager.register_plugin(jax_framework)
exp.start()
2. Automatic Discovery¶
# Create plugin entry points in setup.py or pyproject.toml
[project.entry-points."tracelet.plugins"]
neptune = "tracelet.backends.neptune_backend:NeptuneBackend"
jax = "tracelet.frameworks.jax_framework:JAXFramework"
docker = "tracelet.collectors.docker_collector:DockerCollector"
Testing Plugins¶
1. Unit Tests¶
# tests/unit/test_neptune_backend.py
import pytest
from unittest.mock import Mock, patch
from tracelet.backends.neptune_backend import NeptuneBackend
@pytest.fixture
def mock_neptune():
with patch('tracelet.backends.neptune_backend.require') as mock_require:
mock_neptune = Mock()
mock_require.return_value = mock_neptune
yield mock_neptune
class TestNeptuneBackend:
def test_initialization(self, mock_neptune):
backend = NeptuneBackend()
config = {"project": "test/project", "api_token": "token"}
backend.initialize(config)
mock_neptune.init_run.assert_called_once()
def test_metric_handling(self, mock_neptune):
backend = NeptuneBackend()
backend._run = Mock()
backend._active = True
from tracelet.core.orchestrator import MetricData, MetricType
metric = MetricData("test_metric", 1.0, MetricType.SCALAR)
backend.handle_metric(metric)
backend._run["test_metric"].append.assert_called_once_with(1.0, step=None)
2. Integration Tests¶
# tests/integration/test_plugin_integration.py
import pytest
from tracelet import Experiment
from tracelet.backends.neptune_backend import NeptuneBackend
@pytest.mark.integration
def test_neptune_integration():
"""Test Neptune backend integration with real Neptune API."""
# This test requires NEPTUNE_API_TOKEN environment variable
exp = Experiment(name="integration_test", backend=["neptune"])
exp.start()
# Log some metrics
exp.log_metric("test_accuracy", 0.95, iteration=1)
exp.log_params({"learning_rate": 0.001, "batch_size": 32})
exp.stop()
# Verify metrics were logged to Neptune
# (Implementation depends on Neptune API for verification)
Best Practices¶
1. Error Handling¶
def handle_metric(self, metric: MetricData):
"""Handle metric with robust error handling."""
try:
# Validate metric
if not self._validate_metric(metric):
return
# Process metric
self._process_metric(metric)
except Exception as e:
# Log error but don't crash
logger.error(f"Failed to handle metric {metric.name}: {e}")
# Optionally, emit error metric
error_metric = MetricData(
name="tracelet/plugin_errors",
value=1.0,
type=MetricType.SYSTEM,
metadata={"plugin": self.get_metadata().name, "error": str(e)}
)
self.emit_metric(error_metric)
2. Resource Management¶
class ResourceManagedPlugin(PluginBase):
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def stop(self):
"""Ensure all resources are cleaned up."""
try:
# Cleanup network connections
if hasattr(self, '_client'):
self._client.close()
# Stop background threads
if hasattr(self, '_threads'):
for thread in self._threads:
thread.join(timeout=5.0)
except Exception as e:
logger.error(f"Error during cleanup: {e}")
finally:
self._active = False
3. Configuration Validation¶
from pydantic import BaseModel, validator
class NeptuneConfig(BaseModel):
project: str
api_token: str
mode: str = "async"
@validator('project')
def validate_project_format(cls, v):
if '/' not in v:
raise ValueError('Project must be in format "workspace/project"')
return v
@validator('mode')
def validate_mode(cls, v):
if v not in ['async', 'sync', 'offline']:
raise ValueError('Mode must be one of: async, sync, offline')
return v
class NeptuneBackend(BackendPlugin):
def initialize(self, config: Dict[str, Any]):
# Validate configuration
validated_config = NeptuneConfig(**config)
# Use validated config
self._setup_neptune(validated_config)
Publishing Plugins¶
1. Package Structure¶
tracelet-neptune-plugin/
├── pyproject.toml
├── README.md
├── src/
│ └── tracelet_neptune/
│ ├── __init__.py
│ ├── backend.py
│ └── py.typed
├── tests/
│ ├── test_backend.py
│ └── test_integration.py
└── docs/
└── usage.md
2. Setup Configuration¶
# pyproject.toml
[project]
name = "tracelet-neptune"
version = "1.0.0"
description = "Neptune.ai backend plugin for Tracelet"
dependencies = ["tracelet>=0.1.0", "neptune>=1.0.0"]
[project.entry-points."tracelet.backends"]
neptune = "tracelet_neptune:NeptuneBackend"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
3. Installation¶
# Install from PyPI
pip install tracelet-neptune
# Use in code
from tracelet import Experiment
exp = Experiment(name="test", backend=["neptune"])
This guide provides the foundation for creating robust, well-tested plugins that extend Tracelet's capabilities. Follow these patterns and best practices to ensure your plugins integrate smoothly with the Tracelet ecosystem.