Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import argparse | |
| import os | |
| import glob | |
| from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor | |
| import pandas as pd | |
| from paths import * | |
| import numpy as np | |
| from vision_tower import DINOv2_MLP | |
| from PIL import Image | |
| save_path = './' | |
| device = 'cpu' | |
| dino = DINOv2_MLP( | |
| dino_mode = 'large', | |
| in_dim = 1024, | |
| out_dim = 360+180+60+2, | |
| evaluate = True, | |
| mask_dino = False, | |
| frozen_back = False | |
| ).to(device) | |
| dino.eval() | |
| dino.load_state_dict(torch.load(os.path.join(save_path, 'dino_weight.pt'), map_location='cpu')) | |
| val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./') | |
| def get_3angle(image_path): | |
| image = Image.open(image_path).convert('RGB') | |
| image_inputs = val_preprocess(images = image) | |
| image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device) | |
| with torch.no_grad(): | |
| dino_pred = dino(image_inputs) | |
| gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1) | |
| gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1) | |
| gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1) | |
| angles = torch.zeros(3) | |
| angles[0] = gaus_ax_pred | |
| angles[1] = gaus_pl_pred - 90 | |
| angles[2] = gaus_ro_pred - 30 | |
| return angles | |
| with torch.no_grad(): | |
| obj_angles = [] | |
| img_paths = glob.glob(os.path.join('/home/aiops/wangzh/wangjialei/data_preprocess/meta/sa_10099.jpg')) | |
| img_paths.sort() | |
| for image_path in img_paths: | |
| # image_path = f'/home/aiops/wangzh/zza/Objaverse_render_extract/coco/demo_image/3D/{i}.png' | |
| image_name = image_path.split('/')[-1] | |
| print(image_name) | |
| angles = get_3angle(image_path) | |
| obj_angles.append(angles) | |
| # print(f'cat/{i}.png', angles) | |
| obj_angles = torch.stack(obj_angles, dim=0) | |
| print('wild', obj_angles) | |