| | from typing import Optional |
| | import wandb |
| | import numpy as np |
| | import torch |
| |
|
| | import matplotlib.pyplot as plt |
| | import cv2 |
| | import matplotlib.pyplot as plt |
| | from tqdm import trange, tqdm |
| | import matplotlib.animation as animation |
| | from pathlib import Path |
| |
|
| | plt.set_loglevel("warning") |
| |
|
| | from torchmetrics.functional import mean_squared_error, peak_signal_noise_ratio |
| | from torchmetrics.functional import ( |
| | structural_similarity_index_measure, |
| | universal_image_quality_index, |
| | ) |
| | from algorithms.common.metrics import ( |
| | FrechetVideoDistance, |
| | LearnedPerceptualImagePatchSimilarity, |
| | FrechetInceptionDistance, |
| | ) |
| |
|
| |
|
| | |
| | def log_video( |
| | observation_hat, |
| | observation_gt=None, |
| | step=0, |
| | namespace="train", |
| | prefix="video", |
| | context_frames=0, |
| | color=(255, 0, 0), |
| | logger=None, |
| | ): |
| | """ |
| | take in video tensors in range [-1, 1] and log into wandb |
| | |
| | :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width) |
| | :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width) |
| | :param step: an int indicating the step number |
| | :param namespace: a string specify a name space this video logging falls under, e.g. train, val |
| | :param prefix: a string specify a prefix for the video name |
| | :param context_frames: an int indicating how many frames in observation_hat are ground truth given as context |
| | :param color: a tuple of 3 numbers specifying the color of the border for ground truth frames |
| | :param logger: optional logger to use. use global wandb if not specified |
| | """ |
| | if not logger: |
| | logger = wandb |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | if observation_gt is not None: |
| | video = torch.cat([observation_hat, observation_gt], -2).detach().cpu().numpy() |
| | else: |
| | video = torch.cat([observation_hat], -1).detach().cpu().numpy() |
| | video = np.transpose(np.clip(video, a_min=0.0, a_max=1.0) * 255, (1, 0, 2, 3, 4)).astype(np.uint8) |
| | |
| | n_samples = len(video) |
| | |
| | for i in range(n_samples): |
| | logger.log( |
| | { |
| | f"{namespace}/{prefix}_{i}": wandb.Video(video[i], fps=5), |
| | f"trainer/global_step": step, |
| | } |
| | ) |
| |
|
| |
|
| | def get_validation_metrics_for_videos( |
| | observation_hat, |
| | observation_gt, |
| | lpips_model: Optional[LearnedPerceptualImagePatchSimilarity] = None, |
| | fid_model: Optional[FrechetInceptionDistance] = None, |
| | fvd_model: Optional[FrechetVideoDistance] = None, |
| | ): |
| | """ |
| | :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width) |
| | :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width) |
| | :param lpips_model: a LearnedPerceptualImagePatchSimilarity object from algorithm.common.metrics |
| | :param fid_model: a FrechetInceptionDistance object from algorithm.common.metrics |
| | :param fvd_model: a FrechetVideoDistance object from algorithm.common.metrics |
| | :return: a tuple of metrics |
| | """ |
| | frame, batch, channel, height, width = observation_hat.shape |
| | output_dict = {} |
| | observation_gt = observation_gt.type_as(observation_hat) |
| |
|
| | if frame < 9: |
| | fvd_model = None |
| |
|
| | observation_hat = observation_hat.float() |
| | observation_gt = observation_gt.float() |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | frame_wise_psnr = [] |
| | for f in range(observation_hat.shape[0]): |
| | frame_wise_psnr.append(peak_signal_noise_ratio(observation_hat[f], observation_gt[f], data_range=2.0)) |
| | frame_wise_psnr = torch.stack(frame_wise_psnr) |
| |
|
| | output_dict["frame_wise_psnr"] = frame_wise_psnr |
| | observation_hat = observation_hat.view(-1, channel, height, width) |
| | observation_gt = observation_gt.view(-1, channel, height, width) |
| |
|
| | output_dict["mse"] = mean_squared_error(observation_hat, observation_gt) |
| |
|
| | output_dict["psnr"] = peak_signal_noise_ratio(observation_hat, observation_gt, data_range=2.0) |
| | |
| | |
| | |
| | observation_hat = torch.clamp(observation_hat, -1.0, 1.0) |
| | observation_gt = torch.clamp(observation_gt, -1.0, 1.0) |
| |
|
| | if lpips_model is not None: |
| | lpips_model.update(observation_hat, observation_gt) |
| | lpips = lpips_model.compute().item() |
| | |
| | output_dict["lpips"] = lpips |
| | lpips_model.reset() |
| |
|
| | if fid_model is not None: |
| | observation_hat_uint8 = ((observation_hat + 1.0) / 2 * 255).type(torch.uint8) |
| | observation_gt_uint8 = ((observation_gt + 1.0) / 2 * 255).type(torch.uint8) |
| | fid_model.update(observation_gt_uint8, real=True) |
| | fid_model.update(observation_hat_uint8, real=False) |
| | fid = fid_model.compute() |
| | output_dict["fid"] = fid |
| | |
| | fid_model.reset() |
| |
|
| | return output_dict |
| |
|
| |
|
| | def is_grid_env(env_id): |
| | return "maze2d" in env_id or "diagonal2d" in env_id |
| |
|
| |
|
| | def get_maze_grid(env_id): |
| | |
| | |
| | if "large" in env_id: |
| | maze_string = "############\\#OOOO#OOOOO#\\#O##O#O#O#O#\\#OOOOOO#OOO#\\#O####O###O#\\#OO#O#OOOOO#\\##O#O#O#O###\\#OO#OOO#OGO#\\############" |
| | if "medium" in env_id: |
| | maze_string = "########\\#OO##OO#\\#OO#OOO#\\##OOO###\\#OO#OOO#\\#O#OO#O#\\#OOO#OG#\\########" |
| | if "umaze" in env_id: |
| | maze_string = "#####\\#GOO#\\###O#\\#OOO#\\#####" |
| | lines = maze_string.split("\\") |
| | grid = [line[1:-1] for line in lines] |
| | return grid[1:-1] |
| |
|
| |
|
| | def get_random_start_goal(env_id, batch_size): |
| | maze_grid = get_maze_grid(env_id) |
| | s2i = {"O": 0, "#": 1, "G": 2} |
| | maze_grid = [[s2i[s] for s in r] for r in maze_grid] |
| | maze_grid = np.array(maze_grid) |
| | x, y = np.nonzero(maze_grid == 0) |
| | indices = np.random.randint(len(x), size=batch_size) |
| | start = np.stack([x[indices], y[indices]], -1) + 1 |
| | x, y = np.nonzero(maze_grid == 2) |
| | goal = np.concatenate([x, y], -1) |
| | goal = np.tile(goal[None, :], (batch_size, 1)) + 1 |
| | return start, goal |
| |
|
| |
|
| | def plot_maze_layout(ax, maze_grid): |
| | ax.clear() |
| |
|
| | if maze_grid is not None: |
| | for i, row in enumerate(maze_grid): |
| | for j, cell in enumerate(row): |
| | if cell == "#": |
| | square = plt.Rectangle((i + 0.5, j + 0.5), 1, 1, edgecolor="black", facecolor="black") |
| | ax.add_patch(square) |
| |
|
| | ax.set_aspect("equal") |
| | ax.grid(True, color="white", linewidth=4) |
| | ax.set_axisbelow(True) |
| | ax.spines["top"].set_linewidth(4) |
| | ax.spines["right"].set_linewidth(4) |
| | ax.spines["bottom"].set_linewidth(4) |
| | ax.spines["left"].set_linewidth(4) |
| | ax.set_facecolor("lightgray") |
| | ax.tick_params( |
| | axis="both", |
| | which="both", |
| | bottom=False, |
| | top=False, |
| | left=False, |
| | right=False, |
| | labelbottom=False, |
| | labelleft=False, |
| | ) |
| | ax.set_xticks(np.arange(0.5, len(maze_grid) + 0.5)) |
| | ax.set_yticks(np.arange(0.5, len(maze_grid[0]) + 0.5)) |
| | ax.set_xlim(0.5, len(maze_grid) + 0.5) |
| | ax.set_ylim(0.5, len(maze_grid[0]) + 0.5) |
| | ax.grid(True, color="white", which="minor", linewidth=4) |
| |
|
| |
|
| | def plot_start_goal(ax, start_goal: None): |
| | def draw_star(center, radius, num_points=5, color="black"): |
| | angles = np.linspace(0.0, 2 * np.pi, num_points, endpoint=False) + 5 * np.pi / (2 * num_points) |
| | inner_radius = radius / 2.0 |
| |
|
| | points = [] |
| | for angle in angles: |
| | points.extend( |
| | [ |
| | center[0] + radius * np.cos(angle), |
| | center[1] + radius * np.sin(angle), |
| | center[0] + inner_radius * np.cos(angle + np.pi / num_points), |
| | center[1] + inner_radius * np.sin(angle + np.pi / num_points), |
| | ] |
| | ) |
| |
|
| | star = plt.Polygon(np.array(points).reshape(-1, 2), color=color) |
| | ax.add_patch(star) |
| |
|
| | start_x, start_y = start_goal[0] |
| | start_outer_circle = plt.Circle((start_x, start_y), 0.16, facecolor="white", edgecolor="black") |
| | ax.add_patch(start_outer_circle) |
| | start_inner_circle = plt.Circle((start_x, start_y), 0.08, color="black") |
| | ax.add_patch(start_inner_circle) |
| |
|
| | goal_x, goal_y = start_goal[1] |
| | goal_outer_circle = plt.Circle((goal_x, goal_y), 0.16, facecolor="white", edgecolor="black") |
| | ax.add_patch(goal_outer_circle) |
| | draw_star((goal_x, goal_y), radius=0.08) |
| |
|
| |
|
| | def make_trajectory_images(env_id, trajectory, batch_size, start, goal, plot_end_points=True): |
| | images = [] |
| | for batch_idx in range(batch_size): |
| | fig, ax = plt.subplots() |
| | if is_grid_env(env_id): |
| | maze_grid = get_maze_grid(env_id) |
| | else: |
| | maze_grid = None |
| | plot_maze_layout(ax, maze_grid) |
| | ax.scatter(trajectory[:, batch_idx, 0], trajectory[:, batch_idx, 1], c=np.arange(len(trajectory)), cmap="Reds"), |
| | if plot_end_points: |
| | start_goal = (start[batch_idx], goal[batch_idx]) |
| | plot_start_goal(ax, start_goal) |
| | |
| | fig.tight_layout() |
| | fig.canvas.draw() |
| | img_shape = fig.canvas.get_width_height()[::-1] + (4,) |
| | img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).copy().reshape(img_shape) |
| | images.append(img) |
| |
|
| | plt.close() |
| | return images |
| |
|
| |
|
| | def make_convergence_animation( |
| | env_id, |
| | plan_history, |
| | trajectory, |
| | start, |
| | goal, |
| | open_loop_horizon, |
| | namespace, |
| | interval=100, |
| | plot_end_points=True, |
| | batch_idx=0, |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | start, goal = start[batch_idx], goal[batch_idx] |
| | trajectory = trajectory[:, batch_idx] |
| | plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history] |
| | trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon) |
| |
|
| | |
| | fig, ax = plt.subplots() |
| | if "large" in env_id: |
| | fig.set_size_inches(3.5, 5) |
| | else: |
| | fig.set_size_inches(3, 3) |
| | ax.set_axis_off() |
| | fig.subplots_adjust(left=0, bottom=0, right=1, top=1) |
| |
|
| | if is_grid_env(env_id): |
| | maze_grid = get_maze_grid(env_id) |
| | else: |
| | maze_grid = None |
| |
|
| | def update(frame): |
| | plot_maze_layout(ax, maze_grid) |
| |
|
| | plan_history_m = plan_history[0][frame] |
| | plan_history_m = plan_history_m.numpy() |
| | ax.scatter( |
| | plan_history_m[:, 0], |
| | plan_history_m[:, 1], |
| | c=np.arange(len(plan_history_m))[::-1], |
| | cmap="Reds", |
| | ) |
| |
|
| | if plot_end_points: |
| | plot_start_goal(ax, (start, goal)) |
| |
|
| | frames = tqdm(range(len(plan_history[0])), desc="Making convergence animation") |
| | ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval) |
| | prefix = wandb.run.id if wandb.run is not None else env_id |
| | filename = f"/tmp/{prefix}_{namespace}_convergence.mp4" |
| | ani.save(filename, writer="ffmpeg", fps=5) |
| | return filename |
| |
|
| |
|
| | def prune_history(plan_history, trajectory, goal, open_loop_horizon): |
| | dist = np.linalg.norm( |
| | trajectory[:, :2] - np.array(goal)[None], |
| | axis=-1, |
| | ) |
| | reached = dist < 0.2 |
| | if reached.any(): |
| | cap_idx = np.argmax(reached) |
| | trajectory = trajectory[: cap_idx + open_loop_horizon + 1] |
| | plan_history = plan_history[: cap_idx // open_loop_horizon + 2] |
| |
|
| | pruned_plan_history = [] |
| | for plans in plan_history: |
| | pruned_plan_history.append([]) |
| | for m in range(len(plans)): |
| | plan = plans[m] |
| | pruned_plan_history[-1].append(plan) |
| | plan = pruned_plan_history[-1][-1] |
| | dist = np.linalg.norm(plan.numpy()[:, :2] - np.array(goal)[None], axis=-1) |
| | reached = dist < 0.2 |
| | if reached.any(): |
| | cap_idx = np.argmax(reached) + 1 |
| | pruned_plan_history[-1] = [p[:cap_idx] for p in pruned_plan_history[-1]] |
| | return trajectory, pruned_plan_history |
| |
|
| |
|
| | def make_mpc_animation( |
| | env_id, |
| | plan_history, |
| | trajectory, |
| | start, |
| | goal, |
| | open_loop_horizon, |
| | namespace, |
| | interval=100, |
| | plot_end_points=True, |
| | batch_idx=0, |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | start, goal = start[batch_idx], goal[batch_idx] |
| | trajectory = trajectory[:, batch_idx] |
| | plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history] |
| | trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon) |
| |
|
| | |
| | fig, ax = plt.subplots() |
| | if "large" in env_id: |
| | fig.set_size_inches(3.5, 5) |
| | else: |
| | fig.set_size_inches(3, 3) |
| | ax.set_axis_off() |
| | fig.subplots_adjust(left=0, bottom=0, right=1, top=1) |
| | trajectory_colors = np.linspace(0, 1, len(trajectory)) |
| |
|
| | if is_grid_env(env_id): |
| | maze_grid = get_maze_grid(env_id) |
| | else: |
| | maze_grid = None |
| |
|
| | def update(frame): |
| | control_time_step = 0 |
| | while frame >= 0: |
| | frame -= len(plan_history[control_time_step]) |
| | control_time_step += 1 |
| | control_time_step -= 1 |
| | m = frame + len(plan_history[control_time_step]) |
| | num_steps_taken = 1 + open_loop_horizon * control_time_step |
| | plot_maze_layout(ax, maze_grid) |
| |
|
| | plan_history_m = plan_history[control_time_step][m] |
| | plan_history_m = plan_history_m.numpy() |
| | ax.scatter( |
| | trajectory[:num_steps_taken, 0], |
| | trajectory[:num_steps_taken, 1], |
| | c=trajectory_colors[:num_steps_taken], |
| | cmap="Blues", |
| | ) |
| | ax.scatter( |
| | plan_history_m[:, 0], |
| | plan_history_m[:, 1], |
| | c=np.arange(len(plan_history_m))[::-1], |
| | cmap="Reds", |
| | ) |
| |
|
| | if plot_end_points: |
| | plot_start_goal(ax, (start, goal)) |
| |
|
| | num_frames = sum([len(p) for p in plan_history]) |
| | frames = tqdm(range(num_frames), desc="Making MPC animation") |
| | ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval) |
| | prefix = wandb.run.id if wandb.run is not None else env_id |
| | filename = f"/tmp/{prefix}_{namespace}_mpc.mp4" |
| | ani.save(filename, writer="ffmpeg", fps=5) |
| |
|
| | return filename |
| |
|