Spaces:
Runtime error
Runtime error
add intermediate outputs
Browse files- app.py +10 -0
- climategan/trainer.py +15 -3
- climategan_wrapper.py +2 -1
app.py
CHANGED
|
@@ -41,6 +41,8 @@ def predict(cg: ClimateGAN, api_key):
|
|
| 41 |
masked_input = output_dict["masked_input"]
|
| 42 |
wildfire = output_dict["wildfire"]
|
| 43 |
smog = output_dict["smog"]
|
|
|
|
|
|
|
| 44 |
|
| 45 |
climategan_flood = output_dict.get(
|
| 46 |
"climategan_flood",
|
|
@@ -62,6 +64,8 @@ def predict(cg: ClimateGAN, api_key):
|
|
| 62 |
return (
|
| 63 |
input_image,
|
| 64 |
masked_input,
|
|
|
|
|
|
|
| 65 |
climategan_flood,
|
| 66 |
stable_flood,
|
| 67 |
stable_copy_flood,
|
|
@@ -127,6 +131,12 @@ if __name__ == "__main__":
|
|
| 127 |
outputs.append(
|
| 128 |
gr.outputs.Image(type="numpy", label="Masked input image"),
|
| 129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
with gr.Row():
|
| 131 |
outputs.append(
|
| 132 |
gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
|
|
|
|
| 41 |
masked_input = output_dict["masked_input"]
|
| 42 |
wildfire = output_dict["wildfire"]
|
| 43 |
smog = output_dict["smog"]
|
| 44 |
+
depth = np.repeat(output_dict["depth"][..., None], 3, axis=-1)
|
| 45 |
+
segmentation = output_dict["segmentation"]
|
| 46 |
|
| 47 |
climategan_flood = output_dict.get(
|
| 48 |
"climategan_flood",
|
|
|
|
| 64 |
return (
|
| 65 |
input_image,
|
| 66 |
masked_input,
|
| 67 |
+
segmentation,
|
| 68 |
+
depth,
|
| 69 |
climategan_flood,
|
| 70 |
stable_flood,
|
| 71 |
stable_copy_flood,
|
|
|
|
| 131 |
outputs.append(
|
| 132 |
gr.outputs.Image(type="numpy", label="Masked input image"),
|
| 133 |
)
|
| 134 |
+
outputs.append(
|
| 135 |
+
gr.outputs.Image(type="numpy", label="Segmentation map"),
|
| 136 |
+
)
|
| 137 |
+
outputs.append(
|
| 138 |
+
gr.outputs.Image(type="numpy", label="Depth map"),
|
| 139 |
+
)
|
| 140 |
with gr.Row():
|
| 141 |
outputs.append(
|
| 142 |
gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
|
climategan/trainer.py
CHANGED
|
@@ -22,7 +22,8 @@ from torch import autograd, sigmoid, softmax
|
|
| 22 |
from torch.cuda.amp import GradScaler, autocast
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
| 25 |
-
from climategan.data import get_all_loaders
|
|
|
|
| 26 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
| 27 |
from climategan.eval_metrics import accuracy, mIOU
|
| 28 |
from climategan.fid import compute_val_fid
|
|
@@ -38,6 +39,7 @@ from climategan.tutils import (
|
|
| 38 |
get_WGAN_gradient,
|
| 39 |
lrgb2srgb,
|
| 40 |
normalize,
|
|
|
|
| 41 |
print_num_parameters,
|
| 42 |
shuffle_batch_tuple,
|
| 43 |
srgb2lrgb,
|
|
@@ -226,7 +228,7 @@ class Trainer:
|
|
| 226 |
cloudy=True,
|
| 227 |
auto_resize_640=False,
|
| 228 |
ignore_event=set(),
|
| 229 |
-
|
| 230 |
):
|
| 231 |
"""
|
| 232 |
Create a dictionnary of events from a numpy or tensor,
|
|
@@ -331,10 +333,20 @@ class Trainer:
|
|
| 331 |
smog = (smog * 255).astype(np.uint8)
|
| 332 |
output_data["smog"] = smog
|
| 333 |
|
| 334 |
-
if
|
| 335 |
output_data["mask"] = (
|
| 336 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
| 337 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
return output_data
|
| 340 |
|
|
|
|
| 22 |
from torch.cuda.amp import GradScaler, autocast
|
| 23 |
from tqdm import tqdm
|
| 24 |
|
| 25 |
+
from climategan.data import get_all_loaders, decode_segmap_merged_labels
|
| 26 |
+
|
| 27 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
| 28 |
from climategan.eval_metrics import accuracy, mIOU
|
| 29 |
from climategan.fid import compute_val_fid
|
|
|
|
| 39 |
get_WGAN_gradient,
|
| 40 |
lrgb2srgb,
|
| 41 |
normalize,
|
| 42 |
+
normalize_tensor,
|
| 43 |
print_num_parameters,
|
| 44 |
shuffle_batch_tuple,
|
| 45 |
srgb2lrgb,
|
|
|
|
| 228 |
cloudy=True,
|
| 229 |
auto_resize_640=False,
|
| 230 |
ignore_event=set(),
|
| 231 |
+
return_intermediates=False,
|
| 232 |
):
|
| 233 |
"""
|
| 234 |
Create a dictionnary of events from a numpy or tensor,
|
|
|
|
| 333 |
smog = (smog * 255).astype(np.uint8)
|
| 334 |
output_data["smog"] = smog
|
| 335 |
|
| 336 |
+
if return_intermediates:
|
| 337 |
output_data["mask"] = (
|
| 338 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
| 339 |
)
|
| 340 |
+
output_data["depth"] = (
|
| 341 |
+
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
| 342 |
+
)
|
| 343 |
+
output_data["segmentation"] = (
|
| 344 |
+
decode_segmap_merged_labels(segmentation, "r", False)
|
| 345 |
+
.cpu()
|
| 346 |
+
.permute(0, 2, 3, 1)
|
| 347 |
+
.numpy()
|
| 348 |
+
.astype(np.uint8)
|
| 349 |
+
)
|
| 350 |
|
| 351 |
return output_data
|
| 352 |
|
climategan_wrapper.py
CHANGED
|
@@ -15,6 +15,7 @@ from skimage.transform import resize
|
|
| 15 |
|
| 16 |
from climategan.trainer import Trainer
|
| 17 |
|
|
|
|
| 18 |
CUDA = torch.cuda.is_available()
|
| 19 |
|
| 20 |
|
|
@@ -313,7 +314,7 @@ class ClimateGAN:
|
|
| 313 |
bin_value=0.5,
|
| 314 |
half=CUDA,
|
| 315 |
ignore_event=ignore_event,
|
| 316 |
-
|
| 317 |
)
|
| 318 |
|
| 319 |
outputs["input"] = uint8(images, True)
|
|
|
|
| 15 |
|
| 16 |
from climategan.trainer import Trainer
|
| 17 |
|
| 18 |
+
|
| 19 |
CUDA = torch.cuda.is_available()
|
| 20 |
|
| 21 |
|
|
|
|
| 314 |
bin_value=0.5,
|
| 315 |
half=CUDA,
|
| 316 |
ignore_event=ignore_event,
|
| 317 |
+
return_intermediates=True,
|
| 318 |
)
|
| 319 |
|
| 320 |
outputs["input"] = uint8(images, True)
|