Update inference_onnx.py
Browse files- inference_onnx.py +21 -7
inference_onnx.py
CHANGED
|
@@ -31,6 +31,10 @@ class ONNXInference:
|
|
| 31 |
self.device = device
|
| 32 |
self.use_openvino = False
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
if device == "cuda":
|
| 35 |
# Try CUDA first for GPU
|
| 36 |
try:
|
|
@@ -47,7 +51,7 @@ class ONNXInference:
|
|
| 47 |
# Check if CUDA is actually being used
|
| 48 |
if self.execution_provider == "CUDAExecutionProvider":
|
| 49 |
print(f"Using ONNX Runtime with {self.execution_provider}")
|
| 50 |
-
# Get transform
|
| 51 |
self.transform = self._get_transform()
|
| 52 |
return
|
| 53 |
else:
|
|
@@ -76,7 +80,7 @@ class ONNXInference:
|
|
| 76 |
print(f"OpenVINO initialization failed: {e}, falling back to ONNX Runtime CPU")
|
| 77 |
self._init_onnx_runtime_cpu(model_path)
|
| 78 |
|
| 79 |
-
# Get transform
|
| 80 |
self.transform = self._get_transform()
|
| 81 |
|
| 82 |
def _init_onnx_runtime_cpu(self, model_path):
|
|
@@ -92,15 +96,24 @@ class ONNXInference:
|
|
| 92 |
self.execution_provider = self.session.get_providers()[0]
|
| 93 |
print(f"Using ONNX Runtime with {self.execution_provider}")
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def _get_transform(self):
|
| 99 |
-
"""Create preprocessing transform
|
|
|
|
| 100 |
model = create_model(self.model_arch, pretrained=False)
|
| 101 |
model.eval()
|
| 102 |
-
|
|
|
|
|
|
|
| 103 |
transform = create_transform(**config)
|
|
|
|
| 104 |
return transform
|
| 105 |
|
| 106 |
def preprocess(self, image):
|
|
@@ -115,6 +128,7 @@ class ONNXInference:
|
|
| 115 |
"""
|
| 116 |
image = image.convert("RGB")
|
| 117 |
tensor = self.transform(image)
|
|
|
|
| 118 |
return tensor.unsqueeze(0).cpu().numpy()
|
| 119 |
|
| 120 |
def predict(self, image, top_k=5, threshold=0.0):
|
|
@@ -148,4 +162,4 @@ class ONNXInference:
|
|
| 148 |
def softmax(x):
|
| 149 |
"""Compute softmax values for a set of scores."""
|
| 150 |
e_x = np.exp(x - np.max(x))
|
| 151 |
-
return e_x / e_x.sum(axis=0)
|
|
|
|
| 31 |
self.device = device
|
| 32 |
self.use_openvino = False
|
| 33 |
|
| 34 |
+
# Hardcoded input size mapping - based on actual model definitions
|
| 35 |
+
self.input_size = self._get_input_size(model_arch)
|
| 36 |
+
print(f"Using input size: {self.input_size} for model {model_arch}")
|
| 37 |
+
|
| 38 |
if device == "cuda":
|
| 39 |
# Try CUDA first for GPU
|
| 40 |
try:
|
|
|
|
| 51 |
# Check if CUDA is actually being used
|
| 52 |
if self.execution_provider == "CUDAExecutionProvider":
|
| 53 |
print(f"Using ONNX Runtime with {self.execution_provider}")
|
| 54 |
+
# Get transform with correct input size
|
| 55 |
self.transform = self._get_transform()
|
| 56 |
return
|
| 57 |
else:
|
|
|
|
| 80 |
print(f"OpenVINO initialization failed: {e}, falling back to ONNX Runtime CPU")
|
| 81 |
self._init_onnx_runtime_cpu(model_path)
|
| 82 |
|
| 83 |
+
# Get transform with correct input size
|
| 84 |
self.transform = self._get_transform()
|
| 85 |
|
| 86 |
def _init_onnx_runtime_cpu(self, model_path):
|
|
|
|
| 96 |
self.execution_provider = self.session.get_providers()[0]
|
| 97 |
print(f"Using ONNX Runtime with {self.execution_provider}")
|
| 98 |
|
| 99 |
+
def _get_input_size(self, model_arch):
|
| 100 |
+
"""Get input size based on model architecture - hardcoded to match actual model definitions"""
|
| 101 |
+
if model_arch == 'lsnet_xl_artist_448':
|
| 102 |
+
return 448
|
| 103 |
+
else:
|
| 104 |
+
# All other artist models use 224
|
| 105 |
+
return 224
|
| 106 |
|
| 107 |
def _get_transform(self):
|
| 108 |
+
"""Create preprocessing transform with correct input size"""
|
| 109 |
+
# Create a dummy model to get the base config
|
| 110 |
model = create_model(self.model_arch, pretrained=False)
|
| 111 |
model.eval()
|
| 112 |
+
|
| 113 |
+
# Override the input size with our hardcoded value
|
| 114 |
+
config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=model)
|
| 115 |
transform = create_transform(**config)
|
| 116 |
+
print(f"Created ONNX transform with input size: {self.input_size}")
|
| 117 |
return transform
|
| 118 |
|
| 119 |
def preprocess(self, image):
|
|
|
|
| 128 |
"""
|
| 129 |
image = image.convert("RGB")
|
| 130 |
tensor = self.transform(image)
|
| 131 |
+
print(f"Preprocessed image to tensor shape: {tensor.shape}")
|
| 132 |
return tensor.unsqueeze(0).cpu().numpy()
|
| 133 |
|
| 134 |
def predict(self, image, top_k=5, threshold=0.0):
|
|
|
|
| 162 |
def softmax(x):
|
| 163 |
"""Compute softmax values for a set of scores."""
|
| 164 |
e_x = np.exp(x - np.max(x))
|
| 165 |
+
return e_x / e_x.sum(axis=0)
|