Source code for pytorch3d.ops.interp_face_attrs

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from pytorch3d import _C
from torch.autograd import Function
from torch.autograd.function import once_differentiable

[docs] def interpolate_face_attributes( pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, face_attributes: torch.Tensor, ) -> torch.Tensor: """ Interpolate arbitrary face attributes using the barycentric coordinates for each pixel in the rasterized output. Args: pix_to_face: LongTensor of shape (...) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. A value < 0 indicates that the pixel does not overlap any face and should be skipped. barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordinates of each pixel relative to the faces (in the packed representation) which overlap the pixel. face_attributes: packed attributes of shape (total_faces, 3, D), specifying the value of the attribute for each vertex in the face. Returns: pixel_vals: tensor of shape (N, H, W, K, D) giving the interpolated value of the face attribute for each pixel. """ # Check shapes F, FV, D = face_attributes.shape if FV != 3: raise ValueError("Faces can only have three vertices; got %r" % FV) N, H, W, K, _ = barycentric_coords.shape if pix_to_face.shape != (N, H, W, K): msg = "pix_to_face must have shape (batch_size, H, W, K); got %r" raise ValueError(msg % (pix_to_face.shape,)) # On CPU use the python version # TODO: Implement a C++ version of this function if not pix_to_face.is_cuda: args = (pix_to_face, barycentric_coords, face_attributes) return interpolate_face_attributes_python(*args) # Otherwise flatten and call the custom autograd function N, H, W, K = pix_to_face.shape pix_to_face = pix_to_face.view(-1) barycentric_coords = barycentric_coords.view(N * H * W * K, 3) args = (pix_to_face, barycentric_coords, face_attributes) out = _InterpFaceAttrs.apply(*args) out = out.view(N, H, W, K, -1) return out
class _InterpFaceAttrs(Function): @staticmethod def forward(ctx, pix_to_face, barycentric_coords, face_attrs): args = (pix_to_face, barycentric_coords, face_attrs) ctx.save_for_backward(*args) return _C.interp_face_attrs_forward(*args) @staticmethod @once_differentiable def backward(ctx, grad_pix_attrs): args = ctx.saved_tensors args = args + (grad_pix_attrs,) grads = _C.interp_face_attrs_backward(*args) grad_pix_to_face = None grad_barycentric_coords = grads[0] grad_face_attrs = grads[1] return grad_pix_to_face, grad_barycentric_coords, grad_face_attrs def interpolate_face_attributes_python( pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, face_attributes: torch.Tensor, ) -> torch.Tensor: F, FV, D = face_attributes.shape N, H, W, K, _ = barycentric_coords.shape # Replace empty pixels in pix_to_face with 0 in order to interpolate. mask = pix_to_face < 0 pix_to_face = pix_to_face.clone() pix_to_face[mask] = 0 idx = pix_to_face.view(N * H * W * K, 1, 1).expand(N * H * W * K, 3, D) pixel_face_vals = face_attributes.gather(0, idx).view(N, H, W, K, 3, D) pixel_vals = (barycentric_coords[..., None] * pixel_face_vals).sum(dim=-2) pixel_vals[mask] = 0 # Replace masked values in output. return pixel_vals