Fine-tuning a Named Entity Recognition (NER) model isn’t about teaching it new words; it’s about teaching it a new context for recognizing existing entities.
Let’s see this in action. Imagine we have a dataset of news articles and we want to extract PERSON, ORG, and LOCATION entities. We’ll start with a pre-trained model like bert-base-uncased and fine-tune it on our specific data.
Here’s a snippet of what our training data might look like, using the IOB2 format:
[
{"tokens": ["John", "Doe", "works", "at", "Google", "in", "New", "York", "."], "ner_tags": [1, 2, 0, 0, 3, 0, 4, 5, 0]},
{"tokens": ["The", "United", "Nations", "held", "a", "meeting", "in", "Geneva", "."], "ner_tags": [0, 3, 3, 0, 0, 0, 0, 4, 0]}
]
In this format:
0means "Outside" any entity.1means "Beginning" of an entity.2means "Inside" an entity.3,4,5represent the specific entity types (ORG,LOCATIONin this case, withPERSONbeing1and2).
We’ll use the Hugging Face datasets library to load and preprocess this data, and the transformers library for the model and training.
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
import numpy as np
import evaluate
# Load dataset (replace with your actual data loading)
# For demonstration, let's assume a dummy dataset is created
# In a real scenario, you'd load from JSON, CSV, etc.
# Example: dataset = load_dataset("json", data_files="your_data.json")
# Dummy data for demonstration
data = {
"train": [
{"tokens": ["John", "Doe", "works", "at", "Google", "in", "New", "York", "."], "ner_tags": [1, 2, 0, 0, 3, 0, 4, 5, 0]},
{"tokens": ["The", "United", "Nations", "held", "a", "meeting", "in", "Geneva", "."], "ner_tags": [0, 3, 3, 0, 0, 0, 0, 4, 0]},
{"tokens": ["Alice", "Smith", "visited", "Paris", "last", "week", "."], "ner_tags": [1, 2, 0, 4, 0, 0, 0]}
],
"validation": [
{"tokens": ["Bob", "Johnson", "is", "CEO", "of", "Microsoft", "."], "ner_tags": [1, 2, 0, 0, 0, 3, 0]}
]
}
dataset = {"train": data["train"], "validation": data["validation"]} # Simplified structure
# Define entity labels and their mapping
label_list = ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for i, label in enumerate(label_list)}
# Load tokenizer and model
model_checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
# Preprocessing function
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)
labels = []
for i, label in enumerate(examples["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100) # -100 is a special token for ignoring labels
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
else:
# For sub-word tokens, we assign the label of the first sub-word
# or a specific tag for subsequent sub-words if needed (e.g., I-TAG)
# Here, we simplify by assigning the same tag as the first sub-word for simplicity
# A more robust approach might assign I-TAG for subsequent sub-words of the same entity
label_ids.append(label[word_idx])
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
# Apply preprocessing
# Convert list of dicts to a Dataset object if not already
from datasets import Dataset, DatasetDict
train_dataset = Dataset.from_list(dataset["train"])
val_dataset = Dataset.from_list(dataset["validation"])
dataset_dict = DatasetDict({"train": train_dataset, "validation": val_dataset})
tokenized_datasets = dataset_dict.map(tokenize_and_align_labels, batched=True)
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
save_strategy="epoch",
load_best_model_at_end=True,
)
# Metrics computation
metric = evaluate.load("seqeval")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
# Remove ignored index (-100)
true_predictions = [
[label_list[p] for (p, l) in zip(pred, lab) if l != -100]
for pred, lab in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(pred, lab) if l != -100]
for pred, lab in zip(predictions, labels)
]
results = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
# Train the model
trainer.train()
# Example prediction
text = "Hugging Face is a company based in New York."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
predicted_labels = [id_to_label[p.item()] for p in predictions[0]]
print(list(zip(tokens, predicted_labels)))
The core problem this solves is that general-purpose language models, while understanding grammar and semantics, don’t inherently know what constitutes a "company name" versus a "city name" without specific examples. Fine-tuning adapts the model’s internal representations to recognize patterns and contexts that signify these specific entities in your domain. It’s like taking a brilliant student who knows all about history and giving them a specialized course on, say, 19th-century French literature – they learn to spot the nuances and references specific to that field.
Internally, the model learns to adjust its weights. For example, a pre-trained model might have a strong representation for the word "Apple." Without fine-tuning, it might not distinguish between "Apple" (the fruit) and "Apple" (the company). After fine-tuning on text where "Apple Inc." or "Apple announced…" appears, the model’s weights shift. Specifically, the layers processing words following "Apple" (like "Inc.", "announced", "shares") or words preceding it (like "the tech giant") become more influential in predicting the B-ORG or I-ORG tag. The [CLS] token embedding, which summarizes the sentence, also gets adjusted to better reflect the overall context relevant to NER.
The key levers you control are:
- Dataset Quality and Size: More diverse and accurately labeled data leads to better generalization. Annotating at least a few hundred examples per entity type is a good starting point.
- Pre-trained Model Choice: Models like RoBERTa or ELECTRA might offer better performance than BERT for certain tasks, depending on their pre-training objectives.
- Hyperparameters: Learning rate, batch size, and the number of epochs significantly impact convergence. Too high a learning rate can cause divergence, while too low can lead to slow training.
- Labeling Scheme: Using a consistent and well-defined scheme (like IOB2) is crucial. Inconsistent labels are a major source of poor performance.
- Tokenization Strategy: How you handle sub-word tokens (e.g., "HuggingFace" might be tokenized into "hugging" and "##face") is important. The
tokenize_and_align_labelsfunction above addresses this by mapping sub-word token predictions back to word-level labels, often ignoring predictions for subsequent sub-words of a single token.
A common pitfall is assuming that simply increasing the dataset size will always improve performance. If the added data is noisy or introduces conflicting labels, it can actually degrade the model’s accuracy. The model is learning to associate specific contextual patterns with entity tags; if those patterns are inconsistent in the training data, the model becomes confused. For instance, if "New York" is sometimes labeled LOC and sometimes ORG (perhaps in a fictional context), the model will struggle to make a definitive prediction.
The next step after fine-tuning is often deploying the model for inference, which involves understanding how to batch new text for efficient processing and how to interpret the model’s output probabilities.