spat-data / mask_image.py
chouss's picture
Uploading folder contents
6029b11 verified
import json
import os
import random
from torch.utils.data import Dataset
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
from torchvision import transforms
from tqdm import tqdm
import pickle
import cv2
import torch
import numpy as np
import copy
from transformers import AutoProcessor
from nltk.corpus import wordnet
from bg_aug import get_bkgd
import jax
import random
clip_standard_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
mask_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224, 224)),
transforms.Normalize(0.5, 0.26)
])
crop_aug = transforms.Compose([
transforms.RandomCrop((224-32, 224-32)),
transforms.Resize((224, 224)),
])
def text_filter(text):
text = text.replace(' with a white background', '')
text = text.replace(' with white background', '')
text = text.replace(' next to a white background', '')
text = text.replace(' over a white background', '')
text = text.replace(' is cut out of a white background', '')
text = text.replace(' across a white background', '')
text = text.replace(' on a white background', '')
text = text.replace(' sticking out of a white background', '')
text = text.replace(' in the middle of a white background', '')
text = text.replace(' on white background', '')
text = text.replace(' in a white background', '')
text = text.replace(' and a white background', '')
text = text.replace(' and white background', '')
text = text.replace(' in front of a white background', '')
text = text.replace(' on top of a white background', '')
text = text.replace(' against a white background', '')
text = text.replace('a white background with ', '')
text = text.replace(' and has a white background', '')
text = text.replace('white background', 'background')
text = text + '.'
return text
def crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, scale=1.5):
tl_x = int(bbox_xywh[0])
tl_y = int(bbox_xywh[1])
w = int(bbox_xywh[2]) if int(bbox_xywh[2]) > 0 else 1
h = int(bbox_xywh[3]) if int(bbox_xywh[3]) > 0 else 1
image_h, image_w = image.shape[:2]
# shape maintained
r = max(h, w)
tl_x -= (r - w) / 2
tl_y -= (r - h) / 2
half_scale = (scale - 1.0) / 2
w_l = int(tl_x - half_scale * r) if (tl_x - half_scale * r) > 0 else 0
w_r = int(tl_x + (1+half_scale) * r) if (tl_x + (1+half_scale) * r) < image_w else image_w - 1
h_t = int(tl_y - half_scale * r) if (tl_y - half_scale * r) > 0 else 0
h_b = int(tl_y + (1+half_scale) * r) if (tl_y + (1+half_scale) * r) < image_h else image_h - 1
return image[h_t: h_b, w_l: w_r, :], bi_mask[h_t: h_b, w_l: w_r]
def masked_crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, crop_scale=1.0, masked_color=[255, 255, 255]):
# padding to make_sure bboxshape maintained
image = np.pad(image, ((600, 600), (600, 600), (0, 0)), 'constant', constant_values=255)
bi_mask = np.pad(bi_mask, ((600, 600), (600, 600)), "constant", constant_values=0)
bbox_xywh[:2] += 600
cropped_image, cropped_mask = crop(image, bbox_xywh, bi_mask, crop_scale)
cropped_image[np.nonzero(cropped_mask == 0)] = masked_color
return cropped_image, cropped_mask
class ImageNet_Masked(Dataset):
def __init__(self, ann_file="M_ImageNet_top_460k.json", masked_color=[255, 255, 255]):
self.masked_color = masked_color
self.anns_list = json.load(open(ann_file, 'r'))
random.shuffle(self.anns_list)
self.crop_scale = 1.5
self.transform = clip_standard_transform
self.res = 224
self.blur = 10.0
def __len__(self):
return len(self.anns_list)
def __getitem__(self, index):
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(0)
ann = self.anns_list[index]
# TODO: change list to dict key.
img_pth = ann[2]
# img_pth = img_pth.replace('imagenet-21k/images', 'imagenet-21k-demo/*')
mask = ann[3]
bbox = ann[4]
text = ann[6]
image = cv2.imread(img_pth)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
bbox_xywh = np.copy(np.array(bbox))
binary_mask = maskUtils.decode(mask)
cat_word = img_pth.split("/")[3]
synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
synonyms = [x.name() for x in synset.lemmas()]
text = text.replace(".", f", probably {synonyms[0]}").replace(" ", "_").replace("/", "_").replace("\\", "_")
image[np.nonzero(binary_mask == 1)] = (0.5 * image[np.nonzero(binary_mask == 1)] + 0.5 * np.array([0, 255, 0])).astype(np.uint8)
os.makedirs(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0], exist_ok=True)
Image.fromarray(image).save(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0] + f"/{text}_" + os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[1])
if __name__ == "__main__":
data = ImageNet_Masked()
for i in tqdm(range(data.__len__())):
data.__getitem__(i)