| from abc import ABC, abstractmethod |
|
|
| import numpy as np |
| import torch as th |
| import torch.distributed as dist |
|
|
|
|
| def create_named_schedule_sampler(name, diffusion): |
| """ |
| Create a ScheduleSampler from a library of pre-defined samplers. |
| |
| :param name: the name of the sampler. |
| :param diffusion: the diffusion object to sample for. |
| """ |
| if name == "uniform": |
| return UniformSampler(diffusion) |
| else: |
| raise NotImplementedError(f"unknown schedule sampler: {name}") |
|
|
|
|
| class ScheduleSampler(ABC): |
| """ |
| A distribution over timesteps in the diffusion process, intended to reduce |
| variance of the objective. |
| |
| By default, samplers perform unbiased importance sampling, in which the |
| objective's mean is unchanged. |
| However, subclasses may override sample() to change how the resampled |
| terms are reweighted, allowing for actual changes in the objective. |
| """ |
| @abstractmethod |
| def weights(self): |
| """ |
| Get a numpy array of weights, one per diffusion step. |
| |
| The weights needn't be normalized, but must be positive. |
| """ |
|
|
| def sample(self, batch_size, device): |
| """ |
| Importance-sample timesteps for a batch. |
| |
| :param batch_size: the number of timesteps. |
| :param device: the torch device to save to. |
| :return: a tuple (timesteps, weights): |
| - timesteps: a tensor of timestep indices. |
| - weights: a tensor of weights to scale the resulting losses. |
| """ |
| w = self.weights() |
| p = w / np.sum(w) |
| indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) |
| indices = th.from_numpy(indices_np).long().to(device) |
| weights_np = 1 / (len(p) * p[indices_np]) |
| weights = th.from_numpy(weights_np).float().to(device) |
| return indices, weights |
|
|
|
|
| class UniformSampler(ScheduleSampler): |
| def __init__(self, num_timesteps): |
| self._weights = np.ones([num_timesteps]) |
|
|
| def weights(self): |
| return self._weights |
|
|