# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
[docs]
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
"""
Expands the `global_code` of shape (minibatch, dim)
so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
and appends to the last dimension of `embeds`.
"""
bs = embeds.shape[0]
global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand(
*embeds.shape[:-1],
global_code.shape[-1],
)
return torch.cat([embeds, global_code_broadcast], dim=-1)
[docs]
def create_embeddings_for_implicit_function(
xyz_world: torch.Tensor,
xyz_in_camera_coords: bool,
global_code: Optional[torch.Tensor],
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_in_camera_coords")
ray_points_for_embed = (
camera.get_world_to_view_transform()
.transform_points(xyz_world.view(bs, -1, 3))
.view(xyz_world.shape)
)
else:
ray_points_for_embed = xyz_world
if xyz_embedding_function is None:
embeds = torch.empty(
bs,
1,
prod(spatial_size),
pts_per_ray,
0,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape(
bs,
1,
prod(spatial_size),
pts_per_ray,
-1,
) # flatten spatial, add n_src dim
if fun_viewpool is not None:
# viewpooling
embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3))
embed_shape = (
bs,
embeds_viewpooled.shape[1],
prod(spatial_size),
pts_per_ray,
-1,
)
embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape)
if embeds is not None:
embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1)
else:
embeds = embeds_viewpooled
if global_code is not None:
# append the broadcasted global code to embeds
embeds = broadcast_global_code(embeds, global_code)
return embeds
[docs]
def interpolate_line(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Linearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
dimension of argument source represents feature and ones after that the spatial
dimension.
Arguments:
points: shape (n_grids, n_points, 1),
source: tensor of shape (n_grids, features, width),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# To enable sampling of the source using the torch.functional.grid_sample
# points need to have 2 coordinates.
expansion = points.new_zeros(points.shape)
points = torch.cat((points, expansion), dim=-1)
source = source[:, :, None, :]
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
[docs]
def interpolate_plane(
points: torch.Tensor,
source: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Bilinearly interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
The first dimension of argument source represents feature and ones after that the
spatial dimension.
Arguments:
points: shape (n_grids, n_points, 2),
source: tensor of shape (n_grids, features, width, height),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
# permuting because torch.nn.functional.grid_sample works with
# (features, height, width) and not
# (features, width, height)
source = source.permute(0, 1, 3, 2)
points = points[:, :, None, :]
out = F.grid_sample(
grid=points,
input=source,
**kwargs,
)
return out[:, :, :, 0].permute(0, 2, 1)
[docs]
def interpolate_volume(
points: torch.Tensor, source: torch.Tensor, **kwargs
) -> torch.Tensor:
"""
Interpolates values of source grids. The first dimension of points represents
number of points and the second coordinates, for example
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
and ones after that the spatial dimension.
Arguments:
points: shape (n_grids, n_points, 3),
source: tensor of shape (n_grids, features, width, height, depth),
Returns:
interpolated tensor of shape (n_grids, n_points, features)
"""
if "mode" in kwargs and kwargs["mode"] == "trilinear":
kwargs = kwargs.copy()
kwargs["mode"] = "bilinear"
# permuting because torch.nn.functional.grid_sample works with
# (features, depth, height, width) and not (features, width, height, depth)
source = source.permute(0, 1, 4, 3, 2)
grid = points[:, :, None, None, :]
out = F.grid_sample(
grid=grid,
input=source,
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)
[docs]
def get_rays_points_world(
ray_bundle: Optional[ImplicitronRayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: An ImplicitronRayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
A torch.Tensor representing ray points converted to world coordinates
of shape [minibatch x ... x pts_per_ray x 3].
"""
if rays_points_world is not None and ray_bundle is not None:
raise ValueError(
"Cannot define both rays_points_world and ray_bundle,"
+ " one has to be None."
)
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
# pyre-ignore[6]
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")