|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Tuple, Union |
|
|
|
|
|
from pytorch3d.renderer import ( |
|
|
PerspectiveCameras, |
|
|
MeshRenderer, |
|
|
MeshRasterizer, |
|
|
SoftPhongShader, |
|
|
RasterizationSettings, |
|
|
PointLights, |
|
|
TexturesVertex |
|
|
) |
|
|
|
|
|
from pytorch3d.structures import Meshes |
|
|
from pytorch3d.renderer.camera_conversions import _cameras_from_opencv_projection |
|
|
|
|
|
def update_intrinsics_from_bbox( |
|
|
K_org: torch.Tensor, bbox: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: |
|
|
""" |
|
|
Update intrinsic matrix K according to the given bounding box. |
|
|
|
|
|
Args: |
|
|
K_org (torch.Tensor): Original intrinsic matrix of shape (B, 3, 3). |
|
|
bbox (torch.Tensor): Bounding boxes of shape (B, 4) in (left, top, right, bottom) format. |
|
|
|
|
|
Returns: |
|
|
K_new (torch.Tensor): Updated intrinsics with shape (B, 4, 4). |
|
|
image_sizes (List[Tuple[int, int]]): List of image sizes (height, width) for each bbox. |
|
|
""" |
|
|
device, dtype = K_org.device, K_org.dtype |
|
|
|
|
|
|
|
|
K_new = torch.zeros((K_org.shape[0], 4, 4), device=device, dtype=dtype) |
|
|
K_new[:, :3, :3] = K_org.clone() |
|
|
K_new[:, 2, 2] = 0 |
|
|
K_new[:, 2, -1] = 1 |
|
|
K_new[:, -1, 2] = 1 |
|
|
|
|
|
image_sizes = [] |
|
|
for idx, box in enumerate(bbox): |
|
|
left, top, right, bottom = box |
|
|
cx, cy = K_new[idx, 0, 2], K_new[idx, 1, 2] |
|
|
|
|
|
|
|
|
new_cx = cx - left |
|
|
new_cy = cy - top |
|
|
|
|
|
|
|
|
new_height = max(bottom - top, 1) |
|
|
new_width = max(right - left, 1) |
|
|
|
|
|
|
|
|
new_cx = new_width - new_cx |
|
|
new_cy = new_height - new_cy |
|
|
|
|
|
K_new[idx, 0, 2] = new_cx |
|
|
K_new[idx, 1, 2] = new_cy |
|
|
|
|
|
image_sizes.append((int(new_height), int(new_width))) |
|
|
|
|
|
return K_new, image_sizes |
|
|
|
|
|
class Renderer(): |
|
|
""" |
|
|
Renderer class using PyTorch3D for mesh rendering with Phong shading. |
|
|
|
|
|
Attributes: |
|
|
width (int): Target image width. |
|
|
height (int): Target image height. |
|
|
focal_length (Union[float, Tuple[float, float]]): Camera focal length(s). |
|
|
device (torch.device): Device to run rendering on. |
|
|
renderer (MeshRenderer): PyTorch3D mesh renderer. |
|
|
cameras (PerspectiveCameras): Camera object. |
|
|
lights (PointLights): Lighting setup for rendering. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
width: int, |
|
|
height: int, |
|
|
focal_length: Union[float, Tuple[float, float]], |
|
|
device: torch.device, |
|
|
bin_size: int = 512, |
|
|
max_faces_per_bin: int = 200000, |
|
|
): |
|
|
|
|
|
self.width = width |
|
|
self.height = height |
|
|
self.focal_length = focal_length |
|
|
self.device = device |
|
|
|
|
|
|
|
|
self._initialize_camera_params() |
|
|
|
|
|
|
|
|
self.lights = PointLights( |
|
|
device=device, |
|
|
location = ((0.0, -1.5, -1.5),), |
|
|
ambient_color=((0.75, 0.75, 0.75),), |
|
|
diffuse_color=((0.25, 0.25, 0.25),), |
|
|
specular_color=((0.02, 0.02, 0.02),) |
|
|
) |
|
|
|
|
|
|
|
|
self._create_renderer(bin_size, max_faces_per_bin) |
|
|
|
|
|
def _create_renderer(self, bin_size: int, max_faces_per_bin: int): |
|
|
""" |
|
|
Create the PyTorch3D MeshRenderer with rasterizer and shader. |
|
|
""" |
|
|
self.renderer = MeshRenderer( |
|
|
rasterizer=MeshRasterizer( |
|
|
raster_settings=RasterizationSettings( |
|
|
image_size=self.image_sizes[0], |
|
|
blur_radius=1e-5, |
|
|
bin_size=bin_size, |
|
|
max_faces_per_bin=max_faces_per_bin, |
|
|
) |
|
|
), |
|
|
shader=SoftPhongShader( |
|
|
device=self.device, |
|
|
lights=self.lights, |
|
|
), |
|
|
) |
|
|
|
|
|
def _initialize_camera_params(self): |
|
|
""" |
|
|
Initialize camera intrinsics and extrinsics. |
|
|
""" |
|
|
|
|
|
self.R = torch.eye(3, device=self.device).unsqueeze(0) |
|
|
self.T = torch.zeros(1, 3, device=self.device) |
|
|
|
|
|
|
|
|
if isinstance(self.focal_length, (list, tuple)): |
|
|
fx, fy = self.focal_length |
|
|
else: |
|
|
fx = fy = self.focal_length |
|
|
|
|
|
self.K = torch.tensor( |
|
|
[[fx, 0, self.width / 2], |
|
|
[0, fy, self.height / 2], |
|
|
[0, 0, 1]], |
|
|
device=self.device, |
|
|
dtype=torch.float32, |
|
|
).unsqueeze(0) |
|
|
|
|
|
self.bboxes = torch.tensor([[0, 0, self.width, self.height]], dtype=torch.float32) |
|
|
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) |
|
|
|
|
|
|
|
|
self.cameras = self._create_camera_from_cv() |
|
|
|
|
|
def _create_camera_from_cv( |
|
|
self, |
|
|
R: torch.Tensor = None, |
|
|
T: torch.Tensor = None, |
|
|
K: torch.Tensor = None, |
|
|
image_size: torch.Tensor = None, |
|
|
) -> PerspectiveCameras: |
|
|
""" |
|
|
Create a PyTorch3D camera from OpenCV-style intrinsics and extrinsics. |
|
|
""" |
|
|
if R is None: |
|
|
R = self.R |
|
|
if T is None: |
|
|
T = self.T |
|
|
if K is None: |
|
|
K = self.K |
|
|
if image_size is None: |
|
|
image_size = torch.tensor(self.image_sizes, device=self.device) |
|
|
|
|
|
cameras = _cameras_from_opencv_projection(R, T, K, image_size) |
|
|
return cameras |
|
|
|
|
|
def render( |
|
|
self, |
|
|
verts_list: List[torch.Tensor], |
|
|
faces_list: List[torch.Tensor], |
|
|
colors_list: List[torch.Tensor], |
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Render a batch of meshes into an RGB image and mask. |
|
|
|
|
|
Args: |
|
|
verts_list (List[torch.Tensor]): List of vertex tensors. |
|
|
faces_list (List[torch.Tensor]): List of face tensors. |
|
|
colors_list (List[torch.Tensor]): List of per-vertex color tensors. |
|
|
|
|
|
Returns: |
|
|
rend (np.ndarray): Rendered RGB image as uint8 array. |
|
|
mask (np.ndarray): Boolean mask of rendered pixels. |
|
|
""" |
|
|
all_verts = [] |
|
|
all_faces = [] |
|
|
all_colors = [] |
|
|
vertex_offset = 0 |
|
|
|
|
|
for verts, faces, colors in zip(verts_list, faces_list, colors_list): |
|
|
all_verts.append(verts) |
|
|
all_colors.append(colors) |
|
|
all_faces.append(faces + vertex_offset) |
|
|
vertex_offset += verts.shape[0] |
|
|
|
|
|
|
|
|
all_verts = torch.cat(all_verts, dim=0) |
|
|
all_faces = torch.cat(all_faces, dim=0) |
|
|
all_colors = torch.cat(all_colors, dim=0) |
|
|
|
|
|
mesh = Meshes( |
|
|
verts=[all_verts], |
|
|
faces=[all_faces], |
|
|
textures=TexturesVertex(all_colors.unsqueeze(0)), |
|
|
) |
|
|
|
|
|
|
|
|
images = self.renderer(mesh, cameras=self.cameras, lights=self.lights) |
|
|
|
|
|
rend = np.clip(images[0, ..., :3].cpu().numpy() * 255, 0, 255).astype(np.uint8) |
|
|
mask = images[0, ..., -1].cpu().numpy() > 0 |
|
|
|
|
|
return rend, mask |