class pytorch3d.implicitron.models.generic_model.GenericModel(*args, **kwargs)[source]

Bases: ImplicitronModelBase

GenericModel is a wrapper for the neural implicit rendering and reconstruction pipeline which consists of the following sequence of 7 steps (steps 2–4 are normally skipped in overfitting scenario, since conditioning on source views does not add much information; otherwise they should be present altogether):

Rays are sampled from an image grid based on the target view(s).

│_____________ │ │ │ ▼ │ (2) Feature Extraction (optional) │ ———————– │ A feature extractor (e.g. a convolutional │ neural net) is used to extract image features │ from the source view(s). │ │ │ ▼ │ (3) View Sampling (optional) │ —————— │ Image features are sampled at the 2D projections │ of a set of 3D points along each of the sampled │ target rays from (1). │ │ │ ▼ │ (4) Feature Aggregation (optional) │ —————— │ Aggregate features and masks sampled from │ image view(s) in (3). │ │ │____________▼ │ ▼

Evaluate the implicit function(s) at the sampled ray points (optionally pass in the aggregated image features from (4)). (also optionally pass in a global encoding from global_encoder).

│ ▼

Render the image into the target cameras by raymarching along the sampled rays and aggregating the colors and densities output by the implicit function in (5).

│ ▼

Compute losses based on the predicted target image(s).

The forward function of GenericModel executes this sequence of steps. Currently, steps 1, 3, 4, 5, 6 can be customized by intializing a subclass of the appropriate baseclass and adding the newly created module to the registry. Please see for more details on how to create and register a custom component.

In the config .yaml files for experiments, the parameters below are contained in the model_factory_ImplicitronModelFactory_args.model_GenericModel_args node. As GenericModel derives from ReplaceableBase, the input arguments are parsed by the run_auto_creation function to initialize the necessary member modules. Please see implicitron_trainer/ for more details on this process.

  • mask_images – Whether or not to mask the RGB image background given the foreground mask (the fg_probability argument of GenericModel.forward)

  • mask_depths – Whether or not to mask the depth image background given the foreground mask (the fg_probability argument of GenericModel.forward)

  • render_image_width – Width of the output image to render

  • render_image_height – Height of the output image to render

  • mask_threshold – If greater than 0.0, the foreground mask is thresholded by this value before being applied to the RGB/Depth images

  • output_rasterized_mc – If True, visualize the Monte-Carlo pixel renders by splatting onto an image grid. Default: False.

  • bg_color – RGB values for setting the background color of input image if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own way to determine the background color of its output, unrelated to this.

  • num_passes – The specified implicit_function is initialized num_passes times and run sequentially.

  • chunk_size_grid – The total number of points which can be rendered per chunk. This is used to compute the number of rays used per chunk when the chunked version of the renderer is used (in order to fit rendering on all rays in memory)

  • render_features_dimensions – The number of output features to render. Defaults to 3, corresponding to RGB images.

  • n_train_target_views – The number of cameras to render into at training time; first n_train_target_views in the batch are considered targets, the rest are sources.

  • sampling_mode_training – The sampling method to use during training. Must be a value from the RenderSamplingMode Enum.

  • sampling_mode_evaluation – Same as above but for evaluation.

  • global_encoder_class_type – The name of the class to use for global_encoder, which must be available in the registry. Or None to disable global encoder.

  • global_encoder – An instance of GlobalEncoder. This is used to generate an encoding of the image (referred to as the global_code) that can be used to model aspects of the scene such as multiple objects or morphing objects. It is up to the implicit function definition how to use it, but the most typical way is to broadcast and concatenate to the other inputs for the implicit function.

  • raysampler_class_type – The name of the raysampler class which is available in the global registry.

  • raysampler – An instance of RaySampler which is used to emit rays from the target view(s).

  • renderer_class_type – The name of the renderer class which is available in the global registry.

  • renderer – A renderer class which inherits from BaseRenderer. This is used to generate the images from the target view(s).

  • image_feature_extractor_class_type – If a str, constructs and enables the image_feature_extractor object of this type. Or None if not needed.

  • image_feature_extractor – A module for extrating features from an input image.

  • view_pooler_enabled – If True, constructs and enables the view_pooler object. This means features are sampled from the source image(s) at the projected 2d locations of the sampled 3d ray points from the target view(s), i.e. this activates step (3) above.

  • view_pooler – An instance of ViewPooler which is used for sampling of image-based features at the 2D projections of a set of 3D points and aggregating the sampled features.

  • implicit_function_class_type – The type of implicit function to use which is available in the global registry.

  • implicit_function – An instance of ImplicitFunctionBase. The actual implicit functions are initialised to be in self._implicit_functions.

  • view_metrics – An instance of ViewMetricsBase used to compute loss terms which are independent of the model’s parameters.

  • view_metrics_class_type – The type of view metrics to use, must be available in the global registry.

  • regularization_metrics – An instance of RegularizationMetricsBase used to compute regularization terms which can depend on the model’s parameters.

  • regularization_metrics_class_type – The type of regularization metrics to use, must be available in the global registry.

  • loss_weights – A dictionary with a {loss_name: weight} mapping; see documentation for ViewMetrics class for available loss functions.

  • log_vars – A list of variable names which should be logged. The names should correspond to a subset of the keys of the dict preds output by the forward function.

