# @lint-ignore-every LICENSELINT
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
# pyre-unsafe
import logging
from typing import List, Tuple
import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import enable_get_default_args
from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch import nn
logger = logging.getLogger(__name__)
[docs]
class RayNormalColoringNetwork(torch.nn.Module):
"""
Members:
d_in and feature_vector_size: Sum of these is the input
dimension. These must add up to the sum of
- 3 [for the points]
- 3 unless mode=no_normal [for the normals]
- 3 unless mode=no_view_dir [for view directions]
- the feature size, [number of channels in feature_vectors]
d_out: dimension of output.
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
part of the network input.
dims: list of hidden layer sizes.
weight_norm: whether to apply weight normalization to each layer.
n_harmonic_functions_dir:
If >0, use a harmonic embedding with this number of
harmonic functions for the view direction. Otherwise view directions
are fed without embedding, unless mode is `no_view_dir`.
pooled_feature_dim: If a pooling function is in use (provided as
pooling_fn to forward()) this must be its number of features.
Otherwise this must be set to 0. (If used from GenericModel,
this will be set automatically.)
"""
def __init__(
self,
feature_vector_size: int = 3,
mode: str = "idr",
d_in: int = 9,
d_out: int = 3,
dims: Tuple[int, ...] = (512, 512, 512, 512),
weight_norm: bool = True,
n_harmonic_functions_dir: int = 0,
pooled_feature_dim: int = 0,
) -> None:
super().__init__()
self.mode = mode
self.output_dimensions = d_out
dims_full: List[int] = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None
if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True
)
dims_full[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
logger.info("Pooled features in rendering network.")
dims_full[0] += pooled_feature_dim
self.num_layers = len(dims_full)
layers = []
for layer_idx in range(self.num_layers - 1):
out_dim = dims_full[layer_idx + 1]
lin = nn.Linear(dims_full[layer_idx], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
layers.append(lin)
self.linear_layers = torch.nn.ModuleList(layers)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
[docs]
def forward(
self,
feature_vectors: torch.Tensor,
points,
normals,
ray_bundle: ImplicitronRayBundle,
masks=None,
pooling_fn=None,
):
if masks is not None and not masks.any():
return torch.zeros_like(normals)
view_dirs = ray_bundle.directions
if masks is not None:
# in case of IDR, other outputs are passed here after applying the mask
view_dirs = view_dirs.reshape(view_dirs.shape[0], -1, 3)[
:, masks.reshape(-1)
]
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
if self.mode == "idr":
rendering_input = torch.cat(
[points, view_dirs, normals, feature_vectors], dim=-1
)
elif self.mode == "no_view_dir":
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
elif self.mode == "no_normal":
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
else:
raise ValueError(f"Unsupported rendering mode: {self.mode}")
if pooling_fn is not None:
featspool = pooling_fn(points[None])[0]
rendering_input = torch.cat((rendering_input, featspool), dim=-1)
x = rendering_input
for layer_idx in range(self.num_layers - 1):
x = self.linear_layers[layer_idx](x)
if layer_idx < self.num_layers - 2:
x = self.relu(x)
x = self.tanh(x)
return x
enable_get_default_args(RayNormalColoringNetwork)