Source code for pytorch3d.ops.vert_align

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.


import torch
import torch.nn.functional as F


[docs]def vert_align( feats, verts, return_packed: bool = False, interp_mode: str = "bilinear", padding_mode: str = "zeros", align_corners: bool = True, ) -> torch.Tensor: """ Sample vertex features from a feature map. This operation is called "perceptual feature pooling" in [1] or "vert align" in [2]. [1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single RGB Images", ECCV 2018. [2] Gkioxari et al, "Mesh R-CNN", ICCV 2019 Args: feats: FloatTensor of shape (N, C, H, W) representing image features from which to sample or a list of features each with potentially different C, H or W dimensions. verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes or Pointclouds) with `verts_padded' or `points_padded' as an attribute giving the (x, y, z) vertex positions for which to sample. (x, y) verts should be normalized such that (-1, -1) corresponds to top-left and (+1, +1) to bottom-right location in the input feature map. return_packed: (bool) Indicates whether to return packed features interp_mode: (str) Specifies how to interpolate features. ('bilinear' or 'nearest') padding_mode: (str) Specifies how to handle vertices outside of the [-1, 1] range. ('zeros', 'reflection', or 'border') align_corners (bool): Geometrically, we consider the pixels of the input as squares rather than points. If set to ``True``, the extrema (``-1`` and ``1``) are considered as referring to the center points of the input's corner pixels. If set to ``False``, they are instead considered as referring to the corner points of the input's corner pixels, making the sampling more resolution agnostic. Default: ``True`` Returns: feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for each vertex. If feats is a list, we return concatenated features in axis=2 of shape (N, V, sum(C_n)) where C_n = feats[n].shape[1]. If return_packed = True, the features are transformed to a packed representation of shape (sum(V), C) """ if torch.is_tensor(verts): if verts.dim() != 3: raise ValueError("verts tensor should be 3 dimensional") grid = verts elif hasattr(verts, "verts_padded"): grid = verts.verts_padded() elif hasattr(verts, "points_padded"): grid = verts.points_padded() else: raise ValueError( "verts must be a tensor or have a " + "`points_padded' or`verts_padded` attribute." ) grid = grid[:, None, :, :2] # (N, 1, V, 2) if torch.is_tensor(feats): feats = [feats] for feat in feats: if feat.dim() != 4: raise ValueError("feats must have shape (N, C, H, W)") if grid.shape[0] != feat.shape[0]: raise ValueError("inconsistent batch dimension") feats_sampled = [] for feat in feats: feat_sampled = F.grid_sample( feat, grid, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners, ) # (N, C, 1, V) feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C) feats_sampled.append(feat_sampled) feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C)) if return_packed: # flatten the first two dimensions: (N*V, C) feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1]) if hasattr(verts, "verts_padded_to_packed_idx"): idx = ( verts.verts_padded_to_packed_idx() .view(-1, 1) .expand(-1, feats_sampled.shape[-1]) ) # pyre-fixme[16]: `Tensor` has no attribute `gather`. feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C) return feats_sampled