# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-Apache2 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model with and without transformer engine layers. This file is a modified version of the BERT model from the Hugging Face Transformers library. It includes a custom BERT encoder that can be used with or without transformer engine layers. The BERT encoder is a modified version of the encoder from the Hugging Face Transformers library. It includes a custom BERT layer that can be used with or without transformer engine layers. """ from typing import ClassVar, List, Optional, Tuple, Union import torch import transformer_engine.pytorch as te from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, MaskedLMOutput, ) from transformers.models.bert.configuration_bert import BertConfig from transformers.models.bert.modeling_bert import ( BertEmbeddings, BertLayer, BertOnlyMLMHead, BertPooler, BertPreTrainedModel, ) from transformers.utils import logging logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased" _CONFIG_FOR_DOC = "BertConfig" class TEBertConfig(BertConfig): """Configuration class for the TE BERT model. This class is a subclass of BertConfig, and it adds the following attributes: - torch_dtype: The dtype of the model parameters. - use_te_layers: Whether to use the TE layers. - micro_batch_size: The micro batch size for TE layers. """ def __init__(self, **kwargs): """Initialize the TEBertConfig. Args: **kwargs: Additional keyword arguments to pass to BertConfig. """ super().__init__(**kwargs) # TODO(@jomitchell): Fix this in JIRA BIONEMO-2406 torch_dtype = kwargs.get("torch_dtype", torch.bfloat16) # Convert string dtype to torch dtype if needed if isinstance(torch_dtype, str): if torch_dtype == "bfloat16": torch_dtype = torch.bfloat16 elif torch_dtype == "float16": torch_dtype = torch.float16 elif torch_dtype == "float32": torch_dtype = torch.float32 else: raise ValueError(f"Unsupported dtype: {torch_dtype}") self.torch_dtype = torch_dtype self.use_te_layers = kwargs.get("use_te_layers", False) self.micro_batch_size = kwargs.get("micro_batch_size", None) self.fuse_qkv_params = kwargs.get("fuse_qkv_params", False) class TEBertLayer(nn.Module): """Custom BERT layer using individual TE components for correct post-norm architecture. This builds a BERT-style post-norm layer using: - te.MultiheadAttention (with input_layernorm=False) - te.LayerNorm for post-attention normalization as layernorm - te.Linear for MLP layers (fc1, fc2) wrapped in layernorm_mlp module - te.LayerNorm for post-MLP normalization as layernorm_mlp.layer_norm Parameter naming matches convert.py expectations for weight loading from HF checkpoints. DIVERGENCE FROM TYPICAL TRANSFORMERLAYER: This implementation uses POST-norm architecture, which differs significantly from the typical TransformerLayer that uses PRE-norm. Geneformer/HF BERT (POST-norm, output_layernorm=True equivalent): Input -> Attention -> Dropout -> Residual Add -> LayerNorm -> MLP -> Dropout -> Residual Add -> LayerNorm -> Output Typical TransformerLayer (PRE-norm, output_layernorm=False default): Input -> [LayerNorm Attn inside MultiheadAttention] -> Dropout -> Residual Add -> [LayerNorm MLP inside LayerNormMLP] -> Dropout -> Residual Add -> Output Geneformer applies LayerNorm AFTER residual connections as explicit separate modules, whereas typical TransformerLayer applies LayerNorm Before operations via input_layernorm=True inside MultiheadAttention and LayerNormMLP modules. For more information, see: https://github.com/NVIDIA/TransformerEngine/blob/dd9433e7ad28c12f27da9770be54c9c584e85fa0/transformer_engine/pytorch/transformer.py#L822 """ def __init__(self, config, layer_number=None): """Initialize the TEBertLayer. Args: config: Configuration object containing model parameters. layer_number: Optional layer number for identification. """ super().__init__() self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads self.layer_number = layer_number self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention # Self-attention using TE MultiheadAttention self.self_attention = te.MultiheadAttention( hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_gqa_groups=config.num_attention_heads, attention_dropout=config.attention_probs_dropout_prob, input_layernorm=False, # No LayerNorm before attention attention_type="self", layer_number=layer_number, attn_mask_type="padding", params_dtype=config.torch_dtype, fuse_qkv_params=getattr(config, "fuse_qkv_params", False), window_size=(-1, -1), # No sliding window attention qkv_format="bshd", # BERT uses [batch, seq, head, dim] ) # Post-attention TE LayerNorm self.layernorm = te.LayerNorm( normalized_shape=config.hidden_size, eps=config.layer_norm_eps, params_dtype=config.torch_dtype, ) # MLP using TE Linear layers self.layernorm_mlp = nn.Module() self.layernorm_mlp.fc1 = te.Linear( config.hidden_size, config.intermediate_size, bias=True, params_dtype=config.torch_dtype, ) if config.hidden_act != "relu": raise ValueError(f"Geneformer requires hidden_act='relu', got '{config.hidden_act}'") self.layernorm_mlp.activation = nn.ReLU() self.layernorm_mlp.fc2 = te.Linear( config.intermediate_size, config.hidden_size, bias=True, params_dtype=config.torch_dtype, ) # Post-MLP LayerNorm self.layernorm_mlp.layer_norm = te.LayerNorm( normalized_shape=config.hidden_size, eps=config.layer_norm_eps, params_dtype=config.torch_dtype, ) # Dropout self.attention_dropout = nn.Dropout(config.hidden_dropout_prob) self.mlp_dropout = nn.Dropout(config.hidden_dropout_prob) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: """Forward pass through the TE BERT layer. Architecture Input → Self-Attention → Dropout → Residual Connection → LayerNorm → MLP → Dropout → Residual Connection → LayerNorm → Output This architecture is the key divergence from typical TransformerLayer (with output_layernorm=False default) which uses PRE-norm. In PRE-norm TransformerLayer, LayerNorm is applied Before operations: - MultiheadAttention with input_layernorm=True applies LayerNorm internally before attention - LayerNormMLP applies LayerNorm internally before MLP - Residuals bypass these internal LayerNorms In Geneformer's POST-norm, LayerNorm is applied after residual connections as explicit separate modules, meaning the normalized output flows to the next layer. Args: hidden_states: Input hidden states. attention_mask: Attention mask. head_mask: Head mask. encoder_hidden_states: Encoder hidden states. encoder_attention_mask: Encoder attention mask. past_key_value: Past key value. output_attentions: Whether to output attentions. Returns: Tuple of tensors containing the layer output. """ # Attention mask handling for TE MultiheadAttention, [batch, 1, 1, seq_len], True=masked, False=attend te_attention_mask = None te_mask_type = "no_mask" if attention_mask is not None: # Check if there's actual padding (not all 1s for 2D or not all 0s for 4D) if attention_mask.dim() == 2: # Standard [batch, seq_len] where 1=attend, 0=masked has_padding = not torch.all(attention_mask == 1) if has_padding: # Convert to TE format: [batch, 1, 1, seq_len], invert polarity te_attention_mask = ~attention_mask.bool().unsqueeze(1).unsqueeze(1) te_mask_type = "padding" elif attention_mask.dim() in [3, 4]: # Extended mask with -inf for masked positions has_masking = torch.any( attention_mask < -10000.0 ) # Check if it's not a trivial mask (all zeros/no masking) if has_masking: # Extract padding mask and convert to TE format if attention_mask.dim() == 4: padding_mask = attention_mask[:, 0, 0, :] # [batch, seq_len] else: # dim == 3 padding_mask = attention_mask[:, 0, :] # [batch, seq_len] # -inf to True (masked), 0 to False (attend) # Then reshape to [batch, 1, 1, seq_len] te_attention_mask = (padding_mask < -10000.0).unsqueeze(1).unsqueeze(1) te_mask_type = "padding" # Self-Attention sub-layer attention_output = self.self_attention( hidden_states, attention_mask=te_attention_mask, attn_mask_type=te_mask_type, ) # Residual connection + dropout + LayerNorm (POST-norm) attention_output = self.attention_dropout(attention_output) hidden_states = hidden_states + attention_output hidden_states = self.layernorm(hidden_states) # MLP sub-layer mlp_output = self.layernorm_mlp.fc1(hidden_states) mlp_output = self.layernorm_mlp.activation(mlp_output) mlp_output = self.layernorm_mlp.fc2(mlp_output) # Residual connection + dropout + LayerNorm (POST-norm) mlp_output = self.mlp_dropout(mlp_output) hidden_states = hidden_states + mlp_output hidden_states = self.layernorm_mlp.layer_norm(hidden_states) return (hidden_states,) class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config if self.config.use_te_layers: self.layer = nn.ModuleList( [TEBertLayer(config, layer_number=i + 1) for i in range(config.num_hidden_layers)] ) else: self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def _process_layer_outputs( self, layer_outputs, hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, output_hidden_states, output_attentions, use_cache, next_decoder_cache, ): """Process outputs from a single layer.""" hidden_states = layer_outputs[0] if use_cache and next_decoder_cache is not None: next_decoder_cache = (*next_decoder_cache, layer_outputs[-1]) if output_attentions and len(layer_outputs) > 1: if all_self_attentions is None: all_self_attentions = (layer_outputs[1],) else: all_self_attentions = (*all_self_attentions, layer_outputs[1]) if self.config.add_cross_attention and len(layer_outputs) > 2: if all_cross_attentions is None: all_cross_attentions = (layer_outputs[2],) else: all_cross_attentions = (*all_cross_attentions, layer_outputs[2]) return hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False next_decoder_cache = () if use_cache else None for i, layer_module in enumerate(self.layer): if output_hidden_states: if all_hidden_states is None: all_hidden_states = (hidden_states,) else: all_hidden_states = (*all_hidden_states, hidden_states) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: from torch.utils.checkpoint import checkpoint layer_outputs = checkpoint( layer_module, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, use_reentrant=False, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, next_decoder_cache = ( self._process_layer_outputs( layer_outputs, hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions, output_hidden_states, output_attentions, use_cache, next_decoder_cache, ) ) if output_hidden_states: if all_hidden_states is None: all_hidden_states = (hidden_states,) else: all_hidden_states = (*all_hidden_states, hidden_states) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) class BertModel(BertPreTrainedModel): """BERT model for encoding and decoding. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in [Attention is all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. """ config_class = TEBertConfig # TODO(@jomitchell) Can start swapping layers here for TE layers. _no_split_modules: ClassVar[List[str]] = ["BertEmbeddings", "BertLayer", "TEBertLayer"] def __init__(self, config, add_pooling_layer=True): """Initialize the BertModel. Args: config: Configuration object containing model parameters. add_pooling_layer: Whether to add a pooling layer on top of the encoder. """ super().__init__(config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = BertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.attn_implementation = config._attn_implementation self.position_embedding_type = config.position_embedding_type # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): """Get the input embeddings.""" return self.embeddings.word_embeddings def set_input_embeddings(self, value): """Set the input embeddings.""" self.embeddings.word_embeddings = value def _prune_heads(self, heads_to_prune): """Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel. """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def _validate_and_prepare_inputs( self, input_ids, inputs_embeds, attention_mask, token_type_ids, position_ids, past_key_values, ): """Validate inputs and prepare basic input data.""" if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) return ( input_shape, batch_size, seq_length, device, past_key_values_length, token_type_ids, embedding_output, attention_mask, ) def _prepare_attention_masks( self, attention_mask, input_shape, embedding_output, past_key_values_length, seq_length, device, head_mask, output_attentions, encoder_hidden_states, encoder_attention_mask, ): """Prepare attention masks for the forward pass.""" use_sdpa_attention_masks = ( self.attn_implementation == "sdpa" and self.position_embedding_type == "absolute" and head_mask is None and not output_attentions ) # Expand the attention mask if use_sdpa_attention_masks and attention_mask.dim() == 2: # Expand the attention mask for SDPA. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] if self.config.is_decoder: extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, embedding_output, past_key_values_length, ) else: extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( attention_mask, embedding_output.dtype, tgt_len=seq_length ) else: # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: # Expand the attention mask for SDPA. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length ) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None return extended_attention_mask, encoder_extended_attention_mask def _prepare_inputs_and_masks( self, input_ids, inputs_embeds, attention_mask, token_type_ids, position_ids, head_mask, past_key_values, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, use_cache, return_dict, ): """Prepare inputs and attention masks for the forward pass.""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False ( input_shape, batch_size, seq_length, device, past_key_values_length, token_type_ids, embedding_output, attention_mask, ) = self._validate_and_prepare_inputs( input_ids, inputs_embeds, attention_mask, token_type_ids, position_ids, past_key_values, ) extended_attention_mask, encoder_extended_attention_mask = self._prepare_attention_masks( attention_mask, input_shape, embedding_output, past_key_values_length, seq_length, device, head_mask, output_attentions, encoder_hidden_states, encoder_attention_mask, ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] processed_head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) return ( embedding_output, extended_attention_mask, processed_head_mask, encoder_extended_attention_mask, use_cache, return_dict, ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: r"""Forward pass of the BertModel. Args: input_ids (`torch.Tensor`, *optional*): Input token IDs. attention_mask (`torch.Tensor`, *optional*): Attention mask. token_type_ids (`torch.Tensor`, *optional*): Token type IDs. position_ids (`torch.Tensor`, *optional*): Position IDs. head_mask (`torch.Tensor`, *optional*): Head mask. inputs_embeds (`torch.Tensor`, *optional*): Input embeddings. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding (see `past_key_values`). use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether to output attentions. output_hidden_states (`bool`, *optional*): Whether to output hidden states. return_dict (`bool`, *optional*): Whether to return a ModelOutput instead of a tuple. """ ( embedding_output, extended_attention_mask, processed_head_mask, encoder_extended_attention_mask, use_cache, return_dict, ) = self._prepare_inputs_and_masks( input_ids, inputs_embeds, attention_mask, token_type_ids, position_ids, head_mask, past_key_values, encoder_hidden_states, encoder_attention_mask, output_attentions, output_hidden_states, use_cache, return_dict, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=processed_head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output, *encoder_outputs[1:]) return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) class BertForMaskedLM(BertPreTrainedModel): """BERT model for masked language modeling.""" config_class = TEBertConfig _tied_weights_keys: ClassVar[List[str]] = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): """Initialize the BertForMaskedLM. Args: config: Configuration object containing model parameters. """ super().__init__(config) if config.is_decoder: logger.warning( "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " "bi-directional self-attention." ) self.bert = BertModel(config, add_pooling_layer=False) self.cls = BertOnlyMLMHead(config) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): """Get the output embeddings.""" return self.cls.predictions.decoder def set_output_embeddings(self, new_embeddings): """Set the output embeddings.""" self.cls.predictions.decoder = new_embeddings self.cls.predictions.bias = new_embeddings.bias def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: r"""Forward pass for masked language modeling. Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) masked_lm_loss = None if labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (prediction_scores, *outputs[2:]) return (masked_lm_loss, *output) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): """Prepare inputs for generation.""" input_shape = input_ids.shape effective_batch_size = input_shape[0] # add a dummy token if self.config.pad_token_id is None: raise ValueError("The PAD token should be defined for generation") attention_mask = torch.cat( [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1, ) dummy_token = torch.full( (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device, ) input_ids = torch.cat([input_ids, dummy_token], dim=1) return {"input_ids": input_ids, "attention_mask": attention_mask} @classmethod def can_generate(cls) -> bool: """Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`. Even though it has a `prepare_inputs_for_generation` method. """ return False __all__ = [ "BertForMaskedLM", "BertLayer", "BertModel", "TEBertLayer", ]