Skip to content

Commit

Permalink
Merge pull request #1 from dchichkov/tiled_siglip
Browse files Browse the repository at this point in the history
Adds wrapper for siglip to allow tiling
  • Loading branch information
HuiyingLi authored May 30, 2024
2 parents f0c0a35 + 52c9107 commit e6a3a53
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 5 deletions.
4 changes: 4 additions & 0 deletions docs/source/multimodal/mllm/neva.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ Vision Encoder Configuration within Multimodal
hidden_size: 1024
vision_select_layer: -2
class_token_length: 1
grid_width: 1
grid_height: 1
freeze: True
- ``from_pretrained``: Path or name of the pretrained vision encoder.
Expand All @@ -92,6 +94,8 @@ Vision Encoder Configuration within Multimodal
- ``hidden_size``: Dimensionality of the hidden layers.
- ``vision_select_layer``: Specifies which layer to select from the vision model.
- ``class_token_length``: Length of the classification token.
- ``grid_width``: Number of horizontal visual model tiles.
- ``grid_height``: Number of vertical visual model tiles.

Main Language Model Configuration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in


def process_image(processor, image, image_aspect_ratio="square"):
if isinstance(processor, CLIPImageProcessor) or isinstance(processor, SiglipImageProcessor):
if isinstance(processor, CLIPImageProcessor) or isinstance(processor, SiglipImageProcessor) \
or isinstance(processor, TiledSiglipImageProcessor):
# image processor from HF
if image_aspect_ratio == 'keep':
max_hw, min_hw = max(image.size), min(image.size)
Expand Down
202 changes: 198 additions & 4 deletions nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
import os
from functools import partial
from itertools import chain
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.nn.functional as F, torch.nn as nn
from einops import rearrange, repeat
from omegaconf.dictconfig import DictConfig
from pkg_resources import packaging
from pytorch_lightning.trainer.trainer import Trainer
from transformers import CLIPVisionModel, CLIPImageProcessor, SiglipVisionModel, SiglipImageProcessor
from transformers.models.siglip.modeling_siglip import BaseModelOutputWithPooling
from transformers.models.siglip.processing_siglip import ImageInput


from nemo.collections.common.parts.utils import extend_instance
from nemo.collections.multimodal.data.neva.conversation import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN
Expand Down Expand Up @@ -123,6 +126,180 @@ def freeze(self) -> None:
self.frozen = True


class TiledSiglipVisionModel(nn.Module):
def __init__(self, vision_model: SiglipVisionModel, grid_height: int, grid_width: int, vision_select_layer: int):
super().__init__()
self.vision_model = vision_model
self.grid_h = grid_height
self.grid_w = grid_width
self.return_select_layer = vision_select_layer

def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:

# Reshape input from (b, c, H, W) to (b * grid_h * grid_w, c, h, w)
b, c, H, W = pixel_values.shape
h, w = H // self.grid_h, W // self.grid_w

vision_x = rearrange(pixel_values, "b c (gh h) (gw w) -> (b gh gw) c h w",
gh=self.grid_h, gw=self.grid_w, h=h, w=w)

# Pass through the vision model
vision_x = self.vision_model(
vision_x,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding
)

# Assuming features are (b * grid_h * grid_w, v, d)
features = vision_x.hidden_states[self.return_select_layer]

# Reshape back to (b, grid_h * grid_w * v, d)
b_grid_h_grid_w, v, d = features.shape
b = b_grid_h_grid_w // (self.grid_h * self.grid_w)
features = rearrange(features, "(b gh gw) v d -> b gh gw v d",
b=b, gh=self.grid_h, gw=self.grid_w)

# Downsample concatenating every grid_h x grid_w features
features = rearrange(features, "b gh gw v d -> b v (gh gw d)",
b=b, gh=self.grid_h, gw=self.grid_w)
return features




def calculate_tile_placement(image_size, tile_size, grid, max_upscale):
"""
Calculate the scale and origin for a tile grid to cover an image.
Parameters:
image_size (tuple): The size of the image as (width, height).
tile_size (tuple): The size of a single tile as (width, height).
grid (tuple): The grid dimensions as (columns, rows).
max_upscale (float): The maximum allowed upscale factor.
Returns:
tuple: A tuple containing the scale factor and the origin as (scale, (origin_x, origin_y)).
"""

img_width, img_height = image_size
tile_width, tile_height = tile_size
grid_cols, grid_rows = grid

# Calculate the effective size of the grid and the scale to cover the image
effective_width, effective_height = tile_width * grid_cols, tile_height * grid_rows
scale = max(img_width / effective_width, img_height / effective_height, 1 / max_upscale)

# Calculate the total size of the scaled grid
grid_img_width, grid_img_height = (tile_width * scale * grid_cols, tile_height * scale * grid_rows)

# Calculate the origin to center the grid on the image
origin_x, origin_y = (img_width - grid_img_width) / 2, (img_height - grid_img_height) / 2

# Round the origin to the nearest integer
origin_x, origin_y = round(origin_x), round(origin_y)

return scale, (origin_x, origin_y)



def generate_tile_coordinates(origin, tile_size, grid, scale):
"""
Generate a list of tile coordinates based on the given tile size, grid dimensions,
scale factor, and origin shift.
Parameters:
origin (tuple): A tuple (origin_x, origin_y) representing the shift of the tile origin.
tile_size (tuple): A tuple (tile_width, tile_height) representing the size of each tile.
grid (tuple): A tuple (num_tiles_horiz, num_tiles_vert) representing the number of tiles horizontally and vertically.
scale (float): A scale factor to apply to the size of each tile.
Returns:
list of tuples: A list of tuples where each tuple contains (left, upper, right, lower) coordinates of a tile.
"""
tile_width, tile_height = tile_size
num_tiles_horiz, num_tiles_vert = grid
scale_factor = scale
origin_x, origin_y = origin

