Spaces:
Running
on
Zero
Running
on
Zero
| """ This file contains some utils functions for visualization. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| """ | |
| import torch | |
| import torchvision.transforms.functional as F | |
| from einops import rearrange | |
| def make_viz_from_samples( | |
| original_images, | |
| reconstructed_images | |
| ): | |
| """Generates visualization images from original images and reconstructed images. | |
| Args: | |
| original_images: A torch.Tensor, original images. | |
| reconstructed_images: A torch.Tensor, reconstructed images. | |
| Returns: | |
| A tuple containing two lists - images_for_saving and images_for_logging. | |
| """ | |
| reconstructed_images = torch.clamp(reconstructed_images, 0.0, 1.0) | |
| reconstructed_images = reconstructed_images * 255.0 | |
| reconstructed_images = reconstructed_images.cpu() | |
| original_images = torch.clamp(original_images, 0.0, 1.0) | |
| original_images *= 255.0 | |
| original_images = original_images.cpu() | |
| diff_img = torch.abs(original_images - reconstructed_images) | |
| to_stack = [original_images, reconstructed_images, diff_img] | |
| images_for_logging = rearrange( | |
| torch.stack(to_stack), | |
| "(l1 l2) b c h w -> b c (l1 h) (l2 w)", | |
| l1=1).byte() | |
| images_for_saving = [F.to_pil_image(image) for image in images_for_logging] | |
| return images_for_saving, images_for_logging | |
| def make_viz_from_samples_generation( | |
| generated_images, | |
| ): | |
| generated = torch.clamp(generated_images, 0.0, 1.0) * 255.0 | |
| images_for_logging = rearrange( | |
| generated, | |
| "(l1 l2) c h w -> c (l1 h) (l2 w)", | |
| l1=2) | |
| images_for_logging = images_for_logging.cpu().byte() | |
| images_for_saving = F.to_pil_image(images_for_logging) | |
| return images_for_saving, images_for_logging |