Source code for pytorch3d.implicitron.models.model_dbir

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import Any, Dict, List, Optional, Tuple

import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.implicitron.tools.point_cloud_utils import (
    get_rgbd_point_cloud,
    render_point_cloud_pytorch3d,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds

from .base_model import ImplicitronModelBase, ImplicitronRender
from .renderer.base import EvaluationMode


[docs] @registry.register class ModelDBIR(ImplicitronModelBase): """ A simple depth-based image rendering model. Args: render_image_width: The width of the rendered rectangular images. render_image_height: The height of the rendered rectangular images. bg_color: The color of the background. max_points: Maximum number of points in the point cloud formed by unprojecting all source view depths. If more points are present, they are randomly subsampled to this number of points without replacement. """ render_image_width: int = 256 render_image_height: int = 256 bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) max_points: int = -1 # pyre-fixme[14]: `forward` overrides method defined in `ImplicitronModelBase` # inconsistently.
[docs] def forward( self, *, # force keyword-only arguments image_rgb: Optional[torch.Tensor], camera: CamerasBase, fg_probability: Optional[torch.Tensor], mask_crop: Optional[torch.Tensor], depth_map: Optional[torch.Tensor], sequence_name: Optional[List[str]], evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, frame_type: List[str], **kwargs, ) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass """ Given a set of input source cameras images and depth maps, unprojects all RGBD maps to a colored point cloud and renders into the target views. Args: camera: A batch of `N` PyTorch3D cameras. image_rgb: A batch of `N` images of shape `(N, 3, H, W)`. depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`. fg_probability: A batch of `N` foreground probability maps of shape `(N, 1, H, W)`. frame_type: A list of `N` strings containing frame type indicators which specify target and source views. Returns: preds: A dict with the following fields: implicitron_render: The rendered colors, depth and mask of the target views. point_cloud: The point cloud of the scene. It's renders are stored in `implicitron_render`. """ if image_rgb is None: raise ValueError("ModelDBIR needs image input") if fg_probability is None: raise ValueError("ModelDBIR needs foreground mask input") if depth_map is None: raise ValueError("ModelDBIR needs depth map input") is_known = is_known_frame(frame_type) is_known_idx = torch.where(is_known)[0] mask_fg = (fg_probability > 0.5).type_as(image_rgb) point_cloud = get_rgbd_point_cloud( # pyre-fixme[6]: For 1st param expected `Union[List[int], int, # LongTensor]` but got `Tensor`. camera[is_known_idx], image_rgb[is_known_idx], depth_map[is_known_idx], mask_fg[is_known_idx], ) pcl_size = point_cloud.num_points_per_cloud().item() if (self.max_points > 0) and (pcl_size > self.max_points): # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, # float, int]`. prm = torch.randperm(pcl_size)[: self.max_points] point_cloud = Pointclouds( point_cloud.points_padded()[:, prm, :], # pyre-fixme[16]: Optional type has no attribute `__getitem__`. features=point_cloud.features_padded()[:, prm, :], ) is_target_idx = torch.where(~is_known)[0] depth_render, image_render, mask_render = [], [], [] # render into target frames in a for loop to save memory for tgt_idx in is_target_idx: _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d( camera[int(tgt_idx)], point_cloud, render_size=(self.render_image_height, self.render_image_width), point_radius=1e-2, topk=10, bg_color=self.bg_color, ) _image_render = _image_render.clamp(0.0, 1.0) # the mask is the set of pixels with opacity bigger than eps _mask_render = (_mask_render > 1e-4).float() depth_render.append(_depth_render) image_render.append(_image_render) mask_render.append(_mask_render) implicitron_render = ImplicitronRender( **{ k: torch.cat(v, dim=0) for k, v in zip( ["depth_render", "image_render", "mask_render"], [depth_render, image_render, mask_render], ) } ) preds = { "implicitron_render": implicitron_render, "point_cloud": point_cloud, } return preds