MLflow’s drift monitoring isn’t about catching a model "going bad" in a moral sense, but about detecting when the statistical distribution of incoming data has significantly diverged from the data the model was trained on.

Let’s see it in action. Imagine you’ve trained a model to predict customer churn.

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import mlflow
from mlflow.models import infer_signature
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

# Simulate training data
data = {
    'feature1': np.random.rand(1000),
    'feature2': np.random.rand(1000) * 10,
    'feature3': np.random.randint(0, 5, 1000),
    'target': np.random.randint(0, 2, 1000)
}
df = pd.DataFrame(data)

X = df[['feature1', 'feature2', 'feature3']]
y = df['target']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train a simple model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Log the model with MLflow
with mlflow.start_run(run_name="churn_model_v1"):
    # Infer model signature
    signature = infer_signature(X_train, model.predict(X_train))

    # Log the model
    model_uri = mlflow.sklearn.log_model(
        sk_model=model,
        artifact_path="model",
        signature=signature,
        input_example=X_train.head(3)
    ).model_uri

    print(f"Model logged to: {model_uri}")

# Simulate new incoming data with drift
new_data = {
    'feature1': np.random.rand(500) * 1.5, # Drift in feature1
    'feature2': np.random.rand(500) * 10,
    'feature3': np.random.randint(2, 7, 500) # Drift in feature3
}
df_new = pd.DataFrame(new_data)

# You would then use MLflow's monitoring capabilities to compare df_new (production data)
# against the X_train data associated with the logged model.
# This comparison typically involves statistical tests or distribution divergence metrics.

MLflow Drift Monitoring, specifically through its integration with tools like evidently or custom metrics, allows you to establish a baseline of your model’s expected performance and data characteristics. When new data arrives, it’s compared against this baseline. Significant statistical differences in feature distributions or target variable distributions trigger alerts. This isn’t about the model’s accuracy dropping immediately; it’s about the input changing, which will eventually lead to accuracy degradation if unaddressed.

The core problem this solves is the "silent failure" of ML models. A model might continue to produce predictions, but if the underlying data generating process has shifted (e.g., customer behavior changes due to a new competitor, economic shifts), those predictions become increasingly irrelevant and potentially harmful. Drift monitoring acts as an early warning system, flagging that the model’s assumptions are no longer valid before performance metrics tank.

Internally, drift detection often relies on comparing probability distributions. Common methods include:

  • Kolmogorov-Smirnov (K-S) test: For continuous features, it tests if two samples are drawn from the same distribution. MLflow can log the p-value of this test.
  • Chi-squared test: For categorical features, it checks for independence between two categorical variables, which can be adapted to compare distributions.
  • Population Stability Index (PSI): A common metric in credit scoring, it measures how much a variable’s distribution has shifted between a baseline (training) and a current (production) population. A PSI value above 0.2 typically indicates significant drift.
  • Wasserstein Distance (Earth Mover’s Distance): A more robust metric that measures the "cost" of transforming one distribution into another.

MLflow’s autologging or manual logging of ModelSignature and input_example are crucial here. The ModelSignature defines the expected schema and data types of the model’s inputs and outputs. When you set up drift monitoring, you associate a production dataset with a specific logged model version. The monitoring system then uses the training data (or a representative sample logged with the model) as the reference distribution.

The exact levers you control are the thresholds for these statistical tests and metrics. For instance, you might set an alert to trigger if the PSI for feature1 exceeds 0.25, or if the K-S test p-value for feature2 drops below 0.05. You also define which features and the target variable to monitor. MLflow itself doesn’t perform the drift calculations directly in its core library but integrates with specialized libraries (like evidently, alibi-detect, or custom Python scripts) that do the heavy lifting. MLflow provides the framework to log these drift metrics as MLflow metrics, associate them with model versions, and visualize them over time in the MLflow UI.

A critical detail often overlooked is that drift detection can be applied not just to input features but also to the model’s predictions and actual outcomes (if available). Monitoring prediction drift can indicate issues even before ground truth is available, while monitoring outcome drift (comparing predicted vs. actual distributions) requires a feedback loop but is the most direct measure of model performance degradation.

The next step after setting up drift monitoring is often implementing automated retraining pipelines triggered by significant drift alerts.

Want structured learning?

Take the full Mlflow course →