Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import time | |
| import numpy as np | |
| import random | |
| import os | |
| import socket | |
| import typing as tp | |
| import torch | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| # Change this to reflect your cluster layout. | |
| # The GPU for a given rank is (rank % GPUS_PER_NODE). | |
| GPUS_PER_NODE = 8 | |
| SETUP_RETRY_COUNT = 3 | |
| used_device = 0 | |
| def setup(rank, world_size): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12355" | |
| # initialize the process group | |
| dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
| def cleanup(): | |
| dist.destroy_process_group() | |
| def setup_dist(device=0): | |
| """ | |
| Setup a distributed process group. | |
| """ | |
| global used_device | |
| used_device = device | |
| if dist.is_initialized(): | |
| return | |
| def dev(): | |
| """ | |
| Get the device to use for torch.distributed. | |
| """ | |
| global used_device | |
| if torch.cuda.is_available() and used_device >= 0: | |
| return torch.device(f"cuda:{used_device}") | |
| return torch.device("cpu") | |
| def load_state_dict(path, **kwargs): | |
| """ | |
| Load a PyTorch file without redundant fetches across MPI ranks. | |
| """ | |
| return torch.load(path, **kwargs) | |
| def sync_params(params): | |
| """ | |
| Synchronize a sequence of Tensors across ranks from rank 0. | |
| """ | |
| for p in params: | |
| with torch.no_grad(): | |
| dist.broadcast(p, 0) | |
| def _find_free_port(): | |
| try: | |
| s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| s.bind(("", 0)) | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| return s.getsockname()[1] | |
| finally: | |
| s.close() | |
| def world_size(): | |
| if torch.distributed.is_initialized(): | |
| return torch.distributed.get_world_size() | |
| else: | |
| return 1 | |
| def is_distributed(): | |
| return world_size() > 1 | |
| def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): | |
| if is_distributed(): | |
| return torch.distributed.all_reduce(tensor, op) | |
| def _is_complex_or_float(tensor): | |
| return torch.is_floating_point(tensor) or torch.is_complex(tensor) | |
| def _check_number_of_params(params: tp.List[torch.Tensor]): | |
| # utility function to check that the number of params in all workers is the same, | |
| # and thus avoid a deadlock with distributed all reduce. | |
| if not is_distributed() or not params: | |
| return | |
| tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) | |
| all_reduce(tensor) | |
| if tensor.item() != len(params) * world_size(): | |
| # If not all the workers have the same number, for at least one of them, | |
| # this inequality will be verified. | |
| raise RuntimeError( | |
| f"Mismatch in number of params: ours is {len(params)}, " | |
| "at least one worker has a different one." | |
| ) | |
| def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): | |
| """Broadcast the tensors from the given parameters to all workers. | |
| This can be used to ensure that all workers have the same model to start with. | |
| """ | |
| if not is_distributed(): | |
| return | |
| tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] | |
| _check_number_of_params(tensors) | |
| handles = [] | |
| for tensor in tensors: | |
| handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) | |
| handles.append(handle) | |
| for handle in handles: | |
| handle.wait() | |
| def fixseed(seed): | |
| torch.backends.cudnn.benchmark = False | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| def prGreen(skk): | |
| print("\033[92m {}\033[00m".format(skk)) | |
| def prRed(skk): | |
| print("\033[91m {}\033[00m".format(skk)) | |
| def to_numpy(tensor): | |
| if torch.is_tensor(tensor): | |
| return tensor.cpu().numpy() | |
| elif type(tensor).__module__ != "numpy": | |
| raise ValueError("Cannot convert {} to numpy array".format(type(tensor))) | |
| return tensor | |
| def to_torch(ndarray): | |
| if type(ndarray).__module__ == "numpy": | |
| return torch.from_numpy(ndarray) | |
| elif not torch.is_tensor(ndarray): | |
| raise ValueError("Cannot convert {} to torch tensor".format(type(ndarray))) | |
| return ndarray | |
| def cleanexit(): | |
| import sys | |
| import os | |
| try: | |
| sys.exit(0) | |
| except SystemExit: | |
| os._exit(0) | |
| def load_model_wo_clip(model, state_dict): | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all([k.startswith("clip_model.") for k in missing_keys]) | |
| def freeze_joints(x, joints_to_freeze): | |
| # Freezes selected joint *rotations* as they appear in the first frame | |
| # x [bs, [root+n_joints], joint_dim(6), seqlen] | |
| frozen = x.detach().clone() | |
| frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1] | |
| return frozen | |
| class TimerError(Exception): | |
| """A custom exception used to report errors in use of Timer class""" | |
| class Timer: | |
| def __init__(self): | |
| self._start_time = None | |
| def start(self): | |
| """Start a new timer""" | |
| if self._start_time is not None: | |
| raise TimerError(f"Timer is running. Use .stop() to stop it") | |
| self._start_time = time.perf_counter() | |
| def stop(self, iter=None): | |
| """Stop the timer, and report the elapsed time""" | |
| if self._start_time is None: | |
| raise TimerError(f"Timer is not running. Use .start() to start it") | |
| elapsed_time = time.perf_counter() - self._start_time | |
| self._start_time = None | |
| iter_msg = "" | |
| if iter is not None: | |
| if iter > elapsed_time: | |
| iter_per_sec = iter / elapsed_time | |
| iter_msg = f"[iter/s: {iter_per_sec:0.4f}]" | |
| else: | |
| sec_per_iter = elapsed_time / iter | |
| iter_msg = f"[s/iter: {sec_per_iter:0.4f}]" | |
| print(f"Elapsed time: {elapsed_time:0.4f} seconds {iter_msg}") | |