MLflow custom flavors let you save and load any Python object, not just models from supported frameworks like scikit-learn or TensorFlow.
Let’s say you’ve built a custom data preprocessing pipeline using pandas and scikit-learn transformers, and you want to package this whole thing up to be deployed alongside your model. MLflow custom flavors are your ticket.
Here’s a simple example of a custom "preprocessor" flavor.
import mlflow
import pandas as pd
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
class PreprocessorFlavor:
def __init__(self, pipeline: Pipeline):
self.pipeline = pipeline
def predict(self, data: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame(self.pipeline.transform(data), columns=data.columns)
@staticmethod
def log_model(
pipeline: Pipeline,
artifact_path: str,
registered_model_name: str = None,
**kwargs,
):
"""Logs a scikit-learn pipeline as a custom MLflow model."""
flavor_data = PreprocessorFlavor(pipeline)
mlflow.pyfunc.log_model(
artifact_path=artifact_path,
python_model=flavor_data,
registered_model_name=registered_model_name,
input_example=pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), # Optional but good practice
signature=mlflow.models.infer_signature(pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]}), pd.DataFrame({'col1': [1, 2], 'col2': [3, 4]})), # Optional but good practice
pip_requirements=["scikit-learn==1.2.2", "pandas==2.0.3"], # Crucial for reproducibility
**kwargs,
)
@staticmethod
def load_model(model_uri: str) -> PreprocessorFlavor:
"""Loads a scikit-learn pipeline logged with the custom flavor."""
return mlflow.pyfunc.load_model(model_uri)
# Example Usage:
if __name__ == "__main__":
# 1. Build and train your custom object (e.g., a scikit-learn pipeline)
data = pd.DataFrame({
'feature1': [1.0, 2.0, 3.0, 4.0, 5.0],
'feature2': [5.0, 4.0, 3.0, 2.0, 1.0]
})
pipeline = Pipeline([
('scaler', StandardScaler()),
])
pipeline.fit(data) # Fit the pipeline to some dummy data
# 2. Log the custom object using the custom flavor
with mlflow.start_run() as run:
mlflow.pyfunc.log_model(
artifact_path="custom_preprocessor",
python_model=PreprocessorFlavor(pipeline),
pip_requirements=["scikit-learn==1.2.2", "pandas==2.0.3"],
input_example=pd.DataFrame({'feature1': [1.0], 'feature2': [5.0]}),
signature=mlflow.models.infer_signature(
pd.DataFrame({'feature1': [1.0], 'feature2': [5.0]}),
pd.DataFrame({'feature1': [1.0], 'feature2': [5.0]})
)
)
run_id = run.info.run_id
artifact_uri = f"runs:/{run_id}/custom_preprocessor"
print(f"Model logged at: {artifact_uri}")
# 3. Load the custom object
loaded_preprocessor = PreprocessorFlavor.load_model(artifact_uri)
# 4. Use the loaded custom object
sample_data = pd.DataFrame({'feature1': [6.0], 'feature2': [0.0]})
transformed_data = loaded_preprocessor.predict(sample_data)
print("Transformed data:")
print(transformed_data)
The core idea is to wrap your custom object (in this case, a sklearn.pipeline.Pipeline) within a class that conforms to the MLflow PythonModel interface. This interface requires a predict method. MLflow’s pyfunc.log_model is the workhorse here; it takes your PythonModel instance and serializes it along with its dependencies.
When you call mlflow.pyfunc.log_model, MLflow essentially does two things:
- It serializes your
PythonModelinstance (e.g., usingpickle). - It saves the specified
pip_requirementsinto arequirements.txtfile within the MLflow artifact directory.
This combination is what makes it a "flavor." MLflow knows how to load this specific type of artifact because it finds the serialized PythonModel object and the requirements.txt file. When you later call mlflow.pyfunc.load_model, MLflow reconstructs your PythonModel instance in an environment that has the specified dependencies installed.
The input_example and signature arguments are not strictly required for custom flavors, but they are highly recommended. They allow MLflow to generate schema information and provide sample data for the model, which is invaluable for deployment tools and for understanding the model’s expected inputs and outputs. The pip_requirements are critical for reproducibility. Without them, loading the model might fail if the environment where it’s loaded doesn’t have the exact same library versions.
The most surprising truth about custom flavors is that mlflow.pyfunc is the generic custom flavor mechanism. You don’t need to define a whole new "flavor" name in MLflow’s registry for every custom type you want to package; you just need to adhere to the PythonModel interface and use mlflow.pyfunc.log_model. MLflow’s extensibility comes from its ability to serialize and deserialize arbitrary Python objects, provided you can define a predict method and specify the dependencies.
The PythonModel interface is intentionally simple: a predict method and an optional load_context method. The load_context method is called after the model is loaded from storage and can be used to initialize things like database connections or load auxiliary files that were also logged as artifacts.
The key to making this work reliably across different environments is the pip_requirements argument. MLflow uses this to create a Conda environment or a requirements.txt file when the model is deployed. You should be as specific as possible here, including exact version numbers (e.g., scikit-learn==1.2.2, pandas==2.0.3). This ensures that the model runs with the same dependencies it was trained with.
The next step is often to integrate this custom flavor into an MLflow Pipeline for more complex workflows.