| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Image processor class for Magma.""" |
| |
|
| | from typing import List, Optional, Union |
| | import ast |
| | import numpy as np |
| | import torchvision |
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| | from transformers.image_transforms import ( |
| | convert_to_rgb, |
| | ) |
| | from transformers.image_utils import ( |
| | OPENAI_CLIP_MEAN, |
| | OPENAI_CLIP_STD, |
| | ImageInput, |
| | make_list_of_images, |
| | valid_images, |
| | ) |
| | from transformers.utils import TensorType, is_vision_available, logging |
| |
|
| | from transformers import AutoImageProcessor |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | if is_vision_available(): |
| | from PIL import Image |
| |
|
| | import torch |
| | import torchvision |
| |
|
| | def select_best_resolution(original_size, possible_resolutions): |
| | """ |
| | Selects the best resolution from a list of possible resolutions based on the original size. |
| | |
| | Args: |
| | original_size (tuple): The original size of the image in the format (width, height). |
| | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
| | |
| | Returns: |
| | tuple: The best fit resolution in the format (width, height). |
| | """ |
| | original_width, original_height = original_size |
| | best_fit = None |
| | max_effective_resolution = 0 |
| | min_wasted_resolution = float('inf') |
| |
|
| | for width, height in possible_resolutions: |
| | scale = min(width / original_width, height / original_height) |
| | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
| | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
| | wasted_resolution = (width * height) - effective_resolution |
| |
|
| | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
| | max_effective_resolution = effective_resolution |
| | min_wasted_resolution = wasted_resolution |
| | best_fit = (width, height) |
| |
|
| | return best_fit |
| |
|
| | def process_anyres_image(image, max_num_crops=None, base_width=768, base_height=768): |
| | """ |
| | Process an image with variable resolutions. |
| | |
| | Args: |
| | image (torch.Tensor): The input image to be processed. |
| | max_num_crops (int): Maximum number of crops |
| | |
| | Returns: |
| | torch.Tensor: A tensor containing the processed image patches. |
| | """ |
| | assert max_num_crops is not None |
| | grid_pinpoints = [] |
| | for i in range(1, max_num_crops+1): |
| | for j in range(1, max_num_crops // i + 1): |
| | grid_pinpoints.append((i, j)) |
| | grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints] |
| |
|
| | if type(grid_pinpoints) is list: |
| | possible_resolutions = grid_pinpoints |
| | else: |
| | possible_resolutions = ast.literal_eval(grid_pinpoints) |
| | |
| | best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions) |
| | |
| | best_resolution = (best_resolution[1], best_resolution[0]) |
| | best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width) |
| |
|
| | |
| | image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear') |
| | |
| | patches = image.unfold(2, base_height, base_height).unfold(3, base_width, base_width) |
| | patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(best_resolution_grid[0]*best_resolution_grid[1], -1, base_height, base_width) |
| | return (patches, best_resolution_grid) |
| |
|
| | def process_anyres_image_global(image, max_num_crops=None, base_width=768, base_height=768): |
| | """ |
| | Process an image with variable resolutions. |
| | |
| | Args: |
| | image (torch.Tensor): The input image to be processed. |
| | max_num_crops (int): Maximum number of crops |
| | |
| | Returns: |
| | torch.Tensor: A tensor containing the processed image patches. |
| | """ |
| | assert max_num_crops is not None |
| | grid_pinpoints = [] |
| | for i in range(1, max_num_crops+1): |
| | for j in range(1, max_num_crops // i + 1): |
| | grid_pinpoints.append((i, j)) |
| | grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints] |
| |
|
| | if type(grid_pinpoints) is list: |
| | possible_resolutions = grid_pinpoints |
| | else: |
| | possible_resolutions = ast.literal_eval(grid_pinpoints) |
| | |
| | best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions) |
| | |
| | best_resolution = (best_resolution[1], best_resolution[0]) |
| | best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width) |
| |
|
| | |
| | image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear') |
| | return image |
| |
|
| | class preprocessor(): |
| | def __init__(self, image_preprocessor, base_resolution=(256, 256)): |
| | self.image_preprocessor = image_preprocessor |
| | self.crop_size = { |
| | 'height': base_resolution[0], |
| | 'width': base_resolution[1] |
| | } |
| | self.image_mean = image_preprocessor.transforms[-1].mean |
| |
|
| | def preprocess(self, image, return_tensors='pt'): |
| | image = self.image_preprocessor(image).unsqueeze(0) |
| | return { |
| | 'pixel_values': image, |
| | } |
| |
|
| | class MagmaImageProcessor(BaseImageProcessor): |
| | r""" |
| | Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques |
| | for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512) |
| | |
| | Args: |
| | anyres_strategy (`str`): |
| | strategy to cope with high-resolution images. one conventional way is multi-crop and many other works to accomadate clip-vit models. |
| | however, since we are using convnext, which is essentially convnet, so we can use arbitary resolution images. as such, we use global strategy by defualt, |
| | i.e., directly resize image holistically to a certain resolution. |
| | base_img_size (int, *optional*, defaults to 768): |
| | as convnext has 1/32 downsample rate, we use 768 as the base resolution so that the resulted feature map is 24x24. |
| | num_crops (int, *optional*, defaults to 1): |
| | number of effective crops when coping with images with higher resolution than 768x768. note that num_crops > 1 does not mean we are cropping the image. |
| | """ |
| |
|
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, |
| | anyres_strategy: str = 'global', |
| | base_img_size: int = 768, |
| | num_crops: int = 1, |
| | do_convert_rgb: bool = True, |
| | image_mean: List[float] = OPENAI_CLIP_MEAN, |
| | image_std: List[float] = OPENAI_CLIP_STD, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__(**kwargs) |
| | self.base_img_size = base_img_size |
| | self.anyres_strategy = anyres_strategy |
| | self.num_crops = num_crops |
| | self.do_convert_rgb = do_convert_rgb |
| | self.image_mean = image_mean |
| | self.image_std = image_std |
| |
|
| | def preprocess( |
| | self, |
| | images: Union[ImageInput, List[ImageInput]], |
| | do_pad: bool = False, |
| | do_convert_rgb: bool = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | num_crops: int = None, |
| | ): |
| | """ |
| | Args: |
| | images (`ImageInput` or `List[ImageInput]`): |
| | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
| | passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
| | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| | Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
| | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| | Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
| | `True`. |
| | do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| | Whether to convert the image to RGB. |
| | return_tensors (`str` or `TensorType`, *optional*): |
| | The type of tensors to return. Can be one of: |
| | - Unset: Return a list of `np.ndarray`. |
| | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. |
| | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
| | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. |
| | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. |
| | """ |
| | images = make_list_of_images(images) |
| |
|
| | if not valid_images(images): |
| | raise ValueError( |
| | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
| | "torch.Tensor, tf.Tensor or jax.ndarray." |
| | ) |
| | |
| | if do_convert_rgb: |
| | images = [convert_to_rgb(image) for image in images] |
| | |
| | |
| | img_processor = torchvision.transforms.Compose([ |
| | torchvision.transforms.ToTensor(), |
| | torchvision.transforms.Normalize(self.image_mean, self.image_std) |
| | ]) |
| |
|
| | images = [img_processor(image) for image in images] |
| | image_data_type = 'half' if images[0].type() == 'torch.HalfTensor' else 'float' |
| | images = [image.float() for image in images] |
| |
|
| | |
| | image_patches = [process_anyres_image(image, self.num_crops if num_crops is None else num_crops, base_width=self.base_img_size, base_height=self.base_img_size) for image in images] |
| | pixel_values = torch.cat([image[0] for image in image_patches], dim=0) |
| | |
| | image_sizes = [image_patch[1] for image_patch in image_patches] |
| |
|
| | if image_data_type == 'half': |
| | pixel_values = pixel_values.half() |
| |
|
| | data = { |
| | "pixel_values": pixel_values, |
| | "image_sizes": image_sizes, |
| | } |
| | return BatchFeature(data=data, tensor_type=return_tensors) |
| |
|
| | AutoImageProcessor.register("MagmaImageProcessor", MagmaImageProcessor) |