|
|
import json |
|
|
import os |
|
|
import random |
|
|
|
|
|
from torch.utils.data import Dataset |
|
|
from pycocotools.coco import COCO |
|
|
from pycocotools import mask as maskUtils |
|
|
|
|
|
from PIL import Image |
|
|
from PIL import ImageFile |
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
Image.MAX_IMAGE_PIXELS = None |
|
|
from tqdm import tqdm |
|
|
from torchvision import transforms |
|
|
from tqdm import tqdm |
|
|
import pickle |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
import copy |
|
|
from transformers import AutoProcessor |
|
|
from nltk.corpus import wordnet |
|
|
from bg_aug import get_bkgd |
|
|
import jax |
|
|
import random |
|
|
|
|
|
clip_standard_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize((224, 224), interpolation=Image.BICUBIC), |
|
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
|
]) |
|
|
to_tensor = transforms.ToTensor() |
|
|
|
|
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
|
|
mask_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.Normalize(0.5, 0.26) |
|
|
]) |
|
|
|
|
|
crop_aug = transforms.Compose([ |
|
|
transforms.RandomCrop((224-32, 224-32)), |
|
|
transforms.Resize((224, 224)), |
|
|
]) |
|
|
|
|
|
def text_filter(text): |
|
|
text = text.replace(' with a white background', '') |
|
|
text = text.replace(' with white background', '') |
|
|
text = text.replace(' next to a white background', '') |
|
|
text = text.replace(' over a white background', '') |
|
|
text = text.replace(' is cut out of a white background', '') |
|
|
text = text.replace(' across a white background', '') |
|
|
text = text.replace(' on a white background', '') |
|
|
text = text.replace(' sticking out of a white background', '') |
|
|
text = text.replace(' in the middle of a white background', '') |
|
|
text = text.replace(' on white background', '') |
|
|
text = text.replace(' in a white background', '') |
|
|
text = text.replace(' and a white background', '') |
|
|
text = text.replace(' and white background', '') |
|
|
text = text.replace(' in front of a white background', '') |
|
|
text = text.replace(' on top of a white background', '') |
|
|
text = text.replace(' against a white background', '') |
|
|
text = text.replace('a white background with ', '') |
|
|
text = text.replace(' and has a white background', '') |
|
|
text = text.replace('white background', 'background') |
|
|
text = text + '.' |
|
|
return text |
|
|
|
|
|
def crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, scale=1.5): |
|
|
tl_x = int(bbox_xywh[0]) |
|
|
tl_y = int(bbox_xywh[1]) |
|
|
w = int(bbox_xywh[2]) if int(bbox_xywh[2]) > 0 else 1 |
|
|
h = int(bbox_xywh[3]) if int(bbox_xywh[3]) > 0 else 1 |
|
|
image_h, image_w = image.shape[:2] |
|
|
|
|
|
|
|
|
r = max(h, w) |
|
|
tl_x -= (r - w) / 2 |
|
|
tl_y -= (r - h) / 2 |
|
|
half_scale = (scale - 1.0) / 2 |
|
|
w_l = int(tl_x - half_scale * r) if (tl_x - half_scale * r) > 0 else 0 |
|
|
w_r = int(tl_x + (1+half_scale) * r) if (tl_x + (1+half_scale) * r) < image_w else image_w - 1 |
|
|
h_t = int(tl_y - half_scale * r) if (tl_y - half_scale * r) > 0 else 0 |
|
|
h_b = int(tl_y + (1+half_scale) * r) if (tl_y + (1+half_scale) * r) < image_h else image_h - 1 |
|
|
|
|
|
return image[h_t: h_b, w_l: w_r, :], bi_mask[h_t: h_b, w_l: w_r] |
|
|
|
|
|
def masked_crop(image: np.array, bbox_xywh: np.array, bi_mask: np.array, crop_scale=1.0, masked_color=[255, 255, 255]): |
|
|
|
|
|
image = np.pad(image, ((600, 600), (600, 600), (0, 0)), 'constant', constant_values=255) |
|
|
bi_mask = np.pad(bi_mask, ((600, 600), (600, 600)), "constant", constant_values=0) |
|
|
bbox_xywh[:2] += 600 |
|
|
cropped_image, cropped_mask = crop(image, bbox_xywh, bi_mask, crop_scale) |
|
|
cropped_image[np.nonzero(cropped_mask == 0)] = masked_color |
|
|
return cropped_image, cropped_mask |
|
|
|
|
|
class ImageNet_Masked(Dataset): |
|
|
def __init__(self, ann_file="M_ImageNet_top_460k.json", masked_color=[255, 255, 255]): |
|
|
self.masked_color = masked_color |
|
|
self.anns_list = json.load(open(ann_file, 'r')) |
|
|
random.shuffle(self.anns_list) |
|
|
self.crop_scale = 1.5 |
|
|
self.transform = clip_standard_transform |
|
|
self.res = 224 |
|
|
self.blur = 10.0 |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.anns_list) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
cv2.ocl.setUseOpenCL(False) |
|
|
cv2.setNumThreads(0) |
|
|
ann = self.anns_list[index] |
|
|
|
|
|
img_pth = ann[2] |
|
|
|
|
|
mask = ann[3] |
|
|
bbox = ann[4] |
|
|
text = ann[6] |
|
|
image = cv2.imread(img_pth) |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
bbox_xywh = np.copy(np.array(bbox)) |
|
|
binary_mask = maskUtils.decode(mask) |
|
|
cat_word = img_pth.split("/")[3] |
|
|
synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:])) |
|
|
synonyms = [x.name() for x in synset.lemmas()] |
|
|
text = text.replace(".", f", probably {synonyms[0]}").replace(" ", "_").replace("/", "_").replace("\\", "_") |
|
|
image[np.nonzero(binary_mask == 1)] = (0.5 * image[np.nonzero(binary_mask == 1)] + 0.5 * np.array([0, 255, 0])).astype(np.uint8) |
|
|
os.makedirs(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0], exist_ok=True) |
|
|
Image.fromarray(image).save(os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[0] + f"/{text}_" + os.path.split(img_pth.replace("imagenet-21k/images", "visual_train_c"))[1]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
data = ImageNet_Masked() |
|
|
for i in tqdm(range(data.__len__())): |
|
|
data.__getitem__(i) |
|
|
|