The example uses on two parallel encoders, one to encode the query and the other for the document(s). ## Training Setup with Contrastive Loss Uses the embedding state at the CLS token. ```python import torch from torch import nn from transformers import AutoTokenizer, AutoModel class BiEncoder(nn.Module): def __init__(self, model_name="answerdotai/ModernBERT-base"): super().__init__() self.query_encoder = AutoModel.from_pretrained(model_name) self.doc_encoder = AutoModel.from_pretrained(model_name) self.tokenizer = AutoTokenizer.from_pretrained(model_name) def encode_query(self, queries): inputs = self.tokenizer(queries, padding=True, truncation=True, return_tensors="pt") outputs = self.query_encoder(**inputs) return outputs.last_hidden_state[:, 0] # Use CLS token def encode_doc(self, docs): inputs = self.tokenizer(docs, padding=True, truncation=True, return_tensors="pt") outputs = self.doc_encoder(**inputs) return outputs.last_hidden_state[:, 0] # Use CLS token def forward(self, queries, docs): query_emb = self.encode_query(queries) doc_emb = self.encode_doc(docs) return query_emb, doc_emb # Contrastive loss function remains unchanged def contrastive_loss(query_emb, doc_emb, margin=0.2): scores = nn.functional.cosine_similarity(query_emb.unsqueeze(1), doc_emb.unsqueeze(0), dim=2) batch_size = query_emb.size(0) labels = torch.arange(batch_size) loss_fn = nn.CrossEntropyLoss() loss = loss_fn(scores, labels) return loss # Example training loop model = BiEncoder() optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5) queries = ["What is AI?", "Define machine learning"] docs = ["Artificial intelligence is a branch of computer science.", "Machine learning is a subset of AI."] for epoch in range(3): model.train() optimizer.zero_grad() query_emb, doc_emb = model(queries, docs) loss = contrastive_loss(query_emb, doc_emb) loss.backward() optimizer.step() print(f"Epoch {epoch}, Loss: {loss.item():.4f}") # Save the tokenizer model.tokenizer.save_pretrained('modernbert_encoder_tokenizer') # Save query encoder model.query_encoder.save_pretrained('modernbert_query_encoder') # Save doc encoder model.doc_encoder.save_pretrained('modernbert_doc_encoder') ``` ## Inference Setup with Cosine Similarity ```python from transformers import AutoTokenizer, AutoModel import torch.nn.functional as F # Load encoders and tokenizer query_encoder = AutoModel.from_pretrained('modernbert_query_encoder') doc_encoder = AutoModel.from_pretrained('modernbert_doc_encoder') tokenizer = AutoTokenizer.from_pretrained('modernbert_encoder_tokenizer') def encode_query(texts): inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") outputs = query_encoder(**inputs) return outputs.last_hidden_state[:, 0] # CLS token def encode_doc(texts): inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") outputs = doc_encoder(**inputs) return outputs.last_hidden_state[:, 0] # CLS token # Run inference query_texts = ["What is AI?"] doc_texts = ["Artificial intelligence is a branch of computer science."] query_emb = encode_query(query_texts) doc_emb = encode_doc(doc_texts) # Compute similarity (cosine) similarity = F.cosine_similarity(query_emb, doc_emb) print("Similarity:", similarity.item()) ```