PyTorch on Apple Silicon with Metal MPS isn’t just about running PyTorch faster; it’s about sidestepping the entire CPU-GPU divide that’s plagued machine learning for years.
Let’s see it in action. Imagine you have a simple convolutional neural network and a batch of dummy data. Normally, you’d be carefully choosing between torch.device("cpu") and torch.device("cuda"). With MPS, it’s just one more option:
import torch
import torch.nn as nn
# Check if MPS is available
if torch.backends.mps.is_available():
device = torch.device("mps")
print("Using MPS device")
elif torch.cuda.is_available():
device = torch.device("cuda")
print("Using CUDA device")
else:
device = torch.device("cpu")
print("Using CPU device")
# Define a simple model
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 16 * 16, 10) # Assuming input image size 3x32x32
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1) # Flatten
x = self.fc(x)
return x
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Dummy data
batch_size = 64
channels = 3
height = 32
width = 32
dummy_input = torch.randn(batch_size, channels, height, width).to(device)
dummy_labels = torch.randint(0, 10, (batch_size,)).to(device)
# Training step
optimizer.zero_grad()
outputs = model(dummy_input)
loss = criterion(outputs, dummy_labels)
loss.backward()
optimizer.step()
print(f"Loss: {loss.item()}")
This code looks almost identical to a standard PyTorch training loop. The key is the .to(device) call. When device is set to "mps", PyTorch leverages Apple’s Metal Performance Shaders (MPS) framework to run your tensors and operations directly on the Apple Silicon’s integrated GPU. This means no more wrestling with CUDA drivers, NVIDIA hardware, or the separate CUDA compilation step. Your Mac is the accelerator.
The problem MPS solves is the friction in using integrated GPUs for deep learning. Historically, integrated graphics were an afterthought for serious ML workloads, primarily because they lacked the mature software ecosystem and raw power of discrete NVIDIA GPUs. Apple’s M-series chips, however, represent a significant shift. They combine CPU, GPU, and Neural Engine cores on a single SoC with unified memory, offering a powerful, energy-efficient platform. MPS is the software bridge that allows PyTorch to tap into this GPU power, making it accessible without the usual hardware barriers.
Internally, MPS translates PyTorch operations into Metal API calls. Metal is Apple’s low-level graphics and compute API. When you move a tensor to the MPS device (.to("mps")), PyTorch doesn’t copy it to a separate GPU memory; it keeps it in the unified memory accessible by both CPU and GPU. The operations then execute on the GPU cores via MPS. This unified memory architecture is a huge win for performance, as it eliminates the costly data transfers that often bottleneck CPU-GPU communication in traditional setups. For many common deep learning operations like convolutions, matrix multiplications, and activations, MPS provides highly optimized kernels that run directly on the Apple Silicon GPU.
The primary lever you control is the device object. By setting torch.backends.mps.is_available() to True and then creating torch.device("mps"), you’re directing PyTorch to use the Metal backend. This affects every tensor and every module you move to this device. You can conditionally switch between MPS, CUDA (if available), and CPU based on hardware presence, allowing for portable code. The torch.backends.mps module also exposes some configuration options, though they are less frequently tuned than, say, CUDA memory allocation.
One aspect that often trips people up is that not all PyTorch operations are currently optimized or even supported by MPS. While the common ones are, you might encounter situations where certain custom layers, less common operations, or specific data types might fall back to the CPU. PyTorch’s MPS backend is designed to handle this gracefully. If an operation isn’t supported on MPS, it will automatically fall back to using the CPU for that specific operation, often without explicit notification. This means your model might still run, but performance could be suboptimal. It’s crucial to monitor performance and, if you suspect a bottleneck, check the PyTorch MPS documentation or GitHub issues for support status of specific operators.
The next hurdle is understanding how to profile MPS performance to identify these potential CPU fallbacks.