Source code for pytorch3d.implicitron.models.metrics

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import warnings
from typing import Any, Dict, Optional

import torch
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.ops import padded_to_packed
from pytorch3d.renderer import utils as rend_utils

from .renderer.base import RendererOutput


[docs] class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module): """ Replaceable abstract base for regularization metrics. `forward()` method produces regularization metrics and (unlike ViewMetrics) can depend on the model's parameters. """
[docs] def forward( self, model: Any, keys_prefix: str = "loss_", **kwargs ) -> Dict[str, Any]: """ Calculates various regularization terms useful for supervising differentiable rendering pipelines. Args: model: A model instance. Useful, for example, to implement weights-based regularization. keys_prefix: A common prefix for all keys in the output dictionary containing all regularization metrics. Returns: A dictionary with the resulting regularization metrics. The items will have form `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. """ raise NotImplementedError
[docs] class ViewMetricsBase(ReplaceableBase, torch.nn.Module): """ Replaceable abstract base for model metrics. `forward()` method produces losses and other metrics. """
[docs] def forward( self, raymarched: RendererOutput, ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, mask_crop: Optional[torch.Tensor] = None, keys_prefix: str = "loss_", **kwargs, ) -> Dict[str, Any]: """ Calculates various metrics and loss functions useful for supervising differentiable rendering pipelines. Any additional parameters can be passed in the `raymarched.aux` dictionary. Args: results: A dictionary with the resulting view metrics. The items will have form `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. raymarched: Output of the renderer. ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched object image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb values. depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth values. fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth foreground masks. keys_prefix: A common prefix for all keys in the output dictionary containing all view metrics. Returns: A dictionary with the resulting view metrics. The items will have form `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. """ raise NotImplementedError()
[docs] @registry.register class RegularizationMetrics(RegularizationMetricsBase):
[docs] def forward( self, model: Any, keys_prefix: str = "loss_", **kwargs ) -> Dict[str, Any]: """ Calculates the AD penalty, or returns an empty dict if the model's autoencoder is inactive. Args: model: A model instance. keys_prefix: A common prefix for all keys in the output dictionary containing all regularization metrics. Returns: A dictionary with the resulting regularization metrics. The items will have form `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. The calculated metric is: autoencoder_norm: Autoencoder weight norm regularization term. """ metrics = {} if getattr(model, "sequence_autodecoder", None) is not None: ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm() if ad_penalty is not None: metrics["autodecoder_norm"] = ad_penalty if keys_prefix is not None: metrics = {(keys_prefix + k): v for k, v in metrics.items()} return metrics
[docs] @registry.register class ViewMetrics(ViewMetricsBase):
[docs] def forward( self, raymarched: RendererOutput, ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, mask_crop: Optional[torch.Tensor] = None, keys_prefix: str = "loss_", **kwargs, ) -> Dict[str, Any]: """ Calculates various differentiable metrics useful for supervising differentiable rendering pipelines. Args: results: A dict to store the results in. raymarched.features: Predicted rgb or feature values. raymarched.depths: A tensor of shape `(B, ..., 1)` containing predicted depth values. raymarched.masks: A tensor of shape `(B, ..., 1)` containing predicted foreground masks. raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an evaluation of a gradient of a signed distance function w.r.t. input 3D coordinates used to compute the eikonal loss. raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a `Hg x Wg x Dg` voxel grid of density values. ray_bundle: ImplicitronRayBundle object which was used to produce the raymarched object image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb values. depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth values. fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth foreground masks. keys_prefix: A common prefix for all keys in the output dictionary containing all view metrics. Returns: A dictionary `{metric_name_i: metric_value_i}` keyed by the names of the output metrics `metric_name_i` with their corresponding values `metric_value_i` represented as 0-dimensional float tensors. The calculated metrics are: rgb_huber: A robust huber loss between `image_pred` and `image`. rgb_mse: Mean squared error between `image_pred` and `image`. rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`. rgb_psnr_fg: Peak signal-to-noise ratio between the foreground region of `image_pred` and `image` as defined by `mask`. rgb_mse_fg: Mean squared error between the foreground region of `image_pred` and `image` as defined by `mask`. mask_neg_iou: (1 - intersection-over-union) between `mask_pred` and `mask`. mask_bce: Binary cross entropy between `mask_pred` and `mask`. mask_beta_prior: A loss enforcing strictly binary values of `mask_pred`: `log(mask_pred) + log(1-mask_pred)` depth_abs: Mean per-pixel L1 distance between `depth_pred` and `depth`. depth_abs_fg: Mean per-pixel L1 distance between the foreground region of `depth_pred` and `depth` as defined by `mask`. eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`. density_tv: The Total Variation regularizer of density values in `density_grid` (sum of L1 distances of values of all 4-neighbouring cells). depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative predicted depth values. """ metrics = self._calculate_stage( raymarched, ray_bundle, image_rgb, depth_map, fg_probability, mask_crop, keys_prefix, ) if raymarched.prev_stage: metrics.update( self( raymarched.prev_stage, ray_bundle, image_rgb, depth_map, fg_probability, mask_crop, keys_prefix=(keys_prefix + "prev_stage_"), ) ) return metrics
def _calculate_stage( self, raymarched: RendererOutput, ray_bundle: ImplicitronRayBundle, image_rgb: Optional[torch.Tensor] = None, depth_map: Optional[torch.Tensor] = None, fg_probability: Optional[torch.Tensor] = None, mask_crop: Optional[torch.Tensor] = None, keys_prefix: str = "loss_", **kwargs, ) -> Dict[str, Any]: """ Calculate metrics for the current stage. """ # TODO: extract functions # reshape from B x ... x DIM to B x DIM x -1 x 1 image_rgb_pred, fg_probability_pred, depth_map_pred = [ _reshape_nongrid_var(x) for x in [raymarched.features, raymarched.masks, raymarched.depths] ] xys = ray_bundle.xys # If ray_bundle is packed than we can sample images in padded state to lower # memory requirements. Instead of having one image for every element in # ray_bundle we can than have one image per unique sampled camera. if ray_bundle.is_packed(): xys, first_idxs, num_inputs = ray_bundle.get_padded_xys() # reshape the sampling grid as well # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var # now that we use rend_utils.ndc_grid_sample xys = xys.reshape(xys.shape[0], -1, 1, 2) # closure with the given xys def sample_full(tensor, mode): if tensor is None: return tensor return rend_utils.ndc_grid_sample(tensor, xys, mode=mode) def sample_packed(tensor, mode): if tensor is None: return tensor # select images that corespond to sampled cameras if raybundle is packed tensor = tensor[ray_bundle.camera_ids] if ray_bundle.is_packed(): # select images that corespond to sampled cameras if raybundle is packed tensor = tensor[ray_bundle.camera_ids] result = rend_utils.ndc_grid_sample(tensor, xys, mode=mode) return padded_to_packed(result, first_idxs, num_inputs, max_size_dim=2)[ :, :, None ] # the result is [n_rays_total_training, 3, 1, 1] sample = sample_packed if ray_bundle.is_packed() else sample_full # eval all results in this size image_rgb = sample(image_rgb, mode="bilinear") depth_map = sample(depth_map, mode="nearest") fg_probability = sample(fg_probability, mode="nearest") mask_crop = sample(mask_crop, mode="nearest") if mask_crop is None and image_rgb_pred is not None: mask_crop = torch.ones_like(image_rgb_pred[:, :1]) if mask_crop is None and depth_map_pred is not None: mask_crop = torch.ones_like(depth_map_pred[:, :1]) metrics = {} if image_rgb is not None and image_rgb_pred is not None: metrics.update( _rgb_metrics( image_rgb, image_rgb_pred, fg_probability, fg_probability_pred, mask_crop, ) ) if fg_probability_pred is not None: metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred) if fg_probability is not None and fg_probability_pred is not None: metrics["mask_neg_iou"] = utils.neg_iou_loss( fg_probability_pred, fg_probability, mask=mask_crop ) metrics["mask_bce"] = utils.calc_bce( fg_probability_pred, fg_probability, mask=mask_crop ) if depth_map is not None and depth_map_pred is not None: assert mask_crop is not None _, abs_ = utils.eval_depth( depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0 ) metrics["depth_abs"] = abs_.mean() if fg_probability is not None: mask = fg_probability * mask_crop _, abs_ = utils.eval_depth( depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0 ) metrics["depth_abs_fg"] = abs_.mean() # regularizers grad_theta = raymarched.aux.get("grad_theta") if grad_theta is not None: metrics["eikonal"] = _get_eikonal_loss(grad_theta) density_grid = raymarched.aux.get("density_grid") if density_grid is not None: metrics["density_tv"] = _get_grid_tv_loss(density_grid) if depth_map_pred is not None: metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred) if keys_prefix is not None: metrics = {(keys_prefix + k): v for k, v in metrics.items()} return metrics
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop): assert masks_crop is not None if images.shape[1] != images_pred.shape[1]: raise ValueError( f"Network output's RGB images had {images_pred.shape[1]} " f"channels. {images.shape[1]} expected." ) rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) rgb_loss = utils.huber(rgb_squared, scaling=0.03) crop_mass = masks_crop.sum().clamp(1.0) results = { "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), } if masks is not None: masks = masks_crop * masks results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks) results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0) return results def _get_eikonal_loss(grad_theta): return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5): if log_domain: if (grid <= -eps).any(): warnings.warn("Grid has negative values; this will produce NaN loss") grid = torch.log(grid + eps) # this is an isotropic version, note that it ignores last rows/cols return torch.mean( utils.safe_sqrt( (grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2 + (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2 + (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2, eps=1e-5, ) ) def _get_depth_neg_penalty_loss(depth): neg_penalty = depth.clamp(min=None, max=0.0) ** 2 return torch.mean(neg_penalty) def _reshape_nongrid_var(x): if x is None: return None ba, *_, dim = x.shape return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous()