disentangled-image-editing-final-project
/
ContraCLIP
/models
/genforce
/runners
/base_gan_runner.py
| # python3.7 | |
| """Contains the base class for GAN runner.""" | |
| import os | |
| import shutil | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from metrics.inception import build_inception_model | |
| from metrics.fid import extract_feature | |
| from metrics.fid import compute_fid | |
| from utils.visualizer import HtmlPageVisualizer | |
| from utils.visualizer import postprocess_image | |
| from utils.visualizer import save_image | |
| from utils.visualizer import load_image | |
| from .base_runner import BaseRunner | |
| __all__ = ['BaseGANRunner'] | |
| class BaseGANRunner(BaseRunner): | |
| """Defines the base class for GAN runner.""" | |
| def __init__(self, config, logger): | |
| super().__init__(config, logger) | |
| self.inception_model = None | |
| def moving_average_model(self, model, avg_model, beta=0.999): | |
| """Moving average model weights. | |
| This trick is commonly used in GAN training, where the weight of the | |
| generator is life-long averaged | |
| Args: | |
| model: The latest model used to update the averaged weights. | |
| avg_model: The averaged model weights. | |
| beta: Hyper-parameter used for moving average. | |
| """ | |
| model_params = dict(self.get_module(model).named_parameters()) | |
| avg_params = dict(self.get_module(avg_model).named_parameters()) | |
| assert len(model_params) == len(avg_params) | |
| for param_name in avg_params: | |
| assert param_name in model_params | |
| avg_params[param_name].data = ( | |
| avg_params[param_name].data * beta + | |
| model_params[param_name].data * (1 - beta)) | |
| def build_models(self): | |
| super().build_models() | |
| assert 'generator' in self.models | |
| assert 'discriminator' in self.models | |
| self.z_space_dim = self.models['generator'].z_space_dim | |
| self.resolution = self.models['generator'].resolution | |
| self.G_kwargs_train = self.config.modules['generator'].get( | |
| 'kwargs_train', dict()) | |
| self.G_kwargs_val = self.config.modules['generator'].get( | |
| 'kwargs_val', dict()) | |
| self.D_kwargs_train = self.config.modules['discriminator'].get( | |
| 'kwargs_train', dict()) | |
| self.D_kwargs_val = self.config.modules['discriminator'].get( | |
| 'kwargs_val', dict()) | |
| def train_step(self, data, **train_kwargs): | |
| raise NotImplementedError('Should be implemented in derived class.') | |
| def val(self, **val_kwargs): | |
| self.synthesize(**val_kwargs) | |
| def synthesize(self, | |
| num, | |
| z=None, | |
| html_name=None, | |
| save_raw_synthesis=False): | |
| """Synthesizes images. | |
| Args: | |
| num: Number of images to synthesize. | |
| z: Latent codes used for generation. If not specified, this function | |
| will sample latent codes randomly. (default: None) | |
| html_name: Name of the output html page for visualization. If not | |
| specified, no visualization page will be saved. (default: None) | |
| save_raw_synthesis: Whether to save raw synthesis on the disk. | |
| (default: False) | |
| """ | |
| if not html_name and not save_raw_synthesis: | |
| return | |
| self.set_mode('val') | |
| temp_dir = os.path.join(self.work_dir, 'synthesize_results') | |
| os.makedirs(temp_dir, exist_ok=True) | |
| if z is not None: | |
| assert isinstance(z, np.ndarray) | |
| assert z.ndim == 2 and z.shape[1] == self.z_space_dim | |
| num = min(num, z.shape[0]) | |
| z = torch.from_numpy(z).type(torch.FloatTensor) | |
| if not num: | |
| return | |
| # TODO: Use same z during the entire training process. | |
| self.logger.init_pbar() | |
| task1 = self.logger.add_pbar_task('Synthesize', total=num) | |
| indices = list(range(self.rank, num, self.world_size)) | |
| for batch_idx in range(0, len(indices), self.val_batch_size): | |
| sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
| batch_size = len(sub_indices) | |
| if z is None: | |
| code = torch.randn(batch_size, self.z_space_dim).cuda() | |
| else: | |
| code = z[sub_indices].cuda() | |
| with torch.no_grad(): | |
| if 'generator_smooth' in self.models: | |
| G = self.models['generator_smooth'] | |
| else: | |
| G = self.models['generator'] | |
| images = G(code, **self.G_kwargs_val)['image'] | |
| images = postprocess_image(images.detach().cpu().numpy()) | |
| for sub_idx, image in zip(sub_indices, images): | |
| save_image(os.path.join(temp_dir, f'{sub_idx:06d}.jpg'), image) | |
| self.logger.update_pbar(task1, batch_size * self.world_size) | |
| dist.barrier() | |
| if self.rank != 0: | |
| return | |
| if html_name: | |
| task2 = self.logger.add_pbar_task('Visualize', total=num) | |
| html = HtmlPageVisualizer(grid_size=num) | |
| for image_idx in range(num): | |
| image = load_image( | |
| os.path.join(temp_dir, f'{image_idx:06d}.jpg')) | |
| row_idx, col_idx = divmod(image_idx, html.num_cols) | |
| html.set_cell(row_idx, col_idx, image=image, | |
| text=f'Sample {image_idx:06d}') | |
| self.logger.update_pbar(task2, 1) | |
| html.save(os.path.join(self.work_dir, html_name)) | |
| if not save_raw_synthesis: | |
| shutil.rmtree(temp_dir) | |
| self.logger.close_pbar() | |
| def fid(self, | |
| fid_num, | |
| z=None, | |
| ignore_cache=False, | |
| align_tf=True): | |
| """Computes the FID metric.""" | |
| self.set_mode('val') | |
| if self.val_loader is None: | |
| self.build_dataset('val') | |
| fid_num = min(fid_num, len(self.val_loader.dataset)) | |
| if self.inception_model is None: | |
| if align_tf: | |
| self.logger.info(f'Building inception model ' | |
| f'(aligned with TensorFlow) ...') | |
| else: | |
| self.logger.info(f'Building inception model ' | |
| f'(using torchvision) ...') | |
| self.inception_model = build_inception_model(align_tf).cuda() | |
| self.logger.info(f'Finish building inception model.') | |
| if z is not None: | |
| assert isinstance(z, np.ndarray) | |
| assert z.ndim == 2 and z.shape[1] == self.z_space_dim | |
| fid_num = min(fid_num, z.shape[0]) | |
| z = torch.from_numpy(z).type(torch.FloatTensor) | |
| if not fid_num: | |
| return -1 | |
| indices = list(range(self.rank, fid_num, self.world_size)) | |
| self.logger.init_pbar() | |
| # Extract features from fake images. | |
| fake_feature_list = [] | |
| task1 = self.logger.add_pbar_task('Fake', total=fid_num) | |
| for batch_idx in range(0, len(indices), self.val_batch_size): | |
| sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
| batch_size = len(sub_indices) | |
| if z is None: | |
| code = torch.randn(batch_size, self.z_space_dim).cuda() | |
| else: | |
| code = z[sub_indices].cuda() | |
| with torch.no_grad(): | |
| if 'generator_smooth' in self.models: | |
| G = self.models['generator_smooth'] | |
| else: | |
| G = self.models['generator'] | |
| fake_images = G(code)['image'] | |
| fake_feature_list.append( | |
| extract_feature(self.inception_model, fake_images)) | |
| self.logger.update_pbar(task1, batch_size * self.world_size) | |
| np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy', | |
| np.concatenate(fake_feature_list, axis=0)) | |
| # Extract features from real images if needed. | |
| cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy' | |
| do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) | |
| if do_real_test: | |
| real_feature_list = [] | |
| task2 = self.logger.add_pbar_task("Real", total=fid_num) | |
| for batch_idx in range(0, len(indices), self.val_batch_size): | |
| sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
| batch_size = len(sub_indices) | |
| data = next(self.val_loader) | |
| for key in data: | |
| data[key] = data[key][:batch_size].cuda( | |
| torch.cuda.current_device(), non_blocking=True) | |
| with torch.no_grad(): | |
| real_images = data['image'] | |
| real_feature_list.append( | |
| extract_feature(self.inception_model, real_images)) | |
| self.logger.update_pbar(task2, batch_size * self.world_size) | |
| np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', | |
| np.concatenate(real_feature_list, axis=0)) | |
| dist.barrier() | |
| if self.rank != 0: | |
| return -1 | |
| self.logger.close_pbar() | |
| # Collect fake features. | |
| fake_feature_list.clear() | |
| for rank in range(self.world_size): | |
| fake_feature_list.append( | |
| np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy')) | |
| os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy') | |
| fake_features = np.concatenate(fake_feature_list, axis=0) | |
| assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num | |
| feature_dim = fake_features.shape[1] | |
| pad = fid_num % self.world_size | |
| if pad: | |
| pad = self.world_size - pad | |
| fake_features = np.pad(fake_features, ((0, pad), (0, 0))) | |
| fake_features = fake_features.reshape(self.world_size, -1, feature_dim) | |
| fake_features = fake_features.transpose(1, 0, 2) | |
| fake_features = fake_features.reshape(-1, feature_dim)[:fid_num] | |
| # Collect (or load) real features. | |
| if do_real_test: | |
| real_feature_list.clear() | |
| for rank in range(self.world_size): | |
| real_feature_list.append( | |
| np.load(f'{self.work_dir}/real_fid_features_{rank}.npy')) | |
| os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy') | |
| real_features = np.concatenate(real_feature_list, axis=0) | |
| assert real_features.shape == (fid_num, feature_dim) | |
| real_features = np.pad(real_features, ((0, pad), (0, 0))) | |
| real_features = real_features.reshape( | |
| self.world_size, -1, feature_dim) | |
| real_features = real_features.transpose(1, 0, 2) | |
| real_features = real_features.reshape(-1, feature_dim)[:fid_num] | |
| np.save(cached_fid_file, real_features) | |
| else: | |
| real_features = np.load(cached_fid_file) | |
| assert real_features.shape == (fid_num, feature_dim) | |
| fid_value = compute_fid(fake_features, real_features) | |
| return fid_value | |