Update inference_pytorch.py
Browse files- inference_pytorch.py +20 -4
inference_pytorch.py
CHANGED
|
@@ -31,6 +31,10 @@ class PyTorchInference:
|
|
| 31 |
self.model_arch = model_arch
|
| 32 |
self.device = device
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# Load checkpoint
|
| 35 |
state_dict = self.load_checkpoint_state(checkpoint_path)
|
| 36 |
state_dict = self.normalize_state_dict_keys(state_dict)
|
|
@@ -41,7 +45,7 @@ class PyTorchInference:
|
|
| 41 |
if feature_dim is None:
|
| 42 |
feature_dim = self.resolve_feature_dim(state_dict)
|
| 43 |
|
| 44 |
-
# Create model
|
| 45 |
self.model = create_model(
|
| 46 |
model_arch,
|
| 47 |
pretrained=False,
|
|
@@ -54,9 +58,19 @@ class PyTorchInference:
|
|
| 54 |
self.model.to(device)
|
| 55 |
self.model.eval()
|
| 56 |
|
| 57 |
-
# Get transform
|
| 58 |
-
|
|
|
|
| 59 |
self.transform = create_transform(**config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@staticmethod
|
| 62 |
def load_checkpoint_state(checkpoint_path: str):
|
|
@@ -116,6 +130,7 @@ class PyTorchInference:
|
|
| 116 |
"""
|
| 117 |
image = image.convert("RGB")
|
| 118 |
tensor = self.transform(image)
|
|
|
|
| 119 |
return tensor.unsqueeze(0)
|
| 120 |
|
| 121 |
def predict(self, image, top_k=5, threshold=0.0):
|
|
@@ -134,7 +149,8 @@ class PyTorchInference:
|
|
| 134 |
|
| 135 |
with torch.no_grad():
|
| 136 |
input_tensor = input_tensor.to(self.device)
|
|
|
|
| 137 |
# Use return_features=False to get classification logits
|
| 138 |
logits = self.model(input_tensor, return_features=False)
|
| 139 |
|
| 140 |
-
return logits.cpu().numpy()[0]
|
|
|
|
| 31 |
self.model_arch = model_arch
|
| 32 |
self.device = device
|
| 33 |
|
| 34 |
+
# Hardcoded input size mapping - based on actual model definitions
|
| 35 |
+
self.input_size = self._get_input_size(model_arch)
|
| 36 |
+
print(f"Using input size: {self.input_size} for model {model_arch}")
|
| 37 |
+
|
| 38 |
# Load checkpoint
|
| 39 |
state_dict = self.load_checkpoint_state(checkpoint_path)
|
| 40 |
state_dict = self.normalize_state_dict_keys(state_dict)
|
|
|
|
| 45 |
if feature_dim is None:
|
| 46 |
feature_dim = self.resolve_feature_dim(state_dict)
|
| 47 |
|
| 48 |
+
# Create model - don't pass img_size, let the model use its default
|
| 49 |
self.model = create_model(
|
| 50 |
model_arch,
|
| 51 |
pretrained=False,
|
|
|
|
| 58 |
self.model.to(device)
|
| 59 |
self.model.eval()
|
| 60 |
|
| 61 |
+
# Get transform - override with our correct input size
|
| 62 |
+
# We manually set the input_size instead of relying on the model's config
|
| 63 |
+
config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model)
|
| 64 |
self.transform = create_transform(**config)
|
| 65 |
+
print(f"Created transform with input size: {self.input_size}")
|
| 66 |
+
|
| 67 |
+
def _get_input_size(self, model_arch):
|
| 68 |
+
"""Get input size based on model architecture - hardcoded to match actual model definitions"""
|
| 69 |
+
if model_arch == 'lsnet_xl_artist_448':
|
| 70 |
+
return 448
|
| 71 |
+
else:
|
| 72 |
+
# All other artist models use 224
|
| 73 |
+
return 224
|
| 74 |
|
| 75 |
@staticmethod
|
| 76 |
def load_checkpoint_state(checkpoint_path: str):
|
|
|
|
| 130 |
"""
|
| 131 |
image = image.convert("RGB")
|
| 132 |
tensor = self.transform(image)
|
| 133 |
+
print(f"Preprocessed image to tensor shape: {tensor.shape}")
|
| 134 |
return tensor.unsqueeze(0)
|
| 135 |
|
| 136 |
def predict(self, image, top_k=5, threshold=0.0):
|
|
|
|
| 149 |
|
| 150 |
with torch.no_grad():
|
| 151 |
input_tensor = input_tensor.to(self.device)
|
| 152 |
+
print(f"Running inference on tensor shape: {input_tensor.shape}")
|
| 153 |
# Use return_features=False to get classification logits
|
| 154 |
logits = self.model(input_tensor, return_features=False)
|
| 155 |
|
| 156 |
+
return logits.cpu().numpy()[0]
|