Orient-Anything / demo.py
zhang-ziang
init
43a369c
raw
history blame
2.05 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import glob
from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
import pandas as pd
from paths import *
import numpy as np
from vision_tower import DINOv2_MLP
from PIL import Image
save_path = './'
device = 'cpu'
dino = DINOv2_MLP(
dino_mode = 'large',
in_dim = 1024,
out_dim = 360+180+60+2,
evaluate = True,
mask_dino = False,
frozen_back = False
).to(device)
dino.eval()
dino.load_state_dict(torch.load(os.path.join(save_path, 'dino_weight.pt'), map_location='cpu'))
val_preprocess = AutoImageProcessor.from_pretrained(DINO_LARGE, cache_dir='./')
def get_3angle(image_path):
image = Image.open(image_path).convert('RGB')
image_inputs = val_preprocess(images = image)
image_inputs['pixel_values'] = torch.from_numpy(np.array(image_inputs['pixel_values'])).to(device)
with torch.no_grad():
dino_pred = dino(image_inputs)
gaus_ax_pred = torch.argmax(dino_pred[:, 0:360], dim=-1)
gaus_pl_pred = torch.argmax(dino_pred[:, 360:360+180], dim=-1)
gaus_ro_pred = torch.argmax(dino_pred[:, 360+180:360+180+60], dim=-1)
angles = torch.zeros(3)
angles[0] = gaus_ax_pred
angles[1] = gaus_pl_pred - 90
angles[2] = gaus_ro_pred - 30
return angles
with torch.no_grad():
obj_angles = []
img_paths = glob.glob(os.path.join('/home/aiops/wangzh/wangjialei/data_preprocess/meta/sa_10099.jpg'))
img_paths.sort()
for image_path in img_paths:
# image_path = f'/home/aiops/wangzh/zza/Objaverse_render_extract/coco/demo_image/3D/{i}.png'
image_name = image_path.split('/')[-1]
print(image_name)
angles = get_3angle(image_path)
obj_angles.append(angles)
# print(f'cat/{i}.png', angles)
obj_angles = torch.stack(obj_angles, dim=0)
print('wild', obj_angles)