tile_coordinates = []
for i in range(num_tiles_vert):
for j in range(num_tiles_horiz):
left = origin_x + j * (tile_width * scale_factor)
upper = origin_y + i * (tile_height * scale_factor)
right = left + tile_width * scale_factor
lower = upper + tile_height * scale_factor
tile_coordinates.append((left, upper, right, lower))
return tile_coordinates


class TiledSiglipImageProcessor:
r"""
Wraps a SiglipImageProcessor to add grid processing functionality.
Args:
processor (SiglipImageProcessor):
An instance of SiglipImageProcessor to wrap and add grid functionality.
grid_width (`int`, defaults to `1`):
Width of the grid to be applied to the image.
grid_height (`int`, defaults to `1`):
Height of the grid to be applied to the image.
"""

def __init__(
self,
processor: SiglipImageProcessor,
grid_width: int = 1,
grid_height: int = 1,
max_upscale: float = 2.0,
) -> None:
self.processor = processor
self.grid_width = grid_width
self.grid_height = grid_height
self.max_upscale = max_upscale


def preprocess(
self,
images: ImageInput,
return_tensors: str = 'pt',
**kwargs,
): #-> PILImage.Image:
"""
"""
if self.grid_width > 1 or self.grid_height > 1:
if return_tensors !='pt':
raise ValueError("Image mode is currently not supported by TiledSiglipImageProcessor")

# only a single image is supported for now
image = images

grid_w, grid_h = self.grid_width,self.grid_height # number of tiles in the grid
tile_w, tile_h = self.processor.size['width'], self.processor.size['height']
scale, origin = calculate_tile_placement(image.size, (tile_w, tile_h), (grid_w, grid_h), self.max_upscale)
tile_coordinates = generate_tile_coordinates(origin, (tile_w, tile_h), (grid_w, grid_h), scale)
images = [image.crop(cs) for cs in tile_coordinates] # square, but needs resize
images = self.processor.preprocess(images, return_tensors='pt')['pixel_values']
#print(images.shape)
tensors = rearrange(images, '(gh gw) c th tw -> 1 c (gh th) (gw tw)',
gh=grid_h, gw=grid_w, c=3, th=tile_h, tw=tile_w)
return tensors

elif self.grid_width == 1 and self.grid_height == 1:
return self.processor.preprocess(images, **kwargs)

raise ValueError("Invalid grid dimensions in the TiledSiglipImageProcessor")


class NevaWordEmbeddingMixin(torch.nn.Module, adapter_mixins.AdapterModuleMixin):
"""
A mixin class for integrating vision-based embeddings into language models.
Expand All @@ -142,7 +319,7 @@ def init_vision(
use_im_start_end=False,
):
self.vision_encoder = vision_encoder
self.from_hf = isinstance(vision_encoder, CLIPVisionModel) or isinstance(vision_encoder, SiglipVisionModel)
self.from_hf = isinstance(vision_encoder, CLIPVisionModel) or isinstance(vision_encoder, TiledSiglipVisionModel)
self.from_open_clip = "open_clip" in str(vision_encoder.__module__)
self.media_start_id = media_start_id
self.media_end_id = media_end_id
Expand Down Expand Up @@ -181,6 +358,10 @@ def encode_vision_x(self, vision_x: torch.Tensor):
with torch.no_grad():
if self.from_hf:
vision_x = self.vision_encoder(vision_x, output_hidden_states=True)

# TODO: hack, need to fix this - this is CLIP specific
if hasattr(vision_x, "hidden_states"):
vision_x = vision_x.hidden_states[self.vision_select_layer]
vision_x = vision_x.hidden_states[self.vision_select_layer]
else:
self.vision_encoder.backbone.transformer.return_select_layer = self.vision_select_layer
Expand Down Expand Up @@ -313,16 +494,29 @@ def create_vision_encoder_and_processor(self, mm_cfg):
)
elif "siglip" in mm_cfg.vision_encoder.from_pretrained:
vision_encoder = SiglipVisionModel.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16,
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
).cuda()
vision_encoder = vision_encoder.to(torch.bfloat16)
if mm_cfg.vision_encoder.freeze:
for param in vision_encoder.parameters():
param.requires_grad = False
vision_encoder = vision_encoder.eval()

vision_encoder = TiledSiglipVisionModel(vision_encoder,
grid_height = mm_cfg.vision_encoder.get("grid_height", 1),
grid_width = mm_cfg.vision_encoder.get("grid_width", 1),
vision_select_layer=mm_cfg.vision_encoder.get("vision_select_layer", -1),
).cuda()

image_processor = SiglipImageProcessor.from_pretrained(
mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16
)

image_processor = TiledSiglipImageProcessor(image_processor,
grid_width = mm_cfg.vision_encoder.get("grid_width", 1),
grid_height = mm_cfg.vision_encoder.get("grid_height", 1),
max_upscale = mm_cfg.vision_encoder.get("max_upscale", 2.0),
)
else:
raise(ValueError("Currently only support CLIPVisionModel and SigLipVisionModel from Huggingface"))
elif mm_cfg.vision_encoder.get("from_open_clip", False):
Expand Down

0 comments on commit e6a3a53

Please sign in to comment.