Source code for pytorch3d.ops.sample_points_from_meshes

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


"""
This module implements utility functions for sampling points from
batches of meshes.
"""
import sys
from typing import Tuple, Union

import torch
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
from pytorch3d.ops.packed_to_padded import packed_to_padded
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments


[docs]def sample_points_from_meshes( meshes, num_samples: int = 10000, return_normals: bool = False, return_textures: bool = False, ) -> Union[ torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ]: """ Convert a batch of meshes to a batch of pointclouds by uniformly sampling points on the surface of the mesh with probability proportional to the face area. Args: meshes: A Meshes object with a batch of N meshes. num_samples: Integer giving the number of point samples per mesh. return_normals: If True, return normals for the sampled points. return_textures: If True, return textures for the sampled points. Returns: 3-element tuple containing - **samples**: FloatTensor of shape (N, num_samples, 3) giving the coordinates of sampled points for each mesh in the batch. For empty meshes the corresponding row in the samples array will be filled with 0. - **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector to each sampled point. Only returned if return_normals is True. For empty meshes the corresponding row in the normals array will be filled with 0. - **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional texture vector to each sampled point. Only returned if return_textures is True. For empty meshes the corresponding row in the textures array will be filled with 0. Note that in a future releases, we will replace the 3-element tuple output with a `Pointclouds` datastructure, as follows .. code-block:: python Pointclouds(samples, normals=normals, features=textures) """ if meshes.isempty(): raise ValueError("Meshes are empty.") verts = meshes.verts_packed() if not torch.isfinite(verts).all(): raise ValueError("Meshes contain nan or inf.") if return_textures and meshes.textures is None: raise ValueError("Meshes do not contain textures.") faces = meshes.faces_packed() mesh_to_face = meshes.mesh_to_faces_packed_first_idx() num_meshes = len(meshes) num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes. # Initialize samples tensor with fill value 0 for empty meshes. samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) # Only compute samples for non empty meshes with torch.no_grad(): areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero. max_faces = meshes.num_faces_per_mesh().max().item() areas_padded = packed_to_padded( areas, mesh_to_face[meshes.valid], max_faces ) # (N, F) # TODO (gkioxari) Confirm multinomial bug is not present with real data. sample_face_idxs = areas_padded.multinomial( num_samples, replacement=True ) # (N, num_samples) sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1) # Get the vertex coordinates of the sampled faces. face_verts = verts[faces] v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2] # Randomly generate barycentric coords. w0, w1, w2 = _rand_barycentric_coords( num_valid_meshes, num_samples, verts.dtype, verts.device ) # Use the barycentric coords to get a point on each sampled face. a = v0[sample_face_idxs] # (N, num_samples, 3) b = v1[sample_face_idxs] c = v2[sample_face_idxs] samples[meshes.valid] = w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c if return_normals: # Initialize normals tensor with fill value 0 for empty meshes. # Normals for the sampled points are face normals computed from # the vertices of the face in which the sampled point lies. normals = torch.zeros((num_meshes, num_samples, 3), device=meshes.device) vert_normals = (v1 - v0).cross(v2 - v1, dim=1) vert_normals = vert_normals / vert_normals.norm(dim=1, p=2, keepdim=True).clamp( min=sys.float_info.epsilon ) vert_normals = vert_normals[sample_face_idxs] normals[meshes.valid] = vert_normals if return_textures: # fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1. pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1 bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3 # zbuf and dists are not used in `sample_textures` so we initialize them with dummy dummy = torch.zeros( (len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32 ) # NxSx1x1 fragments = MeshFragments( pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy ) textures = meshes.sample_textures(fragments) # NxSx1x1xC textures = textures[:, :, 0, 0, :] # NxSxC # return # TODO(gkioxari) consider returning a Pointclouds instance [breaking] if return_normals and return_textures: # pyre-fixme[61]: `normals` may not be initialized here. # pyre-fixme[61]: `textures` may not be initialized here. return samples, normals, textures if return_normals: # return_textures is False # pyre-fixme[61]: `normals` may not be initialized here. return samples, normals if return_textures: # return_normals is False # pyre-fixme[61]: `textures` may not be initialized here. return samples, textures return samples
def _rand_barycentric_coords( size1, size2, dtype: torch.dtype, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Helper function to generate random barycentric coordinates which are uniformly distributed over a triangle. Args: size1, size2: The number of coordinates generated will be size1*size2. Output tensors will each be of shape (size1, size2). dtype: Datatype to generate. device: A torch.device object on which the outputs will be allocated. Returns: w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric coordinates """ uv = torch.rand(2, size1, size2, dtype=dtype, device=device) u, v = uv[0], uv[1] u_sqrt = u.sqrt() w0 = 1.0 - u_sqrt w1 = u_sqrt * (1.0 - v) w2 = u_sqrt * v # pyre-fixme[7]: Expected `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]` but # got `Tuple[float, typing.Any, typing.Any]`. return w0, w1, w2