mask_images: bool = True
mask_depths: bool = True
render_image_width: int = 400
render_image_height: int = 400
mask_threshold: float = 0.5
output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
num_passes: int = 1
chunk_size_grid: int = 4096
render_features_dimensions: int = 3
tqdm_trigger_threshold: int = 16
n_train_target_views: int = 1
sampling_mode_training: str = 'mask_sample'
sampling_mode_evaluation: str = 'full_grid'
global_encoder_class_type: str | None = None
global_encoder: GlobalEncoderBase | None
raysampler_class_type: str = 'AdaptiveRaySampler'
raysampler: RaySamplerBase
renderer_class_type: str = 'MultiPassEmissionAbsorptionRenderer'
renderer: BaseRenderer
image_feature_extractor: FeatureExtractorBase | None
image_feature_extractor_class_type: str | None = None
view_pooler_enabled: bool = False
view_pooler: ViewPooler | None
implicit_function_class_type: str = 'NeuralRadianceFieldImplicitFunction'
implicit_function: ImplicitFunctionBase
view_metrics: ViewMetricsBase
view_metrics_class_type: str = 'ViewMetrics'
regularization_metrics: RegularizationMetricsBase
regularization_metrics_class_type: str = 'RegularizationMetrics'
loss_weights: Dict[str, float] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<function GenericModel.<lambda>>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)
log_vars: List[str] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<function GenericModel.<lambda>>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)
classmethod pre_expand() None[source]
forward(*, image_rgb: Tensor | None, camera: CamerasBase, fg_probability: Tensor | None = None, mask_crop: Tensor | None = None, depth_map: Tensor | None = None, sequence_name: List[str] | None = None, frame_timestamp: Tensor | None = None, evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, **kwargs) Dict[str, Any][source]
  • image_rgb – A tensor of shape (B, 3, H, W) containing a batch of rgb images; the first min(B, n_train_target_views) images are considered targets and are used to supervise the renders; the rest corresponding to the source viewpoints from which features will be extracted.

  • camera – An instance of CamerasBase containing a batch of B cameras corresponding to the viewpoints of target images, from which the rays will be sampled, and source images, which will be used for intersecting with target rays.

  • fg_probability – A tensor of shape (B, 1, H, W) containing a batch of foreground masks.

  • mask_crop – A binary tensor of shape (B, 1, H, W) denoting valid regions in the input images (i.e. regions that do not correspond to, e.g., zero-padding). When the RaySampler’s sampling mode is set to “mask_sample”, rays will be sampled in the non zero regions.

  • depth_map – A tensor of shape (B, 1, H, W) containing a batch of depth maps.

  • sequence_name – A list of B strings corresponding to the sequence names from which images image_rgb were extracted. They are used to match target frames with relevant source frames.

  • frame_timestamp – Optionally a tensor of shape (B,) containing a batch of frame timestamps.

  • evaluation_mode – one of EvaluationMode.TRAINING or EvaluationMode.EVALUATION which determines the settings used for rendering.



A dictionary containing all outputs of the forward pass including the

rendered images, depths, masks, losses and other metrics.

visualize(viz: Visdom | None, visdom_env_imgs: str, preds: Dict[str, Any], prefix: str) None[source]

Helper function to visualize the predictions generated in the forward pass.

  • viz – Visdom connection object

  • visdom_env_imgs – name of visdom environment for the images.

  • preds – predictions dict like returned by forward()

  • prefix – prepended to the names of images

classmethod raysampler_tweak_args(type, args: DictConfig) None[source]

We don’t expose certain fields of the raysampler because we want to set them from our own members.

classmethod renderer_tweak_args(type, args: DictConfig) None[source]

We don’t expose certain fields of the renderer because we want to set them based on other inputs.

create_implicit_function() None[source]

No-op called by run_auto_creation so that self.implicit_function does not get created. __post_init__ creates the implicit function(s) in wrappers explicitly in self._implicit_functions.

classmethod implicit_function_tweak_args(type, args: DictConfig) None[source]

We don’t expose certain implicit_function fields because we want to set them based on other inputs.