Spaces:
Runtime error
Runtime error
| """All non-tensor utils | |
| """ | |
| import contextlib | |
| import datetime | |
| import json | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import time | |
| import traceback | |
| from os.path import expandvars | |
| from pathlib import Path | |
| from typing import Any, List, Optional, Union | |
| from uuid import uuid4 | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from addict import Dict | |
| from comet_ml import Experiment | |
| comet_kwargs = { | |
| "auto_metric_logging": False, | |
| "parse_args": True, | |
| "log_env_gpu": True, | |
| "log_env_cpu": True, | |
| "display_summary_level": 0, | |
| } | |
| IMG_EXTENSIONS = set( | |
| [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"] | |
| ) | |
| def resolve(path): | |
| """ | |
| fully resolve a path: | |
| resolve env vars ($HOME etc.) -> expand user (~) -> make absolute | |
| Returns: | |
| pathlib.Path: resolved absolute path | |
| """ | |
| return Path(expandvars(str(path))).expanduser().resolve() | |
| def copy_run_files(opts: Dict) -> None: | |
| """ | |
| Copy the opts's sbatch_file to output_path | |
| Args: | |
| opts (addict.Dict): options | |
| """ | |
| if opts.sbatch_file: | |
| p = resolve(opts.sbatch_file) | |
| if p.exists(): | |
| o = resolve(opts.output_path) | |
| if o.exists(): | |
| shutil.copyfile(p, o / p.name) | |
| if opts.exp_file: | |
| p = resolve(opts.exp_file) | |
| if p.exists(): | |
| o = resolve(opts.output_path) | |
| if o.exists(): | |
| shutil.copyfile(p, o / p.name) | |
| def merge( | |
| source: Union[dict, Dict], destination: Union[dict, Dict] | |
| ) -> Union[dict, Dict]: | |
| """ | |
| run me with nosetests --with-doctest file.py | |
| >>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } } | |
| >>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } } } | |
| >>> merge(b, a) == { | |
| 'first' : { | |
| 'all_rows' : { ' | |
| pass' : 'dog', | |
| 'fail' : 'cat', | |
| 'number' : '5' | |
| } | |
| } | |
| } | |
| True | |
| """ | |
| for key, value in source.items(): | |
| try: | |
| if isinstance(value, dict): | |
| # get node or create one | |
| node = destination.setdefault(key, {}) | |
| merge(value, node) | |
| else: | |
| if isinstance(destination, dict): | |
| destination[key] = value | |
| else: | |
| destination = {key: value} | |
| except TypeError as e: | |
| print(traceback.format_exc()) | |
| print(">>>", source) | |
| print(">>>", destination) | |
| print(">>>", key) | |
| print(">>>", value) | |
| raise Exception(e) | |
| return destination | |
| def load_opts( | |
| path: Optional[Union[str, Path]] = None, | |
| default: Optional[Union[str, Path, dict, Dict]] = None, | |
| commandline_opts: Optional[Union[Dict, dict]] = None, | |
| ) -> Dict: | |
| """Loadsize a configuration Dict from 2 files: | |
| 1. default files with shared values across runs and users | |
| 2. an overriding file with run- and user-specific values | |
| Args: | |
| path (pathlib.Path): where to find the overriding configuration | |
| default (pathlib.Path, optional): Where to find the default opts. | |
| Defaults to None. In which case it is assumed to be a default config | |
| which needs processing such as setting default values for lambdas and gen | |
| fields | |
| Returns: | |
| addict.Dict: options dictionnary, with overwritten default values | |
| """ | |
| if path is None and default is None: | |
| path = ( | |
| resolve(Path(__file__)).parent.parent | |
| / "shared" | |
| / "trainer" | |
| / "defaults.yaml" | |
| ) | |
| if path: | |
| path = resolve(path) | |
| if default is None: | |
| default_opts = {} | |
| else: | |
| if isinstance(default, (str, Path)): | |
| with open(default, "r") as f: | |
| default_opts = yaml.safe_load(f) | |
| else: | |
| default_opts = dict(default) | |
| if path is None: | |
| overriding_opts = {} | |
| else: | |
| with open(path, "r") as f: | |
| overriding_opts = yaml.safe_load(f) or {} | |
| opts = Dict(merge(overriding_opts, default_opts)) | |
| if commandline_opts is not None and isinstance(commandline_opts, dict): | |
| opts = Dict(merge(commandline_opts, opts)) | |
| if opts.train.kitti.pretrained: | |
| assert "kitti" in opts.data.files.train | |
| assert "kitti" in opts.data.files.val | |
| assert opts.train.kitti.epochs > 0 | |
| opts.domains = [] | |
| if "m" in opts.tasks or "s" in opts.tasks or "d" in opts.tasks: | |
| opts.domains.extend(["r", "s"]) | |
| if "p" in opts.tasks: | |
| opts.domains.append("rf") | |
| if opts.train.kitti.pretrain: | |
| opts.domains.append("kitti") | |
| opts.domains = list(set(opts.domains)) | |
| if "s" in opts.tasks: | |
| if opts.gen.encoder.architecture != opts.gen.s.architecture: | |
| print( | |
| "WARNING: segmentation encoder and decoder architectures do not match" | |
| ) | |
| print( | |
| "Encoder: {} <> Decoder: {}".format( | |
| opts.gen.encoder.architecture, opts.gen.s.architecture | |
| ) | |
| ) | |
| if opts.gen.m.use_spade: | |
| if "d" not in opts.tasks or "s" not in opts.tasks: | |
| raise ValueError( | |
| "opts.gen.m.use_spade is True so tasks MUST include" | |
| + "both d and s, but received {}".format(opts.tasks) | |
| ) | |
| if opts.gen.d.classify.enable: | |
| raise ValueError( | |
| "opts.gen.m.use_spade is True but using D as a classifier" | |
| + " which is a non-implemented combination" | |
| ) | |
| if opts.gen.s.depth_feat_fusion is True or opts.gen.s.depth_dada_fusion is True: | |
| opts.gen.s.use_dada = True | |
| events_path = ( | |
| resolve(Path(__file__)).parent.parent / "shared" / "trainer" / "events.yaml" | |
| ) | |
| if events_path.exists(): | |
| with events_path.open("r") as f: | |
| events_dict = yaml.safe_load(f) | |
| events_dict = Dict(events_dict) | |
| opts.events = events_dict | |
| return set_data_paths(opts) | |
| def set_data_paths(opts: Dict) -> Dict: | |
| """Update the data files paths in data.files.train and data.files.val | |
| from data.files.base | |
| Args: | |
| opts (addict.Dict): options | |
| Returns: | |
| addict.Dict: updated options | |
| """ | |
| for mode in ["train", "val"]: | |
| for domain in opts.data.files[mode]: | |
| if opts.data.files.base and not opts.data.files[mode][domain].startswith( | |
| "/" | |
| ): | |
| opts.data.files[mode][domain] = str( | |
| Path(opts.data.files.base) / opts.data.files[mode][domain] | |
| ) | |
| assert Path( | |
| opts.data.files[mode][domain] | |
| ).exists(), "Cannot find {}".format(str(opts.data.files[mode][domain])) | |
| return opts | |
| def load_test_opts(test_file_path: str = "config/trainer/local_tests.yaml") -> Dict: | |
| """Returns the special opts set up for local tests | |
| Args: | |
| test_file_path (str, optional): Name of the file located in config/ | |
| Defaults to "local_tests.yaml". | |
| Returns: | |
| addict.Dict: Opts loaded from defaults.yaml and updated from test_file_path | |
| """ | |
| return load_opts( | |
| Path(__file__).parent.parent / f"{test_file_path}", | |
| default=Path(__file__).parent.parent / "shared/trainer/defaults.yaml", | |
| ) | |
| def get_git_revision_hash() -> str: | |
| """Get current git hash the code is run from | |
| Returns: | |
| str: git hash | |
| """ | |
| try: | |
| return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip() | |
| except Exception as e: | |
| return str(e) | |
| def get_git_branch() -> str: | |
| """Get current git branch name | |
| Returns: | |
| str: git branch name | |
| """ | |
| try: | |
| return ( | |
| subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) | |
| .decode() | |
| .strip() | |
| ) | |
| except Exception as e: | |
| return str(e) | |
| def kill_job(id: Union[int, str]) -> None: | |
| subprocess.check_output(["scancel", str(id)]) | |
| def write_hash(path: Union[str, Path]) -> None: | |
| hash_code = get_git_revision_hash() | |
| with open(path, "w") as f: | |
| f.write(hash_code) | |
| def shortuid(): | |
| return str(uuid4()).split("-")[0] | |
| def datenowshort(): | |
| """ | |
| >>> a = str(datetime.datetime.now()) | |
| >>> print(a) | |
| '2021-02-25 11:34:50.188072' | |
| >>> print(a[5:].split(".")[0].replace(" ", "_")) | |
| '02-25_11:35:41' | |
| Returns: | |
| str: month-day_h:m:s | |
| """ | |
| return str(datetime.datetime.now())[5:].split(".")[0].replace(" ", "_") | |
| def get_increased_path(path: Union[str, Path], use_date: bool = False) -> Path: | |
| """Returns an increased path: if dir exists, returns `dir (1)`. | |
| If `dir (i)` exists, returns `dir (max(i) + 1)` | |
| get_increased_path("test").mkdir() creates `test/` | |
| then | |
| get_increased_path("test").mkdir() creates `test (1)/` | |
| etc. | |
| if `test (3)/` exists but not `test (2)/`, `test (4)/` is created so that indexes | |
| always increase | |
| Args: | |
| path (str or pathlib.Path): the file/directory which may already exist and would | |
| need to be increased | |
| Returns: | |
| pathlib.Path: increased path | |
| """ | |
| fp = resolve(path) | |
| if not fp.exists(): | |
| return fp | |
| if fp.is_file(): | |
| if not use_date: | |
| while fp.exists(): | |
| fp = fp.parent / f"{fp.stem}--{shortuid()}{fp.suffix}" | |
| return fp | |
| else: | |
| while fp.exists(): | |
| time.sleep(0.5) | |
| fp = fp.parent / f"{fp.stem}--{datenowshort()}{fp.suffix}" | |
| return fp | |
| if not use_date: | |
| while fp.exists(): | |
| fp = fp.parent / f"{fp.name}--{shortuid()}" | |
| return fp | |
| else: | |
| while fp.exists(): | |
| time.sleep(0.5) | |
| fp = fp.parent / f"{fp.name}--{datenowshort()}" | |
| return fp | |
| # vals = [] | |
| # for n in fp.parent.glob("{}*".format(fp.stem)): | |
| # if re.match(r".+\(\d+\)", str(n.name)) is not None: | |
| # name = str(n.name) | |
| # start = name.index("(") | |
| # end = name.index(")") | |
| # vals.append(int(name[start + 1 : end])) | |
| # if vals: | |
| # ext = " ({})".format(max(vals) + 1) | |
| # elif fp.exists(): | |
| # ext = " (1)" | |
| # else: | |
| # ext = "" | |
| # return fp.parent / (fp.stem + ext + fp.suffix) | |
| def env_to_path(path: str) -> str: | |
| """Transorms an environment variable mention in a json | |
| into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds | |
| Args: | |
| path (str): path potentially containing the env variable | |
| """ | |
| path_elements = path.split("/") | |
| new_path = [] | |
| for el in path_elements: | |
| if "$" in el: | |
| new_path.append(os.environ[el.replace("$", "")]) | |
| else: | |
| new_path.append(el) | |
| return "/".join(new_path) | |
| def flatten_opts(opts: Dict) -> dict: | |
| """Flattens a multi-level addict.Dict or native dictionnary into a single | |
| level native dict with string keys representing the keys sequence to reach | |
| a value in the original argument. | |
| d = addict.Dict() | |
| d.a.b.c = 2 | |
| d.a.b.d = 3 | |
| d.a.e = 4 | |
| d.f = 5 | |
| flatten_opts(d) | |
| >>> { | |
| "a.b.c": 2, | |
| "a.b.d": 3, | |
| "a.e": 4, | |
| "f": 5, | |
| } | |
| Args: | |
| opts (addict.Dict or dict): addict dictionnary to flatten | |
| Returns: | |
| dict: flattened dictionnary | |
| """ | |
| values_list = [] | |
| def p(d, prefix="", vals=[]): | |
| for k, v in d.items(): | |
| if isinstance(v, (Dict, dict)): | |
| p(v, prefix + k + ".", vals) | |
| elif isinstance(v, list): | |
| if v and isinstance(v[0], (Dict, dict)): | |
| for i, m in enumerate(v): | |
| p(m, prefix + k + "." + str(i) + ".", vals) | |
| else: | |
| vals.append((prefix + k, str(v))) | |
| else: | |
| if isinstance(v, Path): | |
| v = str(v) | |
| vals.append((prefix + k, v)) | |
| p(opts, vals=values_list) | |
| return dict(values_list) | |
| def get_comet_rest_api_key( | |
| path_to_config_file: Optional[Union[str, Path]] = None | |
| ) -> str: | |
| """Gets a comet.ml rest_api_key in the following order: | |
| * config file specified as argument | |
| * environment variable | |
| * .comet.config file in the current working diretory | |
| * .comet.config file in your home | |
| config files must have a line like `rest_api_key=<some api key>` | |
| Args: | |
| path_to_config_file (str or pathlib.Path, optional): config_file to use. | |
| Defaults to None. | |
| Raises: | |
| ValueError: can't find a file | |
| ValueError: can't find the key in a file | |
| Returns: | |
| str: your comet rest_api_key | |
| """ | |
| if "COMET_REST_API_KEY" in os.environ and path_to_config_file is None: | |
| return os.environ["COMET_REST_API_KEY"] | |
| if path_to_config_file is not None: | |
| p = resolve(path_to_config_file) | |
| else: | |
| p = Path() / ".comet.config" | |
| if not p.exists(): | |
| p = Path.home() / ".comet.config" | |
| if not p.exists(): | |
| raise ValueError("Unable to find your COMET_REST_API_KEY") | |
| with p.open("r") as f: | |
| for keys in f: | |
| if "rest_api_key" in keys: | |
| return keys.strip().split("=")[-1].strip() | |
| raise ValueError("Unable to find your COMET_REST_API_KEY in {}".format(str(p))) | |
| def get_files(dirName: str) -> list: | |
| # create a list of file and sub directories | |
| files = sorted(os.listdir(dirName)) | |
| all_files = list() | |
| for entry in files: | |
| fullPath = os.path.join(dirName, entry) | |
| if os.path.isdir(fullPath): | |
| all_files = all_files + get_files(fullPath) | |
| else: | |
| all_files.append(fullPath) | |
| return all_files | |
| def make_json_file( | |
| tasks: List[str], | |
| addresses: List[str], # for windows user, use "\\" instead of using "/" | |
| json_names: List[str] = ["train_jsonfile.json", "val_jsonfile.json"], | |
| splitter: str = "/", | |
| pourcentage_val: float = 0.15, | |
| ) -> None: | |
| """ | |
| How to use it? | |
| e.g. | |
| make_json_file(['x','m','d'], [ | |
| '/network/tmp1/ccai/data/munit_dataset/trainA_size_1200/', | |
| '/network/tmp1/ccai/data/munit_dataset/seg_trainA_size_1200/', | |
| '/network/tmp1/ccai/data/munit_dataset/trainA_megadepth_resized/' | |
| ], ["train_r.json", "val_r.json"]) | |
| Args: | |
| tasks (list): the list of image type like 'x', 'm', 'd', etc. | |
| addresses (list): the list of the corresponding address of the | |
| image type mentioned in tasks | |
| json_names (list): names for the json files, train being first | |
| (e.g. : ["train_r.json", "val_r.json"]) | |
| splitter (str, optional): The path separator for the current OS. | |
| Defaults to '/'. | |
| pourcentage_val: pourcentage of files to go in validation set | |
| """ | |
| assert len(tasks) == len(addresses), "keys and addresses must have the same length!" | |
| files = [get_files(addresses[j]) for j in range(len(tasks))] | |
| n_files_val = int(pourcentage_val * len(files[0])) | |
| n_files_train = len(files[0]) - n_files_val | |
| filenames = [files[0][:n_files_train], files[0][-n_files_val:]] | |
| file_address_map = { | |
| tasks[j]: { | |
| ".".join(file.split(splitter)[-1].split(".")[:-1]): file | |
| for file in files[j] | |
| } | |
| for j in range(len(tasks)) | |
| } | |
| # The tasks of the file_address_map are like 'x', 'm', 'd'... | |
| # The values of the file_address_map are a dictionary whose tasks are the | |
| # filenames without extension whose values are the path of the filename | |
| # e.g. file_address_map = | |
| # {'x': {'A': 'path/to/trainA_size_1200/A.png', ...}, | |
| # 'm': {'A': 'path/to/seg_trainA_size_1200/A.jpg',...} | |
| # 'd': {'A': 'path/to/trainA_megadepth_resized/A.bmp',...} | |
| # ...} | |
| for i, json_name in enumerate(json_names): | |
| dicts = [] | |
| for j in range(len(filenames[i])): | |
| file = filenames[i][j] | |
| filename = file.split(splitter)[-1] # the filename with 'x' extension | |
| filename_ = ".".join( | |
| filename.split(".")[:-1] | |
| ) # the filename without extension | |
| tmp_dict = {} | |
| for k in range(len(tasks)): | |
| tmp_dict[tasks[k]] = file_address_map[tasks[k]][filename_] | |
| dicts.append(tmp_dict) | |
| with open(json_name, "w", encoding="utf-8") as outfile: | |
| json.dump(dicts, outfile, ensure_ascii=False) | |
| def append_task_to_json( | |
| path_to_json: Union[str, Path], | |
| path_to_new_json: Union[str, Path], | |
| path_to_new_images_dir: Union[str, Path], | |
| new_task_name: str, | |
| ): | |
| """Add all files for a task to an existing json file by creating a new json file | |
| in the specified path. | |
| Assumes that the files for the new task have exactly the same names as the ones | |
| for the other tasks | |
| Args: | |
| path_to_json: complete path to the json file to modify | |
| path_to_new_json: complete path to the new json file to be created | |
| path_to_new_images_dir: complete path of the directory where to find the | |
| images for the new task | |
| new_task_name: name of the new task | |
| e.g: | |
| append_json( | |
| "/network/tmp1/ccai/data/climategan/seg/train_r.json", | |
| "/network/tmp1/ccai/data/climategan/seg/train_r_new.json" | |
| "/network/tmp1/ccai/data/munit_dataset/trainA_seg_HRNet/unity_labels", | |
| "s", | |
| ) | |
| """ | |
| ims_list = None | |
| if path_to_json: | |
| path_to_json = Path(path_to_json).resolve() | |
| with open(path_to_json, "r") as f: | |
| ims_list = json.load(f) | |
| files = get_files(path_to_new_images_dir) | |
| if ims_list is None: | |
| raise ValueError(f"Could not find the list in {path_to_json}") | |
| new_ims_list = [None] * len(ims_list) | |
| for i, im_dict in enumerate(ims_list): | |
| new_ims_list[i] = {} | |
| for task, path in im_dict.items(): | |
| new_ims_list[i][task] = path | |
| for i, im_dict in enumerate(ims_list): | |
| for task, path in im_dict.items(): | |
| file_name = os.path.splitext(path)[0] # removes extension | |
| file_name = file_name.rsplit("/", 1)[-1] # only the file_name | |
| file_found = False | |
| for file_path in files: | |
| if file_name in file_path: | |
| file_found = True | |
| new_ims_list[i][new_task_name] = file_path | |
| break | |
| if file_found: | |
| break | |
| else: | |
| print("Error! File ", file_name, "not found in directory!") | |
| return | |
| with open(path_to_new_json, "w", encoding="utf-8") as f: | |
| json.dump(new_ims_list, f, ensure_ascii=False) | |
| def sum_dict(dict1: Union[dict, Dict], dict2: Union[Dict, dict]) -> Union[dict, Dict]: | |
| """Add dict2 into dict1""" | |
| for k, v in dict2.items(): | |
| if not isinstance(v, dict): | |
| dict1[k] += v | |
| else: | |
| sum_dict(dict1[k], dict2[k]) | |
| return dict1 | |
| def div_dict(dict1: Union[dict, Dict], div_by: float) -> dict: | |
| """Divide elements of dict1 by div_by""" | |
| for k, v in dict1.items(): | |
| if not isinstance(v, dict): | |
| dict1[k] /= div_by | |
| else: | |
| div_dict(dict1[k], div_by) | |
| return dict1 | |
| def comet_id_from_url(url: str) -> Optional[str]: | |
| """ | |
| Get comet exp id from its url: | |
| https://www.comet.ml/vict0rsch/climategan/2a1a4a96afe848218c58ac4e47c5375f | |
| -> 2a1a4a96afe848218c58ac4e47c5375f | |
| Args: | |
| url (str): comet exp url | |
| Returns: | |
| str: comet exp id | |
| """ | |
| try: | |
| ids = url.split("/") | |
| ids = [i for i in ids if i] | |
| return ids[-1] | |
| except Exception: | |
| return None | |
| def temp_np_seed(seed: Optional[int]) -> None: | |
| """ | |
| Set temporary numpy seed: | |
| with temp_np_seed(123): | |
| np.random.permutation(3) | |
| Args: | |
| seed (int): temporary numpy seed | |
| """ | |
| state = np.random.get_state() | |
| np.random.seed(seed) | |
| try: | |
| yield | |
| finally: | |
| np.random.set_state(state) | |
| def get_display_indices(opts: Dict, domain: str, length: int) -> list: | |
| """ | |
| Compute the index of images to use for comet logging: | |
| if opts.comet.display_indices is an int, and domain is real: | |
| return range(int) | |
| if opts.comet.display_indices is an int, and domain is sim: | |
| return permutation(length)[:int] | |
| if opts.comet.display_indices is a list: | |
| return list | |
| otherwise return [] | |
| Args: | |
| opts (addict.Dict): options | |
| domain (str): domain for those indices | |
| length (int): length of dataset for the permutation | |
| Returns: | |
| list(int): The indices to display | |
| """ | |
| if domain == "rf": | |
| dsize = max([opts.comet.display_size, opts.train.fid.get("n_images", 0)]) | |
| else: | |
| dsize = opts.comet.display_size | |
| if dsize > length: | |
| print( | |
| f"Warning: dataset is smaller ({length} images) " | |
| + f"than required display indices ({dsize})." | |
| + f" Selecting {length} images." | |
| ) | |
| display_indices = [] | |
| assert isinstance(dsize, (int, list)), "Unknown display size {}".format(dsize) | |
| if isinstance(dsize, int): | |
| assert dsize >= 0, "Display size cannot be < 0" | |
| with temp_np_seed(123): | |
| display_indices = list(np.random.permutation(length)[:dsize]) | |
| elif isinstance(dsize, list): | |
| display_indices = dsize | |
| if not display_indices: | |
| print("Warning: no display indices (utils.get_display_indices)") | |
| return display_indices | |
| def get_latest_path(path: Union[str, Path]) -> Path: | |
| """ | |
| Get the file/dir with largest increment i as `file (i).ext` | |
| Args: | |
| path (str or pathlib.Path): base pattern | |
| Returns: | |
| Path: path found | |
| """ | |
| p = Path(path).resolve() | |
| s = p.stem | |
| e = p.suffix | |
| files = list(p.parent.glob(f"{s}*(*){e}")) | |
| indices = list(p.parent.glob(f"{s}*(*){e}")) | |
| indices = list(map(lambda f: f.name, indices)) | |
| indices = list(map(lambda x: re.findall(r"\((.*?)\)", x)[-1], indices)) | |
| indices = list(map(int, indices)) | |
| if not indices: | |
| f = p | |
| else: | |
| f = files[np.argmax(indices)] | |
| return f | |
| def get_existing_jobID(output_path: Path) -> str: | |
| """ | |
| If the opts in output_path have a jobID, return it. Else, return None | |
| Args: | |
| output_path (pathlib.Path | str): where to look | |
| Returns: | |
| str | None: jobid | |
| """ | |
| op = Path(output_path) | |
| if not op.exists(): | |
| return | |
| opts_path = get_latest_path(op / "opts.yaml") | |
| if not opts_path.exists(): | |
| return | |
| with opts_path.open("r") as f: | |
| opts = yaml.safe_load(f) | |
| jobID = opts.get("jobID", None) | |
| return jobID | |
| def find_existing_training(opts: Dict) -> Optional[Path]: | |
| """ | |
| Looks in all directories like output_path.parent.glob(output_path.name*) | |
| and compares the logged slurm job id with the current opts.jobID | |
| If a match is found, the training should automatically continue in the | |
| matching output directory | |
| If no match is found, this is a new job and it should have a new output path | |
| Args: | |
| opts (Dict): trainer's options | |
| Returns: | |
| Optional[Path]: a path if a matchin jobID is found, None otherwise | |
| """ | |
| if opts.jobID is None: | |
| print("WARNING: current JOBID is None") | |
| return | |
| print("---------- Current job id:", opts.jobID) | |
| path = Path(opts.output_path).resolve() | |
| parent = path.parent | |
| name = path.name | |
| try: | |
| similar_dirs = [p.resolve() for p in parent.glob(f"{name}*") if p.is_dir()] | |
| for sd in similar_dirs: | |
| candidate_jobID = get_existing_jobID(sd) | |
| if candidate_jobID is not None and str(opts.jobID) == str(candidate_jobID): | |
| print(f"Found matching job id in {sd}\n") | |
| return sd | |
| print("Did not find a matching job id in \n {}\n".format(str(similar_dirs))) | |
| except Exception as e: | |
| print("ERROR: Could not resume (find_existing_training)", e) | |
| def pprint(*args: List[Any]): | |
| """ | |
| Prints *args within a box of "=" characters | |
| """ | |
| txt = " ".join(map(str, args)) | |
| col = "=====" | |
| space = " " | |
| head_size = 2 | |
| header = "\n".join(["=" * (len(txt) + 2 * (len(col) + len(space)))] * head_size) | |
| empty = "{}{}{}{}{}".format(col, space, " " * (len(txt)), space, col) | |
| print() | |
| print(header) | |
| print(empty) | |
| print("{}{}{}{}{}".format(col, space, txt, space, col)) | |
| print(empty) | |
| print(header) | |
| print() | |
| def get_existing_comet_id(path: str) -> Optional[str]: | |
| """ | |
| Returns the id of the existing comet experiment stored in path | |
| Args: | |
| path (str): Output pat where to look for the comet exp | |
| Returns: | |
| Optional[str]: comet exp's ID if any was found | |
| """ | |
| comet_previous_path = get_latest_path(Path(path) / "comet_url.txt") | |
| if comet_previous_path.exists(): | |
| with comet_previous_path.open("r") as f: | |
| url = f.read().strip() | |
| return comet_id_from_url(url) | |
| def get_latest_opts(path): | |
| """ | |
| get latest opts dumped in path if they look like *opts*.yaml | |
| and were increased as | |
| opts.yaml < opts (1).yaml < opts (2).yaml etc. | |
| Args: | |
| path (str or pathlib.Path): where to look for opts | |
| Raises: | |
| ValueError: If no match for *opts*.yaml is found | |
| Returns: | |
| addict.Dict: loaded opts | |
| """ | |
| path = Path(path) | |
| opts = get_latest_path(path / "opts.yaml") | |
| assert opts.exists() | |
| with opts.open("r") as f: | |
| opts = Dict(yaml.safe_load(f)) | |
| events_path = Path(__file__).parent.parent / "shared" / "trainer" / "events.yaml" | |
| if events_path.exists(): | |
| with events_path.open("r") as f: | |
| events_dict = yaml.safe_load(f) | |
| events_dict = Dict(events_dict) | |
| opts.events = events_dict | |
| return opts | |
| def text_to_array(text, width=640, height=40): | |
| """ | |
| Creates a numpy array of shape height x width x 3 with | |
| text written on it using PIL | |
| Args: | |
| text (str): text to write | |
| width (int, optional): Width of the resulting array. Defaults to 640. | |
| height (int, optional): Height of the resulting array. Defaults to 40. | |
| Returns: | |
| np.ndarray: Centered text | |
| """ | |
| from PIL import Image, ImageDraw, ImageFont | |
| img = Image.new("RGB", (width, height), (255, 255, 255)) | |
| try: | |
| font = ImageFont.truetype("UnBatang.ttf", 25) | |
| except OSError: | |
| font = ImageFont.load_default() | |
| d = ImageDraw.Draw(img) | |
| text_width, text_height = d.textsize(text) | |
| h = 40 // 2 - 3 * text_height // 2 | |
| w = width // 2 - text_width | |
| d.text((w, h), text, font=font, fill=(30, 30, 30)) | |
| return np.array(img) | |
| def all_texts_to_array(texts, width=640, height=40): | |
| """ | |
| Creates an array of texts, each of height and width specified | |
| by the args, concatenated along their width dimension | |
| Args: | |
| texts (list(str)): List of texts to concatenate | |
| width (int, optional): Individual text's width. Defaults to 640. | |
| height (int, optional): Individual text's height. Defaults to 40. | |
| Returns: | |
| list: len(texts) text arrays with dims height x width x 3 | |
| """ | |
| return [text_to_array(text, width, height) for text in texts] | |
| class Timer: | |
| def __init__(self, name="", store=None, precision=3, ignore=False, cuda=None): | |
| self.name = name | |
| self.store = store | |
| self.precision = precision | |
| self.ignore = ignore | |
| self.cuda = cuda if cuda is not None else torch.cuda.is_available() | |
| if self.cuda: | |
| self._start_event = torch.cuda.Event(enable_timing=True) | |
| self._end_event = torch.cuda.Event(enable_timing=True) | |
| def format(self, n): | |
| return f"{n:.{self.precision}f}" | |
| def __enter__(self): | |
| """Start a new timer as a context manager""" | |
| if self.cuda: | |
| self._start_event.record() | |
| else: | |
| self._start_time = time.perf_counter() | |
| return self | |
| def __exit__(self, *exc_info): | |
| """Stop the context manager timer""" | |
| if self.ignore: | |
| return | |
| if self.cuda: | |
| self._end_event.record() | |
| torch.cuda.synchronize() | |
| new_time = self._start_event.elapsed_time(self._end_event) / 1000 | |
| else: | |
| t = time.perf_counter() | |
| new_time = t - self._start_time | |
| if self.store is not None: | |
| assert isinstance(self.store, list) | |
| self.store.append(new_time) | |
| if self.name: | |
| print(f"[{self.name}] Elapsed time: {self.format(new_time)}") | |
| def get_loader_output_shape_from_opts(opts): | |
| transforms = opts.data.transforms | |
| t = None | |
| for t in transforms[::-1]: | |
| if t.name == "resize": | |
| break | |
| assert t is not None | |
| if isinstance(t.new_size, Dict): | |
| return { | |
| task: ( | |
| t.new_size.get(task, t.new_size.default), | |
| t.new_size.get(task, t.new_size.default), | |
| ) | |
| for task in opts.tasks + ["x"] | |
| } | |
| assert isinstance(t.new_size, int) | |
| new_size = (t.new_size, t.new_size) | |
| return {task: new_size for task in opts.tasks + ["x"]} | |
| def find_target_size(opts, task): | |
| target_size = None | |
| if isinstance(opts.data.transforms[-1].new_size, int): | |
| target_size = opts.data.transforms[-1].new_size | |
| else: | |
| if task in opts.data.transforms[-1].new_size: | |
| target_size = opts.data.transforms[-1].new_size[task] | |
| else: | |
| assert "default" in opts.data.transforms[-1].new_size | |
| target_size = opts.data.transforms[-1].new_size["default"] | |
| return target_size | |
| def to_128(im, w_target=-1): | |
| h, w = im.shape[:2] | |
| aspect_ratio = h / w | |
| if w_target < 0: | |
| w_target = w | |
| nw = int(w_target / 128) * 128 | |
| nh = int(nw * aspect_ratio / 128) * 128 | |
| return nh, nw | |
| def is_image_file(filename): | |
| """Check that a file's name points to a known image format""" | |
| if isinstance(filename, Path): | |
| return filename.suffix in IMG_EXTENSIONS | |
| return Path(filename).suffix in IMG_EXTENSIONS | |
| def find_images(path, recursive=False): | |
| """ | |
| Get a list of all images contained in a directory: | |
| - path.glob("*") if not recursive | |
| - path.glob("**/*") if recursive | |
| """ | |
| p = Path(path) | |
| assert p.exists() | |
| assert p.is_dir() | |
| pattern = "*" | |
| if recursive: | |
| pattern += "*/*" | |
| return [i for i in p.glob(pattern) if i.is_file() and is_image_file(i)] | |
| def cols(): | |
| try: | |
| col = os.get_terminal_size().columns | |
| except Exception: | |
| col = 50 | |
| return col | |
| def upload_images_to_exp( | |
| path, exp=None, project_name="climategan-eval", sleep=-1, verbose=0 | |
| ): | |
| ims = find_images(path) | |
| end = None | |
| c = cols() | |
| if verbose == 1: | |
| end = "\r" | |
| if verbose > 1: | |
| end = "\n" | |
| if exp is None: | |
| exp = Experiment(project_name=project_name) | |
| for im in ims: | |
| exp.log_image(str(im)) | |
| if verbose > 0: | |
| if verbose == 1: | |
| print(" " * (c - 1), end="\r", flush=True) | |
| print(str(im), end=end, flush=True) | |
| if sleep > 0: | |
| time.sleep(sleep) | |
| return exp | |