Spaces:
Running
on
Zero
Running
on
Zero
| """This file contains a class to evalute the reconstruction results. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import warnings | |
| from typing import Sequence, Optional, Mapping, Text | |
| import numpy as np | |
| from scipy import linalg | |
| import torch | |
| import torch.nn.functional as F | |
| from .inception import get_inception_model | |
| def get_covariance(sigma: torch.Tensor, total: torch.Tensor, num_examples: int) -> torch.Tensor: | |
| """Computes covariance of the input tensor. | |
| Args: | |
| sigma: A torch.Tensor, sum of outer products of input features. | |
| total: A torch.Tensor, sum of all input features. | |
| num_examples: An integer, number of examples in the input tensor. | |
| Returns: | |
| A torch.Tensor, covariance of the input tensor. | |
| """ | |
| if num_examples == 0: | |
| return torch.zeros_like(sigma) | |
| sub_matrix = torch.outer(total, total) | |
| sub_matrix = sub_matrix / num_examples | |
| return (sigma - sub_matrix) / (num_examples - 1) | |
| class VQGANEvaluator: | |
| def __init__( | |
| self, | |
| device, | |
| enable_rfid: bool = True, | |
| enable_inception_score: bool = True, | |
| enable_codebook_usage_measure: bool = False, | |
| enable_codebook_entropy_measure: bool = False, | |
| num_codebook_entries: int = 1024 | |
| ): | |
| """Initializes VQGAN Evaluator. | |
| Args: | |
| device: The device to use for evaluation. | |
| enable_rfid: A boolean, whether enabling rFID score. | |
| enable_inception_score: A boolean, whether enabling Inception Score. | |
| enable_codebook_usage_measure: A boolean, whether enabling codebook usage measure. | |
| enable_codebook_entropy_measure: A boolean, whether enabling codebook entropy measure. | |
| num_codebook_entries: An integer, the number of codebook entries. | |
| """ | |
| self._device = device | |
| self._enable_rfid = enable_rfid | |
| self._enable_inception_score = enable_inception_score | |
| self._enable_codebook_usage_measure = enable_codebook_usage_measure | |
| self._enable_codebook_entropy_measure = enable_codebook_entropy_measure | |
| self._num_codebook_entries = num_codebook_entries | |
| # Variables related to Inception score and rFID. | |
| self._inception_model = None | |
| self._is_num_features = 0 | |
| self._rfid_num_features = 0 | |
| if self._enable_inception_score or self._enable_rfid: | |
| self._rfid_num_features = 2048 | |
| self._is_num_features = 1008 | |
| self._inception_model = get_inception_model().to(self._device) | |
| self._inception_model.eval() | |
| self._is_eps = 1e-16 | |
| self._rfid_eps = 1e-6 | |
| self.reset_metrics() | |
| def reset_metrics(self): | |
| """Resets all metrics.""" | |
| self._num_examples = 0 | |
| self._num_updates = 0 | |
| self._is_prob_total = torch.zeros( | |
| self._is_num_features, dtype=torch.float64, device=self._device | |
| ) | |
| self._is_total_kl_d = torch.zeros( | |
| self._is_num_features, dtype=torch.float64, device=self._device | |
| ) | |
| self._rfid_real_sigma = torch.zeros( | |
| (self._rfid_num_features, self._rfid_num_features), | |
| dtype=torch.float64, device=self._device | |
| ) | |
| self._rfid_real_total = torch.zeros( | |
| self._rfid_num_features, dtype=torch.float64, device=self._device | |
| ) | |
| self._rfid_fake_sigma = torch.zeros( | |
| (self._rfid_num_features, self._rfid_num_features), | |
| dtype=torch.float64, device=self._device | |
| ) | |
| self._rfid_fake_total = torch.zeros( | |
| self._rfid_num_features, dtype=torch.float64, device=self._device | |
| ) | |
| self._set_of_codebook_indices = set() | |
| self._codebook_frequencies = torch.zeros((self._num_codebook_entries), dtype=torch.float64, device=self._device) | |
| def update( | |
| self, | |
| real_images: torch.Tensor, | |
| fake_images: torch.Tensor, | |
| codebook_indices: Optional[torch.Tensor] = None | |
| ): | |
| """Updates the metrics with the given images. | |
| Args: | |
| real_images: A torch.Tensor, the real images. | |
| fake_images: A torch.Tensor, the fake images. | |
| codebook_indices: A torch.Tensor, the indices of the codebooks for each image. | |
| Raises: | |
| ValueError: If the fake images is not in RGB (3 channel). | |
| ValueError: If the fake and real images have different shape. | |
| """ | |
| batch_size = real_images.shape[0] | |
| dim = tuple(range(1, real_images.ndim)) | |
| self._num_examples += batch_size | |
| self._num_updates += 1 | |
| if self._enable_inception_score or self._enable_rfid: | |
| # Quantize to uint8 as a real image. | |
| fake_inception_images = (fake_images * 255).to(torch.uint8) | |
| features_fake = self._inception_model(fake_inception_images) | |
| inception_logits_fake = features_fake["logits_unbiased"] | |
| inception_probabilities_fake = F.softmax(inception_logits_fake, dim=-1) | |
| if self._enable_inception_score: | |
| probabiliies_sum = torch.sum(inception_probabilities_fake, 0, dtype=torch.float64) | |
| log_prob = torch.log(inception_probabilities_fake + self._is_eps) | |
| if log_prob.dtype != inception_probabilities_fake.dtype: | |
| log_prob = log_prob.to(inception_probabilities_fake) | |
| kl_sum = torch.sum(inception_probabilities_fake * log_prob, 0, dtype=torch.float64) | |
| self._is_prob_total += probabiliies_sum | |
| self._is_total_kl_d += kl_sum | |
| if self._enable_rfid: | |
| real_inception_images = (real_images * 255).to(torch.uint8) | |
| features_real = self._inception_model(real_inception_images) | |
| if (features_real['2048'].shape[0] != features_fake['2048'].shape[0] or | |
| features_real['2048'].shape[1] != features_fake['2048'].shape[1]): | |
| raise ValueError(f"Number of features should be equal for real and fake.") | |
| for f_real, f_fake in zip(features_real['2048'], features_fake['2048']): | |
| self._rfid_real_total += f_real | |
| self._rfid_fake_total += f_fake | |
| self._rfid_real_sigma += torch.outer(f_real, f_real) | |
| self._rfid_fake_sigma += torch.outer(f_fake, f_fake) | |
| if self._enable_codebook_usage_measure: | |
| self._set_of_codebook_indices |= set(torch.unique(codebook_indices, sorted=False).tolist()) | |
| if self._enable_codebook_entropy_measure: | |
| entries, counts = torch.unique(codebook_indices, sorted=False, return_counts=True) | |
| self._codebook_frequencies.index_add_(0, entries.int(), counts.double()) | |
| def result(self) -> Mapping[Text, torch.Tensor]: | |
| """Returns the evaluation result.""" | |
| eval_score = {} | |
| if self._num_examples < 1: | |
| raise ValueError("No examples to evaluate.") | |
| if self._enable_inception_score: | |
| mean_probs = self._is_prob_total / self._num_examples | |
| log_mean_probs = torch.log(mean_probs + self._is_eps) | |
| if log_mean_probs.dtype != self._is_prob_total.dtype: | |
| log_mean_probs = log_mean_probs.to(self._is_prob_total) | |
| excess_entropy = self._is_prob_total * log_mean_probs | |
| avg_kl_d = torch.sum(self._is_total_kl_d - excess_entropy) / self._num_examples | |
| inception_score = torch.exp(avg_kl_d).item() | |
| eval_score["InceptionScore"] = inception_score | |
| if self._enable_rfid: | |
| mu_real = self._rfid_real_total / self._num_examples | |
| mu_fake = self._rfid_fake_total / self._num_examples | |
| sigma_real = get_covariance(self._rfid_real_sigma, self._rfid_real_total, self._num_examples) | |
| sigma_fake = get_covariance(self._rfid_fake_sigma, self._rfid_fake_total, self._num_examples) | |
| mu_real, mu_fake = mu_real.cpu(), mu_fake.cpu() | |
| sigma_real, sigma_fake = sigma_real.cpu(), sigma_fake.cpu() | |
| diff = mu_real - mu_fake | |
| # Product might be almost singular. | |
| covmean, _ = linalg.sqrtm(sigma_real.mm(sigma_fake).numpy(), disp=False) | |
| # Numerical error might give slight imaginary component. | |
| if np.iscomplexobj(covmean): | |
| if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): | |
| m = np.max(np.abs(covmean.imag)) | |
| raise ValueError("Imaginary component {}".format(m)) | |
| covmean = covmean.real | |
| tr_covmean = np.trace(covmean) | |
| if not np.isfinite(covmean).all(): | |
| tr_covmean = np.sum(np.sqrt(( | |
| (np.diag(sigma_real) * self._rfid_eps) * (np.diag(sigma_fake) * self._rfid_eps)) | |
| / (self._rfid_eps * self._rfid_eps) | |
| )) | |
| rfid = float(diff.dot(diff).item() + torch.trace(sigma_real) + torch.trace(sigma_fake) | |
| - 2 * tr_covmean | |
| ) | |
| if torch.isnan(torch.tensor(rfid)) or torch.isinf(torch.tensor(rfid)): | |
| warnings.warn("The product of covariance of train and test features is out of bounds.") | |
| eval_score["rFID"] = rfid | |
| if self._enable_codebook_usage_measure: | |
| usage = float(len(self._set_of_codebook_indices)) / self._num_codebook_entries | |
| eval_score["CodebookUsage"] = usage | |
| if self._enable_codebook_entropy_measure: | |
| probs = self._codebook_frequencies / self._codebook_frequencies.sum() | |
| entropy = (-torch.log2(probs + 1e-8) * probs).sum() | |
| eval_score["CodebookEntropy"] = entropy | |
| return eval_score | |