Spaces:
Runtime error
Runtime error
| import sys | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from comet_ml import Experiment | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from PIL import Image | |
| from skimage.color import gray2rgb | |
| from skimage.io import imread | |
| from skimage.transform import resize | |
| from skimage.util import img_as_ubyte | |
| from tqdm import tqdm | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| import climategan | |
| GROUND_MODEL = "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--ground" | |
| def uint8(array): | |
| return array.astype(np.uint8) | |
| def crop_and_resize(image_path, label_path): | |
| """ | |
| Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
| is 640, then crops this resized image in its center so that the output is 640x640 | |
| without aspect ratio distortion | |
| Args: | |
| image_path (Path or str): Path to an image | |
| label_path (Path or str): Path to the image's associated label | |
| Returns: | |
| tuple((np.ndarray, np.ndarray)): (new image, new label) | |
| """ | |
| img = imread(image_path) | |
| lab = imread(label_path) | |
| # if img.shape[-1] == 4: | |
| # img = uint8(rgba2rgb(img) * 255) | |
| # TODO: remove (debug) | |
| if img.shape[:2] != lab.shape[:2]: | |
| print( | |
| "\nWARNING: shape mismatch: im -> {}, lab -> {}".format( | |
| image_path.name, label_path.name | |
| ) | |
| ) | |
| # breakpoint() | |
| # resize keeping aspect ratio: smallest dim is 640 | |
| h, w = img.shape[:2] | |
| if h < w: | |
| size = (640, int(640 * w / h)) | |
| else: | |
| size = (int(640 * h / w), 640) | |
| r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
| r_img = uint8(r_img) | |
| r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0) | |
| r_lab = uint8(r_lab) | |
| # crop in the center | |
| H, W = r_img.shape[:2] | |
| top = (H - 640) // 2 | |
| left = (W - 640) // 2 | |
| rc_img = r_img[top : top + 640, left : left + 640, :] | |
| rc_lab = ( | |
| r_lab[top : top + 640, left : left + 640, :] | |
| if r_lab.ndim == 3 | |
| else r_lab[top : top + 640, left : left + 640] | |
| ) | |
| return rc_img, rc_lab | |
| def load_ground(ground_output_path, ref_image_path): | |
| gop = Path(ground_output_path) | |
| rip = Path(ref_image_path) | |
| ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list( | |
| (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png") | |
| ) | |
| if len(ground_paths) == 0: | |
| raise ValueError( | |
| f"Could not find a ground match in {str(gop)} for image {str(rip)}" | |
| ) | |
| elif len(ground_paths) > 1: | |
| raise ValueError( | |
| f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:" | |
| + f" {list(map(str, ground_paths))}" | |
| ) | |
| ground_path = ground_paths[0] | |
| _, ground = crop_and_resize(rip, ground_path) | |
| ground = (ground > 0).astype(np.float32) | |
| return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda() | |
| def parse_args(): | |
| parser = ArgumentParser() | |
| parser.add_argument("-y", "--yaml", help="Path to a list of models") | |
| parser.add_argument( | |
| "--disable_loading", | |
| action="store_true", | |
| default=False, | |
| help="Disable loading of existing inferences", | |
| ) | |
| parser.add_argument( | |
| "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str | |
| ) | |
| parser.add_argument( | |
| "--tasks", | |
| nargs="*", | |
| help="Comet.ml tags", | |
| default=["x", "d", "s", "m", "mx", "p"], | |
| type=str, | |
| ) | |
| args = parser.parse_args() | |
| print("Received args:") | |
| print(vars(args)) | |
| return args | |
| def load_images_and_labels( | |
| path="/miniscratch/_groups/ccai/data/omnigan/masker-test-set", | |
| ): | |
| p = Path(path) | |
| ims_path = p / "imgs" | |
| lab_path = p / "labels" | |
| ims = sorted(climategan.utils.find_images(ims_path), key=lambda x: x.name) | |
| labs = sorted( | |
| climategan.utils.find_images(lab_path), | |
| key=lambda x: x.name.replace("_labeled.", "."), | |
| ) | |
| xs = climategan.transforms.PrepareInference()(ims) | |
| ys = climategan.transforms.PrepareInference(is_label=True)(labs) | |
| return xs, ys, ims, labs | |
| def load_inferences(inf_path, im_paths): | |
| try: | |
| assert inf_path.exists() | |
| assert sorted([i.stem for i in im_paths]) == sorted( | |
| [i.stem for i in inf_path.glob("*.pt")] | |
| ) | |
| return [torch.load(str(i)) for i in tqdm(list(inf_path.glob("*.pt")))] | |
| except Exception as e: | |
| print() | |
| print(e) | |
| print("Aborting Loading") | |
| print() | |
| return None | |
| def get_or_load_inferences( | |
| m_path, device, xs, is_ground, im_paths, ground_model, try_load=True | |
| ): | |
| inf_path = Path(m_path) / "inferences" | |
| if try_load: | |
| print("Trying to load existing inferences:") | |
| outputs = load_inferences(inf_path, im_paths) | |
| if outputs is not None: | |
| print("Successfully loaded existing inferences") | |
| return outputs | |
| trainer = climategan.trainer.Trainer.resume_from_path( | |
| m_path if not is_ground else ground_model, | |
| inference=True, | |
| new_exp=None, | |
| device=device, | |
| ) | |
| inf_path.mkdir(exist_ok=True) | |
| outputs = [] | |
| for i, x in enumerate(tqdm(xs)): | |
| x = x.to(trainer.device) | |
| if not is_ground: | |
| out = trainer.G.decode(x=x) | |
| else: | |
| out = {"m": load_ground(GROUND_MODEL, im_paths[i])} | |
| out["p"] = trainer.G.paint(out["m"] > 0.5, x) | |
| out["x"] = x | |
| inference = {k: v.cpu() for k, v in out.items()} | |
| outputs.append(inference) | |
| torch.save(inference, inf_path / f"{im_paths[i].stem}.pt") | |
| print() | |
| return outputs | |
| def numpify(outputs): | |
| nps = [] | |
| print("Numpifying...") | |
| for o in tqdm(outputs): | |
| x = (o["x"][0].permute(1, 2, 0).numpy() + 1) / 2 | |
| m = o["m"] | |
| m = (m[0, 0, :, :].numpy() > 0.5).astype(np.uint8) | |
| p = (o["p"][0].permute(1, 2, 0).numpy() + 1) / 2 | |
| data = {"m": m, "p": p, "x": x} | |
| if "s" in o: | |
| s = climategan.data.decode_segmap_merged_labels(o["s"], "r", False) / 255.0 | |
| data["s"] = s[0].permute(1, 2, 0).numpy() | |
| if "d" in o: | |
| d = climategan.tutils.normalize_tensor(o["d"]).squeeze().numpy() | |
| data["d"] = d | |
| nps.append({k: img_as_ubyte(v) for k, v in data.items()}) | |
| return nps | |
| def concat_npy_for_model(data, tasks): | |
| assert "m" in data | |
| assert "x" in data | |
| assert "p" in data | |
| x = mask = depth = seg = painted = masked = None | |
| x = data["x"] | |
| painted = data["p"] | |
| mask = (gray2rgb(data["m"]) * 255).astype(np.uint8) | |
| painted = data["p"] | |
| masked = (1 - gray2rgb(data["m"])) * x | |
| concats = [] | |
| if "d" in data: | |
| depth = img_as_ubyte( | |
| gray2rgb( | |
| resize(data["d"], data["x"].shape[:2], anti_aliasing=True, order=1) | |
| ) | |
| ) | |
| else: | |
| depth = np.ones_like(data["x"]) * 255 | |
| if "s" in data: | |
| seg = img_as_ubyte( | |
| resize(data["s"], data["x"].shape[:2], anti_aliasing=False, order=0) | |
| ) | |
| else: | |
| seg = np.ones_like(data["x"]) * 255 | |
| for t in tasks: | |
| if t == "x": | |
| concats.append(x) | |
| if t == "m": | |
| concats.append(mask) | |
| elif t == "mx": | |
| concats.append(masked) | |
| elif t == "d": | |
| concats.append(depth) | |
| elif t == "s": | |
| concats.append(seg) | |
| elif t == "p": | |
| concats.append(painted) | |
| row = np.concatenate(concats, axis=1) | |
| return row | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| with open(args.yaml, "r") as f: | |
| maskers = yaml.safe_load(f) | |
| if "models" in maskers: | |
| maskers = maskers["models"] | |
| load = not args.disable_loading | |
| tags = args.tags | |
| tasks = args.tasks | |
| ground_model = None | |
| for m in maskers: | |
| if "ground" not in maskers: | |
| ground_model = m | |
| break | |
| if ground_model is None: | |
| raise ValueError("Could not find a non-ground model to get a painter") | |
| device = torch.device("cuda:0") | |
| torch.set_grad_enabled(False) | |
| xs, ys, im_paths, lab_paths = load_images_and_labels() | |
| np_outs = {} | |
| names = [] | |
| for m_path in maskers: | |
| opt_path = Path(m_path) / "opts.yaml" | |
| with opt_path.open("r") as f: | |
| opt = yaml.safe_load(f) | |
| name = ( | |
| ", ".join( | |
| [ | |
| t | |
| for t in sorted(opt["comet"]["tags"]) | |
| if "branch" not in t and "ablation" not in t and "trash" not in t | |
| ] | |
| ) | |
| if "--ground" not in m_path | |
| else "ground" | |
| ) | |
| names.append(name) | |
| is_ground = name == "ground" | |
| print("#" * 100) | |
| print("\n>>> Processing", name) | |
| print() | |
| outputs = get_or_load_inferences( | |
| m_path, device, xs, is_ground, im_paths, ground_model, load | |
| ) | |
| nps = numpify(outputs) | |
| np_outs[name] = nps | |
| exp = Experiment(project_name="climategan-inferences", display_summary_level=0) | |
| exp.log_parameter("names", names) | |
| exp.add_tags(tags) | |
| for i in tqdm(range(len(xs))): | |
| all_models_for_image = [] | |
| for name in names: | |
| xpmds = concat_npy_for_model(np_outs[name][i], tasks) | |
| all_models_for_image.append(xpmds) | |
| full_im = np.concatenate(all_models_for_image, axis=0) | |
| pil_im = Image.fromarray(full_im) | |
| exp.log_image(pil_im, name=im_paths[i].stem.replace(".", "_"), step=i) | |