"""Vector quantizer. Reference: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py """ from typing import Mapping, Text, Tuple import torch from einops import rearrange from accelerate.utils.operations import gather from torch.cuda.amp import autocast class VectorQuantizer(torch.nn.Module): def __init__(self, codebook_size: int = 1024, token_size: int = 256, commitment_cost: float = 0.25, use_l2_norm: bool = False, clustering_vq: bool = False ): super().__init__() self.codebook_size = codebook_size self.token_size = token_size self.commitment_cost = commitment_cost self.embedding = torch.nn.Embedding(codebook_size, token_size) self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) self.use_l2_norm = use_l2_norm self.clustering_vq = clustering_vq if clustering_vq: self.decay = 0.99 self.register_buffer("embed_prob", torch.zeros(self.codebook_size)) # Ensure quantization is performed using f32 @autocast(enabled=False) def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: z = z.float() z = rearrange(z, 'b c h w -> b h w c').contiguous() z_flattened = rearrange(z, 'b h w c -> (b h w) c') unnormed_z_flattened = z_flattened if self.use_l2_norm: z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) else: embedding = self.embedding.weight d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ torch.sum(embedding**2, dim=1) - 2 * \ torch.einsum('bd,dn->bn', z_flattened, embedding.T) min_encoding_indices = torch.argmin(d, dim=1) # num_ele z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) if self.use_l2_norm: z = torch.nn.functional.normalize(z, dim=-1) # compute loss for embedding commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) codebook_loss = torch.mean((z_quantized - z.detach()) **2) if self.clustering_vq and self.training: with torch.no_grad(): # Gather distance matrix from all GPUs. encoding_indices = gather(min_encoding_indices) if len(min_encoding_indices.shape) != 1: raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}") # Compute and update the usage of each entry in the codebook. encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device) encodings.scatter_(1, encoding_indices.unsqueeze(1), 1) avg_probs = torch.mean(encodings, dim=0) self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay) # Closest sampling to update the codebook. all_d = gather(d) all_unnormed_z_flattened = gather(unnormed_z_flattened).detach() if all_d.shape[0] != all_unnormed_z_flattened.shape[0]: raise ValueError( "all_d and all_unnormed_z_flattened have different length" + f"{all_d.shape}, {all_unnormed_z_flattened.shape}") indices = torch.argmin(all_d, dim=0) random_feat = all_unnormed_z_flattened[indices] # Decay parameter based on the average usage. decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) / (1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size) self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay loss = commitment_loss + codebook_loss # preserve gradients z_quantized = z + (z_quantized - z).detach() # reshape back to match original input shape z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() result_dict = dict( quantizer_loss=loss, commitment_loss=commitment_loss, codebook_loss=codebook_loss, min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) ) return z_quantized, result_dict @autocast(enabled=False) def get_codebook_entry(self, indices): indices = indices.long() if len(indices.shape) == 1: z_quantized = self.embedding(indices) elif len(indices.shape) == 2: z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) else: raise NotImplementedError if self.use_l2_norm: z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) return z_quantized class DiagonalGaussianDistribution(object): @autocast(enabled=False) def __init__(self, parameters, deterministic=False): """Initializes a Gaussian distribution instance given the parameters. Args: parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. First C channels are used for mean and last C are used for logvar in the Gaussian distribution. deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results is purely based on mean (i.e., std = 0). """ self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) @autocast(enabled=False) def sample(self): x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device) return x @autocast(enabled=False) def mode(self): return self.mean @autocast(enabled=False) def kl(self): if self.deterministic: return torch.Tensor([0.]) else: return 0.5 * torch.sum(torch.pow(self.mean.float(), 2) + self.var.float() - 1.0 - self.logvar.float(), dim=[1, 2])