APGASU's picture
scripts
7bef20f verified
"""Vector quantizer.
Reference:
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
https://github.com/lyndonzheng/CVQ-VAE/blob/main/quantise.py
"""
from typing import Mapping, Text, Tuple
import torch
from einops import rearrange
from accelerate.utils.operations import gather
from torch.cuda.amp import autocast
class VectorQuantizer(torch.nn.Module):
def __init__(self,
codebook_size: int = 1024,
token_size: int = 256,
commitment_cost: float = 0.25,
use_l2_norm: bool = False,
clustering_vq: bool = False
):
super().__init__()
self.codebook_size = codebook_size
self.token_size = token_size
self.commitment_cost = commitment_cost
self.embedding = torch.nn.Embedding(codebook_size, token_size)
self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
self.use_l2_norm = use_l2_norm
self.clustering_vq = clustering_vq
if clustering_vq:
self.decay = 0.99
self.register_buffer("embed_prob", torch.zeros(self.codebook_size))
# Ensure quantization is performed using f32
@autocast(enabled=False)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
z = z.float()
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z_flattened = rearrange(z, 'b h w c -> (b h w) c')
unnormed_z_flattened = z_flattened
if self.use_l2_norm:
z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
else:
embedding = self.embedding.weight
d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
torch.sum(embedding**2, dim=1) - 2 * \
torch.einsum('bd,dn->bn', z_flattened, embedding.T)
min_encoding_indices = torch.argmin(d, dim=1) # num_ele
z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
if self.use_l2_norm:
z = torch.nn.functional.normalize(z, dim=-1)
# compute loss for embedding
commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
codebook_loss = torch.mean((z_quantized - z.detach()) **2)
if self.clustering_vq and self.training:
with torch.no_grad():
# Gather distance matrix from all GPUs.
encoding_indices = gather(min_encoding_indices)
if len(min_encoding_indices.shape) != 1:
raise ValueError(f"min_encoding_indices in a wrong shape, {min_encoding_indices.shape}")
# Compute and update the usage of each entry in the codebook.
encodings = torch.zeros(encoding_indices.shape[0], self.codebook_size, device=z.device)
encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
avg_probs = torch.mean(encodings, dim=0)
self.embed_prob.mul_(self.decay).add_(avg_probs, alpha=1-self.decay)
# Closest sampling to update the codebook.
all_d = gather(d)
all_unnormed_z_flattened = gather(unnormed_z_flattened).detach()
if all_d.shape[0] != all_unnormed_z_flattened.shape[0]:
raise ValueError(
"all_d and all_unnormed_z_flattened have different length" +
f"{all_d.shape}, {all_unnormed_z_flattened.shape}")
indices = torch.argmin(all_d, dim=0)
random_feat = all_unnormed_z_flattened[indices]
# Decay parameter based on the average usage.
decay = torch.exp(-(self.embed_prob * self.codebook_size * 10) /
(1 - self.decay) - 1e-3).unsqueeze(1).repeat(1, self.token_size)
self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
loss = commitment_loss + codebook_loss
# preserve gradients
z_quantized = z + (z_quantized - z).detach()
# reshape back to match original input shape
z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
result_dict = dict(
quantizer_loss=loss,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
)
return z_quantized, result_dict
@autocast(enabled=False)
def get_codebook_entry(self, indices):
indices = indices.long()
if len(indices.shape) == 1:
z_quantized = self.embedding(indices)
elif len(indices.shape) == 2:
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
else:
raise NotImplementedError
if self.use_l2_norm:
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
return z_quantized
class DiagonalGaussianDistribution(object):
@autocast(enabled=False)
def __init__(self, parameters, deterministic=False):
"""Initializes a Gaussian distribution instance given the parameters.
Args:
parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected
to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension.
First C channels are used for mean and last C are used for logvar in the Gaussian distribution.
deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results
is purely based on mean (i.e., std = 0).
"""
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
@autocast(enabled=False)
def sample(self):
x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
@autocast(enabled=False)
def mode(self):
return self.mean
@autocast(enabled=False)
def kl(self):
if self.deterministic:
return torch.Tensor([0.])
else:
return 0.5 * torch.sum(torch.pow(self.mean.float(), 2)
+ self.var.float() - 1.0 - self.logvar.float(),
dim=[1, 2])