Source code for pytorch3d.implicitron.models.view_pooler.feature_aggregator

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
    cameras_points_cartesian_product,
)
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase


[docs] class ReductionFunction(Enum): AVG = "avg" # simple average MAX = "max" # maximum STD = "std" # standard deviation STD_AVG = "std_avg" # average of per-dimension standard deviations
[docs] class FeatureAggregatorBase(ABC, ReplaceableBase): """ Base class for aggregating features. Typically, the aggregated features and their masks are output by `ViewSampler` which samples feature tensors extracted from a set of source images. Settings: exclude_target_view: If `True`/`False`, enables/disables pooling from target view to itself. exclude_target_view_mask_features: If `True`, mask the features from the target view before aggregation concatenate_output: If `True`, concatenate the aggregated features into a single tensor, otherwise return a dictionary mapping feature names to tensors. """ exclude_target_view: bool = True exclude_target_view_mask_features: bool = True concatenate_output: bool = True
[docs] @abstractmethod def forward( self, feats_sampled: Dict[str, torch.Tensor], masks_sampled: torch.Tensor, camera: Optional[CamerasBase] = None, pts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, where each `t_i` is a tensor of shape `(minibatch, n_source_views, n_samples, dim_i)`. masks_sampled: A binary mask represented as a tensor of shape `(minibatch, n_source_views, n_samples, 1)` denoting valid sampled features. camera: A batch of `n_source_views` `CamerasBase` objects corresponding to the source view cameras. pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the 3D points whose 2D projections to source views were sampled in order to generate `feats_sampled` and `masks_sampled`. Returns: feats_aggregated: If `concatenate_output==True`, a tensor of shape `(minibatch, reduce_dim, n_samples, sum(dim_1, ... dim_N))` containing the concatenation of the aggregated features `feats_sampled`. `reduce_dim` depends on the specific feature aggregator implementation and typically equals 1 or `n_source_views`. If `concatenate_output==False`, the aggregator does not concatenate the aggregated features and returns a dictionary of per-feature aggregations `{f_i: t_i_aggregated}` instead. Each `t_i_aggregated` is of shape `(minibatch, reduce_dim, n_samples, aggr_dim_i)`. """ raise NotImplementedError()
[docs] @abstractmethod def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): """ Returns the final dimensionality of the output aggregated features. Args: feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding to the `feats_sampled` argument of `forward`, or an `int` representing the sum of dimensionalities of each `t_i`. Returns: aggregated_feature_dim: The final dimensionality of the output aggregated features. """ raise NotImplementedError()
[docs] def has_aggregation(self) -> bool: """ Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1. Returns: has_aggregation: `True` if `reduce_dim==1`, else `False`. """ return hasattr(self, "reduction_functions")
[docs] @registry.register class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): """ This aggregator does not perform any feature aggregation. Depending on the settings the aggregator allows to mask target view features and concatenate the outputs. """
[docs] def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
[docs] def forward( self, feats_sampled: Dict[str, torch.Tensor], masks_sampled: torch.Tensor, camera: Optional[CamerasBase] = None, pts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, where each `t_i` is a tensor of shape `(minibatch, n_source_views, n_samples, dim_i)`. masks_sampled: A binary mask represented as a tensor of shape `(minibatch, n_source_views, n_samples, 1)` denoting valid sampled features. camera: A batch of `n_source_views` `CamerasBase` objects corresponding to the source view cameras. pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the 3D points whose 2D projections to source views were sampled in order to generate `feats_sampled` and `masks_sampled`. Returns: feats_aggregated: If `concatenate_output==True`, a tensor of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` with each `t_i_aggregated` of shape `(minibatch, n_source_views, n_samples, dim_i)`. """ if self.exclude_target_view_mask_features: feats_sampled = _mask_target_view_features(feats_sampled) feats_aggregated = feats_sampled if self.concatenate_output: feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) return feats_aggregated
[docs] @registry.register class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): """ Aggregates using a set of predefined `reduction_functions` and concatenates the results of each aggregation function along the channel dimension. The reduction functions singularize the second dimension of the sampled features which stacks the source views. Settings: reduction_functions: A list of `ReductionFunction`s` that reduce the the stack of source-view-specific features to a single feature. """ reduction_functions: Tuple[ReductionFunction, ...] = ( ReductionFunction.AVG, ReductionFunction.STD, )
[docs] def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): return _get_reduction_aggregator_feature_dim( feats_or_feats_dim, self.reduction_functions )
[docs] def forward( self, feats_sampled: Dict[str, torch.Tensor], masks_sampled: torch.Tensor, camera: Optional[CamerasBase] = None, pts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, where each `t_i` is a tensor of shape `(minibatch, n_source_views, n_samples, dim_i)`. masks_sampled: A binary mask represented as a tensor of shape `(minibatch, n_source_views, n_samples, 1)` denoting valid sampled features. camera: A batch of `n_source_views` `CamerasBase` objects corresponding to the source view cameras. pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the 3D points whose 2D projections to source views were sampled in order to generate `feats_sampled` and `masks_sampled`. Returns: feats_aggregated: If `concatenate_output==True`, a tensor of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` with each `t_i_aggregated` of shape `(minibatch, 1, n_samples, aggr_dim_i)`. """ pts_batch, n_cameras = masks_sampled.shape[:2] if self.exclude_target_view_mask_features: feats_sampled = _mask_target_view_features(feats_sampled) sampling_mask = _get_view_sampling_mask( n_cameras, pts_batch, masks_sampled.device, self.exclude_target_view, ) aggr_weigths = masks_sampled[..., 0] * sampling_mask[..., None] feats_aggregated = { k: _avgmaxstd_reduction_function( f, aggr_weigths, dim=1, reduction_functions=self.reduction_functions, ) for k, f in feats_sampled.items() } if self.concatenate_output: feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) return feats_aggregated
[docs] @registry.register class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): """ Performs a weighted aggregation using a set of predefined `reduction_functions` and concatenates the results of each aggregation function along the channel dimension. The weights are proportional to the cosine of the angle between the target ray and the source ray:: weight = ( dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight )**self.weight_by_ray_angle_gamma The reduction functions singularize the second dimension of the sampled features which stacks the source views. Settings: reduction_functions: A list of `ReductionFunction`s that reduce the the stack of source-view-specific features to a single feature. min_ray_angle_weight: The minimum possible aggregation weight before rasising to the power of `self.weight_by_ray_angle_gamma`. weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles used when calculating the angle-based aggregation weights. """ reduction_functions: Tuple[ReductionFunction, ...] = ( ReductionFunction.AVG, ReductionFunction.STD, ) weight_by_ray_angle_gamma: float = 1.0 min_ray_angle_weight: float = 0.1
[docs] def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): return _get_reduction_aggregator_feature_dim( feats_or_feats_dim, self.reduction_functions )
[docs] def forward( self, feats_sampled: Dict[str, torch.Tensor], masks_sampled: torch.Tensor, camera: Optional[CamerasBase] = None, pts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, where each `t_i` is a tensor of shape `(minibatch, n_source_views, n_samples, dim_i)`. masks_sampled: A binary mask represented as a tensor of shape `(minibatch, n_source_views, n_samples, 1)` denoting valid sampled features. camera: A batch of `n_source_views` `CamerasBase` objects corresponding to the source view cameras. pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the 3D points whose 2D projections to source views were sampled in order to generate `feats_sampled` and `masks_sampled`. Returns: feats_aggregated: If `concatenate_output==True`, a tensor of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` with each `t_i_aggregated` of shape `(minibatch, n_source_views, n_samples, dim_i)`. """ if camera is None: raise ValueError("camera cannot be None for angle weighted aggregation") if pts is None: raise ValueError("Points cannot be None for angle weighted aggregation") pts_batch, n_cameras = masks_sampled.shape[:2] if self.exclude_target_view_mask_features: feats_sampled = _mask_target_view_features(feats_sampled) view_sampling_mask = _get_view_sampling_mask( n_cameras, pts_batch, masks_sampled.device, self.exclude_target_view, ) aggr_weights = _get_angular_reduction_weights( view_sampling_mask, masks_sampled, camera, pts, self.min_ray_angle_weight, self.weight_by_ray_angle_gamma, ) assert torch.isfinite(aggr_weights).all() feats_aggregated = { k: _avgmaxstd_reduction_function( f, aggr_weights, dim=1, reduction_functions=self.reduction_functions, ) for k, f in feats_sampled.items() } if self.concatenate_output: feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) return feats_aggregated
[docs] @registry.register class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): """ This aggregator does not perform any feature aggregation. It only weights the features by the weights proportional to the cosine of the angle between the target ray and the source ray:: weight = ( dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight )**self.weight_by_ray_angle_gamma Settings: min_ray_angle_weight: The minimum possible aggregation weight before rasising to the power of `self.weight_by_ray_angle_gamma`. weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles used when calculating the angle-based aggregation weights. Additionally the aggregator allows to mask target view features and to concatenate the outputs. """ weight_by_ray_angle_gamma: float = 1.0 min_ray_angle_weight: float = 0.1
[docs] def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
[docs] def forward( self, feats_sampled: Dict[str, torch.Tensor], masks_sampled: torch.Tensor, camera: Optional[CamerasBase] = None, pts: Optional[torch.Tensor] = None, **kwargs, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, where each `t_i` is a tensor of shape `(minibatch, n_source_views, n_samples, dim_i)`. masks_sampled: A binary mask represented as a tensor of shape `(minibatch, n_source_views, n_samples, 1)` denoting valid sampled features. camera: A batch of `n_source_views` `CamerasBase` objects corresponding to the source view cameras. pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the 3D points whose 2D projections to source views were sampled in order to generate `feats_sampled` and `masks_sampled`. Returns: feats_aggregated: If `concatenate_output==True`, a tensor of shape `(minibatch, n_source_views, n_samples, sum(dim_1, ... dim_N))`. If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` with each `t_i_aggregated` of shape `(minibatch, n_source_views, n_samples, dim_i)`. """ if camera is None: raise ValueError("camera cannot be None for angle weighted aggregation") if pts is None: raise ValueError("Points cannot be None for angle weighted aggregation") pts_batch, n_cameras = masks_sampled.shape[:2] if self.exclude_target_view_mask_features: feats_sampled = _mask_target_view_features(feats_sampled) view_sampling_mask = _get_view_sampling_mask( n_cameras, pts_batch, masks_sampled.device, self.exclude_target_view, ) aggr_weights = _get_angular_reduction_weights( view_sampling_mask, masks_sampled, camera, pts, self.min_ray_angle_weight, self.weight_by_ray_angle_gamma, ) feats_aggregated = { k: f * aggr_weights[..., None] for k, f in feats_sampled.items() } if self.concatenate_output: feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) return feats_aggregated
def _get_reduction_aggregator_feature_dim( feats_or_feats_dim: Union[Dict[str, torch.Tensor], int], reduction_functions: Sequence[ReductionFunction], ) -> int: if isinstance(feats_or_feats_dim, int): feat_dim = feats_or_feats_dim else: feat_dim = int(sum(f.shape[1] for f in feats_or_feats_dim.values())) if len(reduction_functions) == 0: return feat_dim return sum( _get_reduction_function_output_dim( reduction_function, feat_dim, ) for reduction_function in reduction_functions ) def _get_reduction_function_output_dim( reduction_function: ReductionFunction, feat_dim: int, ) -> int: if reduction_function == ReductionFunction.STD_AVG: return 1 else: return feat_dim def _get_view_sampling_mask( n_cameras: int, pts_batch: int, device: Union[str, torch.device], exclude_target_view: bool, ): return ( -torch.eye(n_cameras, device=device, dtype=torch.float32) * float(exclude_target_view) + 1.0 )[:pts_batch] def _mask_target_view_features( feats_sampled: Dict[str, torch.Tensor], ): # mask out the sampled features to be sure we dont use them # anywhere later one_feature_sampled = next(iter(feats_sampled.values())) pts_batch, n_cameras = one_feature_sampled.shape[:2] view_sampling_mask = _get_view_sampling_mask( n_cameras, pts_batch, one_feature_sampled.device, True, ) view_sampling_mask = view_sampling_mask.view( pts_batch, n_cameras, *([1] * (one_feature_sampled.ndim - 2)) ) return {k: f * view_sampling_mask for k, f in feats_sampled.items()} def _get_angular_reduction_weights( view_sampling_mask: torch.Tensor, masks_sampled: torch.Tensor, camera: CamerasBase, pts: torch.Tensor, min_ray_angle_weight: float, weight_by_ray_angle_gamma: float, ): aggr_weights = masks_sampled.clone()[..., 0] assert not any(v is None for v in [camera, pts]) angle_weight = _get_ray_angle_weights( camera, pts, min_ray_angle_weight, weight_by_ray_angle_gamma, ) assert torch.isfinite(angle_weight).all() # multiply the final aggr weights with ray angles view_sampling_mask = view_sampling_mask.view( *view_sampling_mask.shape[:2], *([1] * (aggr_weights.ndim - 2)) ) aggr_weights = ( aggr_weights * angle_weight.reshape_as(aggr_weights) * view_sampling_mask ) return aggr_weights def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor): n_cameras = camera.R.shape[0] pts_batch = pts.shape[0] camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts) # does not produce nans randomly unlike get_camera_center() below cam_centers_rep = -torch.bmm( camera_rep.T[:, None], camera_rep.R.permute(0, 2, 1), ).reshape(-1, *([1] * (pts.ndim - 2)), 3) # cam_centers_rep = camera_rep.get_camera_center().reshape( # -1, *([1]*(pts.ndim - 2)), 3 # ) ray_dirs = F.normalize(pts_rep - cam_centers_rep, dim=-1) # camera_rep = [ pts_rep = [ # camera[0] pts[0], # camera[0] pts[1], # camera[0] ..., # ... pts[batch_pts-1], # camera[1] pts[0], # camera[1] pts[1], # camera[1] ..., # ... pts[batch_pts-1], # ... ..., # camera[n_cameras-1] pts[0], # camera[n_cameras-1] pts[1], # camera[n_cameras-1] ..., # ... pts[batch_pts-1], # ] ] ray_dirs_reshape = ray_dirs.view(n_cameras, pts_batch, -1, 3) # [ # [pts_0 in cam_0, pts_1 in cam_0, ..., pts_m in cam_0], # [pts_0 in cam_1, pts_1 in cam_1, ..., pts_m in cam_1], # ... # [pts_0 in cam_n, pts_1 in cam_n, ..., pts_m in cam_n], # ] ray_dirs_pts = torch.stack([ray_dirs_reshape[i, i] for i in range(pts_batch)]) ray_dir_dot_prods = (ray_dirs_pts[None] * ray_dirs_reshape).sum( dim=-1 ) # pts_batch x n_cameras x n_pts return ray_dir_dot_prods.transpose(0, 1) def _get_ray_angle_weights( camera: CamerasBase, pts: torch.Tensor, min_ray_angle_weight: float, weight_by_ray_angle_gamma: float, ): ray_dir_dot_prods = _get_ray_dir_dot_prods( camera, pts ) # pts_batch x n_cameras x ... x 3 angle_weight_01 = ray_dir_dot_prods * 0.5 + 0.5 # [-1, 1] to [0, 1] angle_weight = (angle_weight_01 + min_ray_angle_weight) ** weight_by_ray_angle_gamma return angle_weight def _avgmaxstd_reduction_function( x: torch.Tensor, w: torch.Tensor, reduction_functions: Sequence[ReductionFunction], dim: int = 1, ): """ Args: x: Features to aggreagate. Tensor of shape `(batch, n_views, ..., dim)`. w: Aggregation weights. Tensor of shape `(batch, n_views, ...,)`. dim: the dimension along which to aggregate. reduction_functions: The set of reduction functions. Returns: x_aggr: Aggregation of `x` to a tensor of shape `(batch, 1, ..., dim_aggregate)`. """ pooled_features = [] mu = None std = None if ReductionFunction.AVG in reduction_functions: # average pool mu = _avg_reduction_function(x, w, dim=dim) pooled_features.append(mu) if ReductionFunction.STD in reduction_functions: # standard-dev pool std = _std_reduction_function(x, w, dim=dim, mu=mu) pooled_features.append(std) if ReductionFunction.STD_AVG in reduction_functions: # average-of-standard-dev pool stdavg = _std_avg_reduction_function(x, w, dim=dim, mu=mu, std=std) pooled_features.append(stdavg) if ReductionFunction.MAX in reduction_functions: max_ = _max_reduction_function(x, w, dim=dim) pooled_features.append(max_) # cat all results along the feature dimension (the last dim) x_aggr = torch.cat(pooled_features, dim=-1) # zero out features that were all masked out # pyre-fixme[16]: `bool` has no attribute `type_as`. any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr) x_aggr = x_aggr * any_active[..., None] # some asserts to check that everything was done right assert torch.isfinite(x_aggr).all() assert x_aggr.shape[1] == 1 return x_aggr def _avg_reduction_function( x: torch.Tensor, w: torch.Tensor, dim: int = 1, ): mu = wmean(x, w, dim=dim, eps=1e-2) return mu def _std_reduction_function( x: torch.Tensor, w: torch.Tensor, dim: int = 1, mu: Optional[torch.Tensor] = None, # pre-computed mean ): if mu is None: mu = _avg_reduction_function(x, w, dim=dim) # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`. std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt() # FIXME: somehow this is extremely heavy in mem? return std def _std_avg_reduction_function( x: torch.Tensor, w: torch.Tensor, dim: int = 1, mu: Optional[torch.Tensor] = None, # pre-computed mean std: Optional[torch.Tensor] = None, # pre-computed std ): if std is None: std = _std_reduction_function(x, w, dim=dim, mu=mu) stdmean = std.mean(dim=-1, keepdim=True) return stdmean def _max_reduction_function( x: torch.Tensor, w: torch.Tensor, dim: int = 1, big_M_factor: float = 10.0, ): big_M = x.max(dim=dim, keepdim=True).values.abs() * big_M_factor max_ = (x * w - ((1 - w) * big_M)).max(dim=dim, keepdim=True).values return max_