The example uses on two parallel encoders, one to encode the query and the other for the document(s).
## Training Setup with Contrastive Loss
Uses the embedding state at the CLS token.
```python
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel
class BiEncoder(nn.Module):
def __init__(self, model_name="answerdotai/ModernBERT-base"):
super().__init__()
self.query_encoder = AutoModel.from_pretrained(model_name)
self.doc_encoder = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def encode_query(self, queries):
inputs = self.tokenizer(queries, padding=True, truncation=True, return_tensors="pt")
outputs = self.query_encoder(**inputs)
return outputs.last_hidden_state[:, 0] # Use CLS token
def encode_doc(self, docs):
inputs = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt")
outputs = self.doc_encoder(**inputs)
return outputs.last_hidden_state[:, 0] # Use CLS token
def forward(self, queries, docs):
query_emb = self.encode_query(queries)
doc_emb = self.encode_doc(docs)
return query_emb, doc_emb
# Contrastive loss function remains unchanged
def contrastive_loss(query_emb, doc_emb, margin=0.2):
scores = nn.functional.cosine_similarity(query_emb.unsqueeze(1), doc_emb.unsqueeze(0), dim=2)
batch_size = query_emb.size(0)
labels = torch.arange(batch_size)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(scores, labels)
return loss
# Example training loop
model = BiEncoder()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
queries = ["What is AI?", "Define machine learning"]
docs = ["Artificial intelligence is a branch of computer science.", "Machine learning is a subset of AI."]
for epoch in range(3):
model.train()
optimizer.zero_grad()
query_emb, doc_emb = model(queries, docs)
loss = contrastive_loss(query_emb, doc_emb)
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# Save the tokenizer
model.tokenizer.save_pretrained('modernbert_encoder_tokenizer')
# Save query encoder
model.query_encoder.save_pretrained('modernbert_query_encoder')
# Save doc encoder
model.doc_encoder.save_pretrained('modernbert_doc_encoder')
```
## Inference Setup with Cosine Similarity
```python
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
# Load encoders and tokenizer
query_encoder = AutoModel.from_pretrained('modernbert_query_encoder')
doc_encoder = AutoModel.from_pretrained('modernbert_doc_encoder')
tokenizer = AutoTokenizer.from_pretrained('modernbert_encoder_tokenizer')
def encode_query(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = query_encoder(**inputs)
return outputs.last_hidden_state[:, 0] # CLS token
def encode_doc(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
outputs = doc_encoder(**inputs)
return outputs.last_hidden_state[:, 0] # CLS token
# Run inference
query_texts = ["What is AI?"]
doc_texts = ["Artificial intelligence is a branch of computer science."]
query_emb = encode_query(query_texts)
doc_emb = encode_doc(doc_texts)
# Compute similarity (cosine)
similarity = F.cosine_similarity(query_emb, doc_emb)
print("Similarity:", similarity.item())
```