WeatherFlow / generate_reflow_pairs.py
JacobLinCool's picture
Upload folder using huggingface_hub
4c7009f verified
import argparse
import os
import yaml
import torch
import torch.nn as nn
import numpy as np
import torchdiffeq
import utils
from diff2flow import VPDiffusionFlow, dict2namespace
import datasets
from tqdm import tqdm
def ode_inverse_solve(
flow_model,
x_data,
x_cond,
steps=100,
method="dopri5",
patch_size=64,
atol=1e-5,
rtol=1e-5,
):
"""
Solves the ODE from t=0 (data) to t=1 (noise).
Returns x_1 (noise latent).
"""
# Define the drift function wrapper for torchdiffeq
# For inversion, we integrate from 0 to 1.
# The drift v(x, t) is defined for t in [0, 1].
def drift_func(t, x):
# flow_model.get_velocity expects t in [0, 1]
# When using torchdiffeq, t will be traversing 0->1.
return flow_model.get_velocity(x, t, x_cond, patch_size=patch_size)
# Time points from 0 to 1
t_eval = torch.linspace(0.0, 1.0, steps + 1, device=x_data.device)
# Solve
out = torchdiffeq.odeint(
drift_func, x_data, t_eval, method=method, atol=atol, rtol=rtol
)
# Return final state at t=1
return out[-1]
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--resume", type=str, required=True)
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--dataset", type=str, default=None)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--output_dir", type=str, default="reflow_data")
parser.add_argument("--seed", type=int, default=61)
parser.add_argument("--patch_size", type=int, default=64)
parser.add_argument("--method", type=str, default="dopri5")
parser.add_argument("--atol", type=float, default=1e-5)
parser.add_argument("--rtol", type=float, default=1e-5)
parser.add_argument(
"--max_images",
type=int,
default=None,
help="Max images to generate (for testing)",
)
args = parser.parse_args()
# Load Config
with open(os.path.join("configs", args.config), "r") as f:
config_dict = yaml.safe_load(f)
config = dict2namespace(config_dict)
if args.data_dir:
config.data.data_dir = args.data_dir
if args.dataset:
config.data.dataset = args.dataset
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config.device = device
# Reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# Load Model
print("Initializing VPDiffusionFlow...")
flow = VPDiffusionFlow(args, config)
flow.load_ckpt(args.resume)
os.makedirs(args.output_dir, exist_ok=True)
# Load Dataset
print(f"Loading dataset {config.data.dataset}...")
DATASET = datasets.__dict__[config.data.dataset](config)
# We use the TRAINING set to generate pairs for training the reflow model
train_loader, _ = DATASET.get_loaders(
parse_patches=False,
validation=config.data.dataset if args.dataset else "raindrop",
)
# We want to iterate over training data. Note: get_loaders usually returns (train_loader, val_loader).
# RainDrop.get_loaders returns (train_loader, val_loader).
# train_loader usually parses patches = True for original training.
# But for generating full image pairs or consistent pairs, we might want full images or patching?
# The user asked for "very fast inference". If we train on patches, we can infer on patches (and then stitch).
# If we train on full images, that's better but memory intensive.
# The original training was likely on patches (RainDropDataset uses patch_size).
# For Reflow, we should probably train on PATCHES to match the original training distribution and efficiency.
# So let's use parse_patches=True for the loader to match training setup.
# However, to use `ode_inverse_solve`, we need to follow the flow.
# If we use patches, we can solve ODE for each patch independently.
# This is consistent.
# Re-get loaders with parse_patches=True to get training patches
train_loader, _ = DATASET.get_loaders(parse_patches=True)
print(f"Starting generation of reflow pairs...")
count = 0
# Iterate through training patches
for i, (x_batch, img_id) in enumerate(
tqdm(train_loader, desc="Generating Reflow Pairs")
):
# x_batch: [B, N, 6, H, W] if parse_patches=True
# Flatten B and N to process all patches
if x_batch.ndim == 5:
x_batch = x_batch.flatten(start_dim=0, end_dim=1)
input_img = x_batch[:, :3, :, :].to(device) # Input (Rainy)
gt_img = x_batch[:, 3:, :, :].to(device) # GT (Clean)
# Transform data to [-1, 1]
x_cond = utils.sampling.data_transform(input_img)
x_data = utils.sampling.data_transform(gt_img)
# Run ODE Inversion: x_data (t=0) -> x_noise (t=1)
# Note: patch_size argument in ode_inverse_solve usually used for stitching.
# Here x_data IS a patch (e.g. 64x64 or config size).
# So we can pass patch_size=None or just let it handle it.
# Our VPDiffusionFlow.get_velocity handles patching if x > patch_size.
# Here x is likely small.
with torch.no_grad():
x_noise = ode_inverse_solve(
flow,
x_data,
x_cond,
steps=args.steps,
method=args.method,
patch_size=args.patch_size,
atol=args.atol,
rtol=args.rtol,
)
# Save pair (x_noise, x_cond, x_data)
# x_noise is the 'target' input for the reflow model (at t=1)
# x_data is the 'target' output (at t=0)
# x_cond is the condition
# We save this batch
batch_data = {
"x_noise": x_noise.cpu(),
"x_data": x_data.cpu(),
"x_cond": x_cond.cpu(),
}
save_path = os.path.join(args.output_dir, f"batch_{i}.pth")
torch.save(batch_data, save_path)
print(f"Saved batch {i} to {save_path}")
count += input_img.shape[0]
if args.max_images and count >= args.max_images:
print(f"Reached max images {args.max_images}")
break
if __name__ == "__main__":
main()