Skip to content

Commit

Permalink
rename diversity
Browse files Browse the repository at this point in the history
  • Loading branch information
function2-llx committed Dec 21, 2023
1 parent 71cb014 commit fa540c3
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions conf/tokenizer/swin/loss.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
quant_weight: 1.
entropy_weight: 0.2
quant_weight: 1
entropy_weight: 0
rec_loss: l1
rec_weight: 1
perceptual_loss:
Expand Down
2 changes: 1 addition & 1 deletion pumit/tokenizer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ def forward_gen(
'rec_loss': rec_loss,
'perceptual_loss': perceptual_loss,
'quant_loss': vq_out.loss,
'util_var': vq_out.util_var,
'vq_loss': vq_loss,
'gan_loss': gan_loss,
'gan_weight': gan_weight,
}
if vq_out.entropy is not None:
log_dict['entropy'] = vq_out.entropy
log_dict['diversity'] = vq_out.diversity
return loss, log_dict

def disc_fix_rgb(self, x: torch.Tensor, not_rgb: torch.Tensor):
Expand Down
34 changes: 17 additions & 17 deletions pumit/tokenizer/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class VectorQuantizerOutput:
"""codebook index, can be a discrete one or soft one (categorical distribution over codebook)"""
loss: torch.Tensor
"""quantization loss (e.g., |z_q - e|, or prior distribution regularization)"""
diversity: float
util_var: float
"""evaluate if the utilization of the codebook is "uniform" enough"""
logits: torch.Tensor | None = None
"""original logits over the codebook for probabilistic VQ"""
Expand Down Expand Up @@ -49,16 +49,16 @@ def _load_from_state_dict(self, state_dict: dict[str, torch.Tensor], prefix: str
state_dict[proj_weight_key] = weight
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def cal_diversity(self, index: torch.Tensor):
bias_correction = self.num_embeddings
if index.ndim == 4:
# discrete
raise NotImplementedError
else:
# probabilistic
p = einops.reduce(index, '... n_e -> n_e', 'mean')
var = p.var(unbiased=False)
return var * bias_correction
# def cal_diversity(self, index: torch.Tensor):
# bias_correction = self.num_embeddings
# if index.ndim == 4:
# # discrete
# raise NotImplementedError
# else:
# # probabilistic
# p = einops.reduce(index, '... n_e -> n_e', 'mean')
# var = p.var(unbiased=False)
# return var * bias_correction

def forward(self, z: torch.Tensor, fabric: Fabric | None = None) -> VectorQuantizerOutput:
"""
Expand Down Expand Up @@ -113,11 +113,11 @@ def __init__(self, num_embeddings: int, embedding_dim: int, in_channels: int | N
self.pdr_eps = pdr_eps

def get_pdr_loss(self, probs: torch.Tensor, fabric: Fabric | None):
"""prior distribution regularization"""
"""calculate prior distribution regularization and util_var"""
mean_probs = einops.reduce(probs, '... d -> d', reduction='mean')
if fabric is not None and fabric.world_size > 1:
mean_probs = fabric.all_reduce(mean_probs) - mean_probs.detach() + mean_probs
return (mean_probs * (mean_probs + self.pdr_eps).log()).sum()
return (mean_probs * (mean_probs + self.pdr_eps).log()).sum(), mean_probs.var(unbiased=False)

def embed_index(self, index_probs: torch.Tensor):
z_q = einops.einsum(index_probs, self.embedding.weight, '... ne, ne d -> ... d')
Expand Down Expand Up @@ -151,13 +151,13 @@ def adjust_temperature(self, global_step: int, max_steps: int):

def forward(self, z: torch.Tensor, fabric: Fabric | None = None):
logits, probs, entropy = self.project_over_codebook(z)
loss = self.get_pdr_loss(probs, fabric)
loss, util_var = self.get_pdr_loss(probs, fabric)
if self.training:
index_probs = nnf.gumbel_softmax(logits, self.temperature, self.hard_gumbel, dim=-1)
else:
index_probs = probs
z_q = self.embed_index(index_probs)
return VectorQuantizerOutput(z_q, index_probs, loss, self.cal_diversity(index_probs), logits, entropy)
return VectorQuantizerOutput(z_q, index_probs, loss, util_var, logits, entropy)

class SoftVQ(ProbabilisticVQ):
def __init__(self, num_embeddings: int, embedding_dim: int, in_channels: int | None = None, prune: int | None = 3):
Expand All @@ -175,6 +175,6 @@ def forward(self, z: torch.Tensor, fabric: Fabric | None = None):
index_probs = torch.zeros_like(logits)
index_probs.scatter_(-1, top_indices, top_logits.softmax(dim=-1))
index_probs = index_probs + probs - probs.detach()
loss = self.get_pdr_loss(index_probs, fabric)
loss, util_var = self.get_pdr_loss(index_probs, fabric)
z_q = self.embed_index(index_probs)
return VectorQuantizerOutput(z_q, index_probs, loss, self.cal_diversity(index_probs), logits, entropy)
return VectorQuantizerOutput(z_q, index_probs, loss, util_var, logits, entropy)
2 changes: 1 addition & 1 deletion third-party/LuoLib

0 comments on commit fa540c3

Please sign in to comment.