The most surprising thing about safetensors is that it’s not just about security; it’s fundamentally a faster, more efficient way to serialize and deserialize tensors, making security a welcome, almost incidental, benefit.

Let’s see it in action. Imagine you’ve trained a small model or fine-tuned an existing one. You want to save its weights.

from transformers import AutoModelForCausalLM
from safetensors.torch import save_file
import torch

# Load a small model for demonstration
model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id)

# Get the state dictionary
state_dict = model.state_dict()

# Save the state dictionary using safetensors
save_file(state_dict, "my_model.safetensors")

print("Model weights saved to my_model.safetensors")

Now, you want to load these weights back, perhaps into a fresh instance of the same model architecture.

from transformers import AutoModelForCausalLM
from safetensors.torch import load_file
import torch

# Load the model architecture
model_id = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_id)

# Load the weights from the safetensors file
weights = load_file("my_model.safetensors")

# Load the weights into the model
model.load_state_dict(weights)

print("Model weights loaded from my_model.safetensors")

This looks pretty standard, right? The magic isn’t in the save_file and load_file calls themselves, but in what they’re doing under the hood.

The problem safetensors solves is the inherent insecurity and inefficiency of Python’s pickle module, which is the default for saving model weights in PyTorch (torch.save). pickle can execute arbitrary Python code during deserialization. This means a malicious actor could craft a pickle file that, when loaded, runs any code they want on your machine – think data exfiltration, ransomware, or installing backdoors. This is a huge risk when downloading models from untrusted sources.

safetensors, on the other hand, is designed with a simple, explicit format. It’s essentially a JSON header describing the tensors (their names, shapes, data types, and offsets within the file) followed by the raw binary data of the tensors themselves. There’s no executable code. When you load a safetensors file, the library reads the JSON header, validates the tensor descriptions, and then directly maps the binary tensor data into memory. This mapping is incredibly fast because it avoids deserialization overhead and can even leverage memory mapping techniques.

Here’s the mental model:

  1. Serialization (Saving): You have a Python dictionary where keys are tensor names (strings) and values are torch.Tensor objects. safetensors iterates through this dictionary. For each tensor, it records its metadata (name, shape, dtype, device) and then appends its raw bytes to a growing buffer. Finally, it serializes this metadata into a JSON string and prepends it to the buffer, along with a magic string and the size of the JSON header.
  2. Deserialization (Loading): The load_file function first reads a fixed-size header to get the length of the JSON metadata. It then reads the JSON metadata itself. It parses this JSON to understand all the tensors and their locations within the file. Crucially, it then uses memory-mapping (or efficient loading) to expose the raw tensor data directly. These raw bytes are then reinterpreted as tensors according to the metadata.

The key advantage here is that the file format is strictly data-descriptive. There’s no embedded logic. This makes it inherently safe. You can be confident that loading a safetensors file will only give you tensor data, not a surprise Python script. The performance gains come from avoiding the pickle overhead and enabling direct memory mapping, which is especially beneficial for large models.

The one thing most people don’t realize is that safetensors can be used with NumPy arrays and JAX arrays too, not just PyTorch. The core safetensors library is framework-agnostic; it’s the safetensors.torch or safetensors.numpy modules that provide the integration for specific libraries. This means you can save a PyTorch model’s weights and load them into a JAX model (provided the architectures match) without an intermediate conversion step, as long as you use the appropriate safetensors module for loading.

The next step is understanding how to handle larger models that might not fit into memory all at once, even with safetensors’ efficient loading.

Want structured learning?

Take the full Huggingface course →