Spot instances can slash your ML training costs by up to 70% because they leverage AWS’s spare EC2 capacity, which is offered at a steep discount.
Let’s see this in action. Imagine you’re training a large language model, a process that can easily run into tens of thousands of dollars on on-demand instances.
Here’s a typical on-demand setup for a PyTorch training job on p3.8xlarge instances:
# On-demand training script (simplified)
import torch
import torch.distributed as dist
import os
def run_training():
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# ... model definition, data loading ...
for epoch in range(num_epochs):
# Training loop
pass
if __name__ == "__main__":
run_training()
To run this, you’d launch EC2 instances with AMIs pre-configured with CUDA, PyTorch, and your training code, setting environment variables for distributed training. For a cluster of 8 p3.8xlarge instances, each costing approximately $3.06 per hour, a 24-hour training run would cost:
8 instances * $3.06/hour * 24 hours = $587.52
Now, let’s switch to Spot Instances. The same p3.8xlarge instances on Spot can be obtained for around $0.92 per hour.
8 instances * $0.92/hour * 24 hours = $176.64
That’s a 70% saving, just like that.
The core problem Spot Instances solve for ML training is the prohibitive cost of GPU compute. Training large models requires significant GPU hours, and on-demand pricing makes this accessible only to well-funded organizations. Spot Instances democratize access to powerful hardware by utilizing underutilized capacity.
Internally, AWS has a pool of GPU instances that aren’t being used by on-demand customers. When you request a Spot Instance, you’re essentially bidding on this spare capacity. The price fluctuates based on supply and demand for that specific instance type in that specific Availability Zone. If the Spot price rises above your maximum bid (or the current on-demand price, if you don’t set a bid), AWS can reclaim the instance with a two-minute warning.
The key levers you control are:
- Instance Type: Choose GPUs that match your model’s needs (e.g.,
p3.8xlarge,g4dn.12xlarge,p4d.24xlarge). - Region/Availability Zone: Spot prices vary significantly between AZs.
- Maximum Spot Price: You can set a maximum price you’re willing to pay. If the current Spot price exceeds this, your instance will be terminated. For ML workloads, setting this to the on-demand price is common to maximize uptime.
- Interruption Handling: This is critical. Your training script must be able to checkpoint its progress frequently and resume from the last checkpoint upon interruption.
Here’s how you’d launch Spot Instances using the AWS CLI:
aws ec2 request-spot-instances \
--instance-count 8 \
--type "one-time" \
--launch-specification '{
"ImageId": "ami-0abcdef1234567890",
"InstanceType": "p3.8xlarge",
"KeyName": "my-key-pair",
"Placement": {
"AvailabilityZone": "us-east-1a"
},
"BlockDeviceMappings": [
{
"DeviceName": "/dev/sda1",
"Ebs": {
"VolumeSize": 200,
"VolumeType": "gp3",
"DeleteOnTermination": true
}
}
],
"NetworkInterfaces": [
{
"DeviceIndex": 0,
"SubnetId": "subnet-0123456789abcdef0",
"Groups": ["sg-0fedcba9876543210"],
"AssociatePublicIpAddress": true
}
],
"UserData": "#!/bin/bash\napt update && apt install -y python3-pip\npip3 install torch torchvision torchaudio\n# ... rest of your setup script ..."
}' \
--tag-specifications 'ResourceType=spot-instances-request,Tags=[{Key=Name,Value=ML-Training-Spot}]' \
--region us-east-1
The UserData script would download your training code, install dependencies, and launch your distributed training job. Crucially, your training script needs to handle the SIGTERM signal sent by AWS before termination, saving its state to persistent storage (like S3 or EFS).
# Checkpointing in PyTorch
import torch
import os
def save_checkpoint(model, optimizer, epoch, filepath):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, filepath)
print(f"Checkpoint saved to {filepath}")
def load_checkpoint(model, optimizer, filepath):
if os.path.exists(filepath):
checkpoint = torch.load(filepath)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(f"Checkpoint loaded from {filepath}, resuming epoch {epoch}")
return epoch
return 0
# Inside your training loop:
checkpoint_path = "/mnt/efs/model_checkpoint.pth" # Example path on shared storage
start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
for epoch in range(start_epoch, num_epochs):
# ... training ...
if epoch % 5 == 0: # Save every 5 epochs
save_checkpoint(model, optimizer, epoch, checkpoint_path)
# Handle SIGTERM for graceful shutdown and final save
import signal
def handler(signum, frame):
print("Termination signal received. Saving final checkpoint...")
save_checkpoint(model, optimizer, epoch, checkpoint_path)
exit(0)
signal.signal(signal.SIGTERM, handler)
The most surprising benefit of using Spot Instances for ML training, beyond the cost savings, is the ability to scale up your compute resources dramatically. Because the cost per hour is so low, you can afford to provision a much larger cluster of GPUs than you could with on-demand pricing. This allows for faster experimentation and the training of larger, more complex models that would otherwise be out of reach.
The immediate next challenge you’ll face is ensuring your training job is robust enough to handle frequent interruptions, which means mastering fault tolerance and distributed checkpointing.