Spaces:
Runtime error
Runtime error
add option to return intermediates as tensors
Browse files- climategan/trainer.py +18 -13
climategan/trainer.py
CHANGED
|
@@ -334,19 +334,24 @@ class Trainer:
|
|
| 334 |
output_data["smog"] = smog
|
| 335 |
|
| 336 |
if return_intermediates:
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
return output_data
|
| 352 |
|
|
|
|
| 334 |
output_data["smog"] = smog
|
| 335 |
|
| 336 |
if return_intermediates:
|
| 337 |
+
if numpy:
|
| 338 |
+
output_data["mask"] = (
|
| 339 |
+
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
| 340 |
+
)
|
| 341 |
+
output_data["depth"] = (
|
| 342 |
+
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
| 343 |
+
)
|
| 344 |
+
output_data["segmentation"] = (
|
| 345 |
+
decode_segmap_merged_labels(segmentation, "r", False)
|
| 346 |
+
.cpu()
|
| 347 |
+
.permute(0, 2, 3, 1)
|
| 348 |
+
.numpy()
|
| 349 |
+
.astype(np.uint8)
|
| 350 |
+
)
|
| 351 |
+
else:
|
| 352 |
+
output_data["mask"] = mask
|
| 353 |
+
output_data["depth"] = depth
|
| 354 |
+
output_data["segmentation"] = segmentation
|
| 355 |
|
| 356 |
return output_data
|
| 357 |
|