import os import sys import math import pickle import random import torch import numpy as np import requests from .utils import get_suppression_coefficient from io import BytesIO from typing import Union, List, Optional, Any, Dict, Tuple, Callable from dataclasses import dataclass from PIL import Image from transformers import ( AutoModel, AutoTokenizer, AutoModelForCausalLM, AutoConfig, PreTrainedModel ) from transformers.models.qwen2_5_vl.processing_qwen2_5_vl import Qwen2_5_VLProcessor from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2ForCausalLM, Qwen2Config ) from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from transformers.generation import GenerationMixin from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, TokenClassifierOutput, ModelOutput, ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.processing_utils import Unpack from transformers.utils import ( # LossKwargs, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, can_return_tuple, logging, replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg # from qwen_vl_utils import process_vision_info from .vlm_unitok import UniTok from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.models.qwen2.modeling_qwen2 import * class StyleGenerator(Qwen2ForCausalLM): def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, code_freq: Any = None, code_freq_threshold: Any = None, k: Any=None, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) coefficient = get_suppression_coefficient(code_freq, code_freq_threshold, k).to(logits.device) logits[0][0] = logits[0][0] * coefficient return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) _CONFIG_FOR_DOC = "Qwen2_5_VLConfig" @dataclass class Qwen2_5_VLCausalLMOutputWithPastQuant(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None past_key_values: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None rope_deltas: Optional[torch.LongTensor] = None quant_info: Optional[Dict[str, Any]] = None class Qwen2_5_VLForConditionalGeneration_Quant(Qwen2_5_VLForConditionalGeneration): def forward( self, unitok: Optional[Any] = None, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, second_per_grid_ts: Optional[torch.Tensor] = None, codebook_id: Any = None, ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: unitok_info = {} output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.dtype) if codebook_id == None: image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) b_n, dim = image_embeds.shape image_embeds = image_embeds.reshape(b_n//196,196,dim) with torch.amp.autocast(device_type='cuda', enabled=False): output = unitok(image_embeds) image_embeds_recon, unitok_info = output['img_rec'].squeeze(), output image_embeds = image_embeds_recon.reshape(b_n, dim) else: image_embeds = unitok.quantizer.idx_to_f(codebook_id.unsqueeze(0).to(self.visual.device)) image_embeds = unitok.post_quant_proj(image_embeds).squeeze() n_image_tokens = (input_ids == self.config.image_token_id).sum().item() n_image_features = image_embeds.shape[0] if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) mask = input_ids == self.config.image_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) image_mask = mask_expanded.to(inputs_embeds.device) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) image_embeds_for_hook = image_embeds.clone() inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.dtype) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) mask = input_ids == self.config.video_token_id mask_unsqueezed = mask.unsqueeze(-1) mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) video_mask = mask_expanded.to(inputs_embeds.device) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask, ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = ( (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 ) position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Qwen2_5_VLCausalLMOutputWithPastQuant( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.rope_deltas, quant_info = unitok_info, ) class Qwen2_5_VL_Quant(nn.Module): def __init__(self, unitok, qwen2_5_vl): super().__init__() self.unitok = unitok self.qwen = qwen2_5_vl self.dtype = self.qwen.dtype def forward(self, input_ids, attention_mask, pixel_values=None, image_grid_thw=None, output_hidden_states=None, codebook_id=None, ): output = self.qwen( unitok = self.unitok, input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw, output_hidden_states=output_hidden_states, codebook_id=codebook_id, ) return output