Source code for pytorch3d.implicitron.models.implicit_function.idr_feature_field

# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
#              implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv

# pyre-unsafe
import math
from typing import Optional, Tuple

import torch
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding

from torch import nn

from .base import ImplicitFunctionBase
from .utils import get_rays_points_world


[docs] @registry.register class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): """ Implicit function as used in http://github.com/lioryariv/idr. Members: d_in: dimension of the input point. n_harmonic_functions_xyz: If -1, do not embed the point. If >=0, use a harmonic embedding with this number of harmonic functions. (The harmonic embedding includes the input itself, so a value of 0 means the point is used but without any harmonic functions.) d_out and feature_vector_size: Sum of these is the output dimension. This implicit function thus returns a concatenation of `d_out` signed distance function values and `feature_vector_size` features (such as colors). When used in `GenericModel`, `feature_vector_size` corresponds is automatically set to `render_features_dimensions`. dims: list of hidden layer sizes. geometric_init: whether to use custom weight initialization in linear layers. If False, pytorch default (uniform sampling) is used. bias: if geometric_init=True, initial value for bias subtracted in the last layer. skip_in: List of indices of layers that receive as input the initial value concatenated with the output of the previous layers. weight_norm: whether to apply weight normalization to each layer. pooled_feature_dim: If view pooling is in use (provided as fun_viewpool to forward()) this must be its number of features. Otherwise this must be set to 0. (If used from GenericModel, this config value will be overridden automatically.) encoding_dim: If global coding is in use (provided as global_code to forward()) this must be its number of featuress. Otherwise this must be set to 0. (If used from GenericModel, this config value will be overridden automatically.) """ feature_vector_size: int = 3 d_in: int = 3 d_out: int = 1 dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512) geometric_init: bool = True bias: float = 1.0 skip_in: Tuple[int, ...] = () weight_norm: bool = True n_harmonic_functions_xyz: int = 0 pooled_feature_dim: int = 0 encoding_dim: int = 0 def __post_init__(self): dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] self.embed_fn = None if self.n_harmonic_functions_xyz >= 0: self.embed_fn = HarmonicEmbedding( self.n_harmonic_functions_xyz, append_input=True ) dims[0] = self.embed_fn.get_output_dim() if self.pooled_feature_dim > 0: dims[0] += self.pooled_feature_dim if self.encoding_dim > 0: dims[0] += self.encoding_dim self.num_layers = len(dims) out_dim = 0 layers = [] for layer_idx in range(self.num_layers - 1): if layer_idx + 1 in self.skip_in: out_dim = dims[layer_idx + 1] - dims[0] else: out_dim = dims[layer_idx + 1] lin = nn.Linear(dims[layer_idx], out_dim) if self.geometric_init: if layer_idx == self.num_layers - 2: torch.nn.init.normal_( lin.weight, mean=math.pi**0.5 / dims[layer_idx] ** 0.5, std=0.0001, ) torch.nn.init.constant_(lin.bias, -self.bias) elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5) elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5) if self.weight_norm: lin = nn.utils.weight_norm(lin) layers.append(lin) self.linear_layers = torch.nn.ModuleList(layers) self.out_dim = out_dim self.softplus = nn.Softplus(beta=100) # pyre-fixme[14]: `forward` overrides method defined in `ImplicitFunctionBase` # inconsistently.
[docs] def forward( self, *, ray_bundle: Optional[ImplicitronRayBundle] = None, rays_points_world: Optional[torch.Tensor] = None, fun_viewpool=None, global_code=None, **kwargs, ): # this field only uses point locations # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] rays_points_world = get_rays_points_world(ray_bundle, rays_points_world) if rays_points_world.numel() == 0 or ( self.embed_fn is None and fun_viewpool is None and global_code is None ): return torch.tensor( [], device=rays_points_world.device, dtype=rays_points_world.dtype ).view(0, self.out_dim) embeddings = [] if self.embed_fn is not None: embeddings.append(self.embed_fn(rays_points_world)) if fun_viewpool is not None: assert rays_points_world.ndim == 2 pooled_feature = fun_viewpool(rays_points_world[None]) # TODO: pooled features are 4D! embeddings.append(pooled_feature) if global_code is not None: assert global_code.shape[0] == 1 # TODO: generalize to batches! # This will require changing raytracer code # embedding = embedding[None].expand(global_code.shape[0], *embedding.shape) embeddings.append( global_code[0, None, :].expand(rays_points_world.shape[0], -1) ) embedding = torch.cat(embeddings, dim=-1) x = embedding for layer_idx in range(self.num_layers - 1): if layer_idx in self.skip_in: x = torch.cat([x, embedding], dim=-1) / 2**0.5 x = self.linear_layers[layer_idx](x) if layer_idx < self.num_layers - 2: x = self.softplus(x) return x