| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from accelerate.hooks import add_hook_to_module |
| from einops import rearrange |
| from s2wrapper import forward as multiscale_forward |
| from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, SiglipImageProcessor |
| from transformers.image_processing_utils import BaseImageProcessor |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| from transformers.models.siglip import SiglipVisionModel |
|
|
|
|
| class VisionTower(nn.Module): |
| def __init__(self, vision_tower, args, delay_load=False): |
| super().__init__() |
|
|
| self.is_loaded = False |
|
|
| self.vision_tower_name = vision_tower |
| self.select_layer = getattr(args, "mm_vision_select_layer", -2) |
| self.select_feature = getattr(args, "mm_vision_select_feature", "patch") |
|
|
| self.cfg_only = None |
|
|
| def feature_select(self, image_forward_outs): |
| image_features = image_forward_outs.hidden_states[self.select_layer] |
| if self.select_feature == "patch": |
| image_features = image_features[:, 1:] |
| elif self.select_feature == "cls_patch": |
| image_features = image_features |
| else: |
| raise ValueError(f"Unexpected select feature: {self.select_feature}") |
| return image_features |
|
|
| def _maybe_resize_pos_embeds( |
| self, |
| model: PreTrainedModel, |
| image_processor: BaseImageProcessor, |
| resolution: int = -1, |
| interpolate_mode: str = "linear", |
| ): |
| if resolution in [model.config.image_size, -1]: |
| return |
| print( |
| f"Resizing vision model's position embeddings to support higher vision resolution: from {model.config.image_size} to {resolution} ..." |
| ) |
| embeddings = model.vision_model.embeddings |
| patch_size = embeddings.patch_size |
| num_new_tokens = int((resolution // patch_size) ** 2) |
|
|
| old_embeddings = embeddings.position_embedding |
| match interpolate_mode: |
| case "linear": |
| |
| |
| import torch |
| import torch.nn as nn |
|
|
| if is_deepspeed_zero3_enabled(): |
| try: |
| import deepspeed |
| except ImportError: |
| raise ImportError("DeepSpeed is not installed. Please install it with `pip install deepspeed`.") |
| with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None): |
| old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| else: |
| old_num_tokens, old_embedding_dim = old_embeddings.weight.size() |
| new_embeddings = nn.Embedding( |
| num_new_tokens, |
| old_embedding_dim, |
| dtype=old_embeddings.weight.dtype, |
| device=old_embeddings.weight.device, |
| ) |
| mapped_indices = ( |
| torch.arange(num_new_tokens).to(old_embeddings.weight.device) |
| / (num_new_tokens - 1) |
| * (old_num_tokens - 1) |
| ) |
| floor_indices = torch.clamp(mapped_indices.floor().long(), min=0, max=old_num_tokens - 1) |
| ceil_indices = torch.clamp(mapped_indices.ceil().long(), min=0, max=old_num_tokens - 1) |
| if is_deepspeed_zero3_enabled(): |
| params = [old_embeddings.weight, new_embeddings.weight] |
| with deepspeed.zero.GatheredParameters(params, modifier_rank=0): |
| interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ |
| ceil_indices, : |
| ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] |
| else: |
| interpolated_embeds = (mapped_indices - floor_indices)[:, None] * old_embeddings.weight.data[ |
| ceil_indices, : |
| ] + (ceil_indices - mapped_indices)[:, None] * old_embeddings.weight.data[floor_indices, :] |
| new_embeddings.weight.data = interpolated_embeds |
| case _: |
| raise NotImplementedError |
|
|
| if hasattr(old_embeddings, "_hf_hook"): |
| hook = old_embeddings._hf_hook |
| add_hook_to_module(new_embeddings, hook) |
| new_embeddings.requires_grad_(old_embeddings.weight.requires_grad) |
| |
| model.config.image_size = resolution |
| if hasattr(image_processor, "crop_size"): |
| |
| image_processor.crop_size = resolution |
| else: |
| |
| assert hasattr(image_processor, "size") |
| image_processor.size = {"height": resolution, "width": resolution} |
| embeddings.position_embedding = new_embeddings |
| embeddings.image_size = resolution |
| embeddings.num_patches = embeddings.num_positions = num_new_tokens |
| embeddings.position_ids = ( |
| torch.arange(embeddings.num_positions).expand((1, -1)).to(old_embeddings.weight.device) |
| ) |
|
|
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_forward_out = self.vision_tower( |
| image.to(device=self.device, dtype=self.dtype).unsqueeze(0), |
| output_hidden_states=True, |
| ) |
| image_feature = self.feature_select(image_forward_out).to(image.dtype) |
| image_features.append(image_feature) |
| else: |
| image_forward_outs = self.vision_tower( |
| images.to(device=self.device, dtype=self.dtype), |
| output_hidden_states=True, |
| ) |
| image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
| return image_features |
|
|
| @property |
| def dummy_feature(self): |
| return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
| @property |
| def dtype(self): |
| return self.vision_tower.dtype |
|
|
| @property |
| def device(self): |
| return self.vision_tower.device |
|
|
| @property |
| def config(self): |
| if self.is_loaded: |
| return self.vision_tower.config |
| else: |
| return self.cfg_only |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size |
|
|
| @property |
| def num_patches(self): |
| return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
| class VisionTowerS2(VisionTower): |
| def __init__(self, vision_tower, args, delay_load=False): |
| super().__init__(vision_tower, args, delay_load) |
|
|
| self.scales = list(map(int, args.s2_scales.split(","))) |
| self.scales.sort() |
| self.max_split_size = args.s2_max_split_size |
| self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) |
|
|
| def forward_feature(self, images): |
| image_forward_outs = self.vision_tower( |
| images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
| ) |
| image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| return image_features |
|
|
| def forward(self, images): |
| if type(images) is list: |
| image_features = [] |
| for image in images: |
| image_feature = multiscale_forward( |
| self.forward_feature, |
| image.unsqueeze(0), |
| img_sizes=self.scales, |
| max_split_size=self.max_split_size, |
| resize_output_to_idx=self.resize_output_to_scale_idx, |
| ) |
| image_features.append(image_feature) |
| else: |
| image_features = multiscale_forward( |
| self.forward_feature, |
| images, |
| img_sizes=self.scales, |
| max_split_size=self.max_split_size, |
| resize_output_to_idx=self.resize_output_to_scale_idx, |
| ) |
|
|
| return image_features |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size * len(self.scales) |
|
|
|
|
| class VisionTowerDynamicS2(VisionTower): |
| def __init__(self, vision_tower, args, delay_load=False): |
| super().__init__(vision_tower, args, delay_load) |
|
|
| self.scales = list(map(int, args.s2_scales.split(","))) |
| self.scales.sort() |
| self.max_split_size = args.s2_max_split_size |
| self.resize_output_to_scale_idx = getattr(args, "s2_resize_output_to_scale_idx", 0) |
|
|
| def forward_feature(self, images): |
| image_forward_outs = self.vision_tower( |
| images.to(device=self.device, dtype=self.dtype), output_hidden_states=True |
| ) |
| image_features = self.feature_select(image_forward_outs).to(images.dtype) |
| return image_features |
|
|
| def forward(self, images): |
| assert type(images) is not list |
| image_features = self.forward_feature(images) |
|
|
| return image_features |
|
|
| @property |
| def hidden_size(self): |
| return self.config.hidden_size * len(self.scales) |
|
|
|
|
| class SiglipVisionTower(VisionTower): |
| def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| super().__init__(model_name_or_path, config) |
| self.vision_tower = SiglipVisionModel.from_pretrained( |
| model_name_or_path, |
| attn_implementation=config._attn_implementation, |
| torch_dtype=eval(config.model_dtype), |
| ) |
| self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| self.is_loaded = True |
|
|
|
|
| class SiglipVisionTowerS2(VisionTowerS2): |
| def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| super().__init__(model_name_or_path, config) |
| self.vision_tower = SiglipVisionModel.from_pretrained( |
| model_name_or_path, |
| attn_implementation=config._attn_implementation, |
| torch_dtype=eval(config.model_dtype), |
| ) |
| self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| |
| self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[-1] |
| self.is_loaded = True |
|
|
|
|
| class SiglipVisionTowerDynamicS2(VisionTowerDynamicS2): |
| def __init__(self, model_name_or_path: str, config: PretrainedConfig) -> None: |
| super().__init__(model_name_or_path, config) |
| self.vision_tower = SiglipVisionModel.from_pretrained( |
| model_name_or_path, |
| attn_implementation=config._attn_implementation, |
| torch_dtype=eval(config.model_dtype), |
| ) |
| self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
| |
| self.image_processor.size["height"] = self.image_processor.size["width"] = self.scales[0] |
| self.is_loaded = True |
|
|