DraconicDragon commited on
Commit
c4223ac
·
verified ·
1 Parent(s): 9df5a22

Update inference_pytorch.py

Browse files
Files changed (1) hide show
  1. inference_pytorch.py +20 -4
inference_pytorch.py CHANGED
@@ -31,6 +31,10 @@ class PyTorchInference:
31
  self.model_arch = model_arch
32
  self.device = device
33
 
 
 
 
 
34
  # Load checkpoint
35
  state_dict = self.load_checkpoint_state(checkpoint_path)
36
  state_dict = self.normalize_state_dict_keys(state_dict)
@@ -41,7 +45,7 @@ class PyTorchInference:
41
  if feature_dim is None:
42
  feature_dim = self.resolve_feature_dim(state_dict)
43
 
44
- # Create model
45
  self.model = create_model(
46
  model_arch,
47
  pretrained=False,
@@ -54,9 +58,19 @@ class PyTorchInference:
54
  self.model.to(device)
55
  self.model.eval()
56
 
57
- # Get transform
58
- config = resolve_data_config({}, model=self.model)
 
59
  self.transform = create_transform(**config)
 
 
 
 
 
 
 
 
 
60
 
61
  @staticmethod
62
  def load_checkpoint_state(checkpoint_path: str):
@@ -116,6 +130,7 @@ class PyTorchInference:
116
  """
117
  image = image.convert("RGB")
118
  tensor = self.transform(image)
 
119
  return tensor.unsqueeze(0)
120
 
121
  def predict(self, image, top_k=5, threshold=0.0):
@@ -134,7 +149,8 @@ class PyTorchInference:
134
 
135
  with torch.no_grad():
136
  input_tensor = input_tensor.to(self.device)
 
137
  # Use return_features=False to get classification logits
138
  logits = self.model(input_tensor, return_features=False)
139
 
140
- return logits.cpu().numpy()[0]
 
31
  self.model_arch = model_arch
32
  self.device = device
33
 
34
+ # Hardcoded input size mapping - based on actual model definitions
35
+ self.input_size = self._get_input_size(model_arch)
36
+ print(f"Using input size: {self.input_size} for model {model_arch}")
37
+
38
  # Load checkpoint
39
  state_dict = self.load_checkpoint_state(checkpoint_path)
40
  state_dict = self.normalize_state_dict_keys(state_dict)
 
45
  if feature_dim is None:
46
  feature_dim = self.resolve_feature_dim(state_dict)
47
 
48
+ # Create model - don't pass img_size, let the model use its default
49
  self.model = create_model(
50
  model_arch,
51
  pretrained=False,
 
58
  self.model.to(device)
59
  self.model.eval()
60
 
61
+ # Get transform - override with our correct input size
62
+ # We manually set the input_size instead of relying on the model's config
63
+ config = resolve_data_config({'input_size': (3, self.input_size, self.input_size)}, model=self.model)
64
  self.transform = create_transform(**config)
65
+ print(f"Created transform with input size: {self.input_size}")
66
+
67
+ def _get_input_size(self, model_arch):
68
+ """Get input size based on model architecture - hardcoded to match actual model definitions"""
69
+ if model_arch == 'lsnet_xl_artist_448':
70
+ return 448
71
+ else:
72
+ # All other artist models use 224
73
+ return 224
74
 
75
  @staticmethod
76
  def load_checkpoint_state(checkpoint_path: str):
 
130
  """
131
  image = image.convert("RGB")
132
  tensor = self.transform(image)
133
+ print(f"Preprocessed image to tensor shape: {tensor.shape}")
134
  return tensor.unsqueeze(0)
135
 
136
  def predict(self, image, top_k=5, threshold=0.0):
 
149
 
150
  with torch.no_grad():
151
  input_tensor = input_tensor.to(self.device)
152
+ print(f"Running inference on tensor shape: {input_tensor.shape}")
153
  # Use return_features=False to get classification logits
154
  logits = self.model(input_tensor, return_features=False)
155
 
156
+ return logits.cpu().numpy()[0]