| | |
| | |
| | """ |
| | Bidirectional Llama model for cross-encoder reranking. |
| | |
| | Modifies LlamaModel to use bidirectional (non-causal) attention so each token |
| | attends to all others — required for cross-encoder scoring of query-document pairs. |
| | |
| | Provides three classes: |
| | - LlamaBidirectionalConfig: Adds pooling and temperature to LlamaConfig. |
| | - LlamaBidirectionalModel: LlamaModel with causal masking replaced by |
| | bidirectional masking. Overrides forward() to support transformers >=4.44. |
| | - LlamaBidirectionalForSequenceClassification: Pools hidden states and |
| | projects to a relevance score via a linear head. |
| | |
| | Transformers version compatibility (>=4.44 including 5.0+): |
| | The forward() implementation handles these API changes at import time via |
| | inspect.signature() on LlamaDecoderLayer and DynamicCache: |
| | |
| | < 4.53: _update_causal_mask exists on LlamaModel (not used here). |
| | 4.53+: Masking moved to masking_utils; requires full forward() override. |
| | < 4.54: Decoder layer returns a tuple. |
| | 4.54+: Decoder layer returns a tensor. |
| | < 4.56: Cache kwarg is ``past_key_value`` (singular). |
| | 4.56+: Cache kwarg is ``past_key_values`` (plural); DynamicCache accepts config. |
| | 5.0+: Native ``create_bidirectional_mask`` in masking_utils. |
| | """ |
| |
|
| | import inspect |
| | from typing import Optional, Union, Tuple, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| | from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
| | from transformers.cache_utils import Cache, DynamicCache |
| | from transformers.modeling_outputs import BaseModelOutputWithPast |
| | from transformers.models.llama.configuration_llama import LlamaConfig |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaDecoderLayer, |
| | LlamaModel, |
| | LlamaPreTrainedModel, |
| | ) |
| | from transformers.utils import logging |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | |
| | try: |
| | from transformers.masking_utils import create_bidirectional_mask |
| |
|
| | _HAS_NATIVE_BIDIRECTIONAL_MASK = True |
| | except ImportError: |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| |
|
| | _HAS_NATIVE_BIDIRECTIONAL_MASK = False |
| |
|
| | |
| | _decoder_forward_params = inspect.signature(LlamaDecoderLayer.forward).parameters |
| | _dynamic_cache_init_params = inspect.signature(DynamicCache.__init__).parameters |
| |
|
| | |
| | _USE_PLURAL_CACHE_PARAM = "past_key_values" in _decoder_forward_params |
| | |
| | _DYNAMIC_CACHE_ACCEPTS_CONFIG = "config" in _dynamic_cache_init_params |
| |
|
| |
|
| | class LlamaBidirectionalConfig(LlamaConfig): |
| | """Configuration for LlamaBidirectionalModel with pooling and temperature settings.""" |
| |
|
| | model_type = "llama_bidirec" |
| |
|
| | def __init__( |
| | self, pooling: str = "avg", temperature: float = 1.0, **kwargs |
| | ) -> None: |
| | """ |
| | Initialize bidirectional Llama configuration. |
| | |
| | Args: |
| | pooling: Pooling strategy for embeddings ("avg", "cls", "last", etc.) |
| | temperature: Temperature scaling for embeddings |
| | **kwargs: Additional arguments passed to LlamaConfig |
| | """ |
| | self.pooling = pooling |
| | self.temperature = temperature |
| | super().__init__(**kwargs) |
| |
|
| |
|
| | class LlamaBidirectionalModel(LlamaModel): |
| | """ |
| | LlamaModel modified to use bidirectional (non-causal) attention. |
| | |
| | In standard Llama, each token can only attend to previous tokens (causal attention). |
| | This model removes that restriction, allowing each token to attend to all tokens |
| | in the sequence, which is useful for embedding tasks. |
| | |
| | The key modifications are: |
| | 1. Setting is_causal=False on all attention layers |
| | 2. Using a bidirectional attention mask instead of causal mask |
| | """ |
| |
|
| | config_class = LlamaBidirectionalConfig |
| |
|
| | def __init__(self, config: LlamaConfig) -> None: |
| | super().__init__(config) |
| | for layer in self.layers: |
| | layer.self_attn.is_causal = False |
| |
|
| | def _create_bidirectional_mask( |
| | self, |
| | input_embeds: torch.Tensor, |
| | attention_mask: torch.Tensor | None, |
| | ) -> torch.Tensor | None: |
| | """ |
| | Create bidirectional attention mask. |
| | |
| | Args: |
| | input_embeds: Input embeddings tensor of shape (batch_size, seq_len, hidden_size) |
| | attention_mask: Optional 2D attention mask of shape (batch_size, seq_len) |
| | where 1 indicates tokens to attend to and 0 indicates masked tokens |
| | |
| | Returns: |
| | 4D attention mask suitable for the attention implementation, or None |
| | if no masking is needed |
| | """ |
| | if attention_mask is None: |
| | return None |
| |
|
| | if _HAS_NATIVE_BIDIRECTIONAL_MASK: |
| | return create_bidirectional_mask( |
| | config=self.config, |
| | input_embeds=input_embeds, |
| | attention_mask=attention_mask, |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | if getattr(self.config, "_attn_implementation", None) == "flash_attention_2": |
| | has_masked_tokens = (attention_mask == 0).any() |
| | return attention_mask if has_masked_tokens else None |
| |
|
| | return _prepare_4d_attention_mask(attention_mask, input_embeds.dtype) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor | None = None, |
| | attention_mask: torch.Tensor | None = None, |
| | position_ids: torch.LongTensor | None = None, |
| | past_key_values: Cache | None = None, |
| | inputs_embeds: torch.FloatTensor | None = None, |
| | cache_position: torch.LongTensor | None = None, |
| | use_cache: bool | None = None, |
| | **kwargs, |
| | ) -> BaseModelOutputWithPast: |
| | """ |
| | Forward pass with bidirectional attention. |
| | |
| | Args: |
| | input_ids: Input token IDs of shape (batch_size, seq_len) |
| | attention_mask: Attention mask of shape (batch_size, seq_len) |
| | position_ids: Position IDs for rotary embeddings |
| | past_key_values: Cached key/value states for incremental decoding |
| | inputs_embeds: Pre-computed input embeddings (alternative to input_ids) |
| | cache_position: Position indices for cache updates |
| | use_cache: Whether to return cached key/value states |
| | **kwargs: Additional arguments passed to decoder layers |
| | |
| | Returns: |
| | BaseModelOutputWithPast containing last_hidden_state and past_key_values |
| | """ |
| | if (input_ids is None) ^ (inputs_embeds is not None): |
| | raise ValueError( |
| | "You must specify exactly one of input_ids or inputs_embeds" |
| | ) |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.embed_tokens(input_ids) |
| |
|
| | |
| | if use_cache and past_key_values is None: |
| | if _DYNAMIC_CACHE_ACCEPTS_CONFIG: |
| | past_key_values = DynamicCache(config=self.config) |
| | else: |
| | past_key_values = DynamicCache() |
| |
|
| | if cache_position is None: |
| | past_seen_tokens = ( |
| | past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | ) |
| | cache_position = torch.arange( |
| | past_seen_tokens, |
| | past_seen_tokens + inputs_embeds.shape[1], |
| | device=inputs_embeds.device, |
| | ) |
| |
|
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| |
|
| | bidirectional_mask = self._create_bidirectional_mask( |
| | inputs_embeds, attention_mask |
| | ) |
| |
|
| | hidden_states = inputs_embeds |
| | position_embeddings = self.rotary_emb(hidden_states, position_ids) |
| |
|
| | |
| | |
| | layer_kwargs = { |
| | "attention_mask": bidirectional_mask, |
| | "position_ids": position_ids, |
| | "use_cache": use_cache, |
| | "cache_position": cache_position, |
| | "position_embeddings": position_embeddings, |
| | } |
| | if _USE_PLURAL_CACHE_PARAM: |
| | layer_kwargs["past_key_values"] = past_key_values |
| | else: |
| | layer_kwargs["past_key_value"] = past_key_values |
| |
|
| | for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| | layer_outputs = decoder_layer(hidden_states, **layer_kwargs) |
| |
|
| | |
| | if isinstance(layer_outputs, tuple): |
| | hidden_states = layer_outputs[0] |
| | else: |
| | hidden_states = layer_outputs |
| |
|
| | hidden_states = self.norm(hidden_states) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=past_key_values, |
| | ) |
| |
|
| |
|
| | def pool( |
| | last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str |
| | ) -> torch.Tensor: |
| | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| |
|
| | if pool_type == "avg": |
| | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
| | elif pool_type == "weighted_avg": |
| | emb = last_hidden.sum(dim=1) |
| | elif pool_type == "cls": |
| | emb = last_hidden[:, 0] |
| | elif pool_type == "last": |
| | left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] |
| | if left_padding: |
| | emb = last_hidden[:, -1] |
| | else: |
| | sequence_lengths = attention_mask.sum(dim=1) - 1 |
| | batch_size = last_hidden.shape[0] |
| | emb = last_hidden[ |
| | torch.arange(batch_size, device=last_hidden.device), sequence_lengths |
| | ] |
| | else: |
| | raise ValueError(f"pool_type {pool_type} not supported") |
| |
|
| | return emb |
| |
|
| |
|
| | class LlamaBidirectionalForSequenceClassification(LlamaPreTrainedModel): |
| | config_class = LlamaBidirectionalConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
| | self.model = LlamaBidirectionalModel(config) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | **kwargs, |
| | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | transformer_outputs = self.model( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | **kwargs, |
| | ) |
| | hidden_states = transformer_outputs[0] |
| |
|
| | pooled_hidden_states = pool( |
| | last_hidden_states=hidden_states, |
| | attention_mask=attention_mask, |
| | pool_type=self.config.pooling, |
| | ) |
| |
|
| | pooled_logits = self.score(pooled_hidden_states) |
| | pooled_logits = pooled_logits / self.config.temperature |
| |
|
| | loss = None |
| | if labels is not None: |
| | labels = labels.to(pooled_logits.device) |
| | if self.config.problem_type is None: |
| | if self.num_labels == 1: |
| | self.config.problem_type = "regression" |
| | elif self.num_labels > 1 and ( |
| | labels.dtype == torch.long or labels.dtype == torch.int |
| | ): |
| | self.config.problem_type = "single_label_classification" |
| | else: |
| | self.config.problem_type = "multi_label_classification" |
| |
|
| | if self.config.problem_type == "regression": |
| | loss_fct = MSELoss() |
| | if self.num_labels == 1: |
| | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
| | else: |
| | loss = loss_fct(pooled_logits, labels) |
| | elif self.config.problem_type == "single_label_classification": |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct( |
| | pooled_logits.view(-1, self.num_labels), labels.view(-1) |
| | ) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = BCEWithLogitsLoss() |
| | loss = loss_fct(pooled_logits, labels) |
| | if not return_dict: |
| | output = (pooled_logits,) + transformer_outputs[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return SequenceClassifierOutputWithPast( |
| | loss=loss, |
| | logits=pooled_logits, |
| | past_key_values=transformer_outputs.past_key_values, |
| | hidden_states=transformer_outputs.hidden_states, |
| | attentions=transformer_outputs.attentions, |
| | ) |
| |
|