# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import warnings
from typing import Optional, Tuple, Union
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.ops import padded_to_packed
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit.utils import HeterogeneousRayBundle, RayBundle
from torch.nn import functional as F
"""
This file defines three raysampling techniques:
- MultinomialRaysampler which can be used to sample rays from pixels of an image grid
- NDCMultinomialRaysampler which can be used to sample rays from pixels of an image grid,
which follows the pytorch3d convention for image grid coordinates
- MonteCarloRaysampler which randomly selects real-valued locations in the image plane
and emits rays from them
"""
[docs]
class MultinomialRaysampler(torch.nn.Module):
"""
Samples a fixed number of points along rays which are regularly distributed
in a batch of rectangular image grids. Points along each ray
have uniformly-spaced z-coordinates between a predefined
minimum and maximum depth.
The raysampler first generates a 3D coordinate grid of the following form::
/ min_x, min_y, max_depth -------------- / max_x, min_y, max_depth
/ /|
/ / | ^
/ min_depth min_depth / | |
min_x ----------------------------- max_x | | image
min_y min_y | | height
| | | |
| | | v
| | |
| | / max_x, max_y, ^
| | / max_depth /
min_x max_y / / n_pts_per_ray
max_y ----------------------------- max_x/ min_depth v
< --- image_width --- >
In order to generate ray points, `MultinomialRaysampler` takes each 3D point of
the grid (with coordinates `[x, y, depth]`) and unprojects it
with `cameras.unproject_points([x, y, depth])`, where `cameras` are an
additional input to the `forward` function.
Note that this is a generic implementation that can support any image grid
coordinate convention. For a raysampler which follows the PyTorch3D
coordinate conventions please refer to `NDCMultinomialRaysampler`.
As such, `NDCMultinomialRaysampler` is a special case of `MultinomialRaysampler`.
Attributes:
min_x: The leftmost x-coordinate of each ray's source pixel's center.
max_x: The rightmost x-coordinate of each ray's source pixel's center.
min_y: The topmost y-coordinate of each ray's source pixel's center.
max_y: The bottommost y-coordinate of each ray's source pixel's center.
"""
[docs]
def __init__(
self,
*,
min_x: float,
max_x: float,
min_y: float,
max_y: float,
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
"""
Args:
min_x: The leftmost x-coordinate of each ray's source pixel's center.
max_x: The rightmost x-coordinate of each ray's source pixel's center.
min_y: The topmost y-coordinate of each ray's source pixel's center.
max_y: The bottommost y-coordinate of each ray's source pixel's center.
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
"""
super().__init__()
self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth
self._max_depth = max_depth
self._n_rays_per_image = n_rays_per_image
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling
self.min_x, self.max_x = min_x, max_x
self.min_y, self.max_y = min_y, max_y
# get the initial grid of image xy coords
y, x = meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)
_xy_grid = torch.stack([x, y], dim=-1)
self.register_buffer("_xy_grid", _xy_grid, persistent=False)
[docs]
def forward(
self,
cameras: CamerasBase,
*,
mask: Optional[torch.Tensor] = None,
min_depth: Optional[float] = None,
max_depth: Optional[float] = None,
n_rays_per_image: Optional[int] = None,
n_pts_per_ray: Optional[int] = None,
stratified_sampling: Optional[bool] = None,
n_rays_total: Optional[int] = None,
**kwargs,
) -> Union[RayBundle, HeterogeneousRayBundle]:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
mask: if given, the rays are sampled from the mask. Should be of size
(batch_size, image_height, image_width).
min_depth: The minimum depth of a ray-point.
max_depth: The maximum depth of a ray-point.
n_rays_per_image: If given, this amount of rays are sampled from the grid.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
Returns:
A named tuple RayBundle or dataclass HeterogeneousRayBundle with the
following fields:
origins: A tensor of shape
`(batch_size, s1, s2, 3)`
denoting the locations of ray origins in the world coordinates.
directions: A tensor of shape
`(batch_size, s1, s2, 3)`
denoting the directions of each ray in the world coordinates.
lengths: A tensor of shape
`(batch_size, s1, s2, n_pts_per_ray)`
containing the z-coordinate (=depth) of each ray in world units.
xys: A tensor of shape
`(batch_size, s1, s2, 2)`
containing the 2D image coordinates of each ray or,
if mask is given, `(batch_size, n, 1, 2)`
Here `s1, s2` refer to spatial dimensions.
`(s1, s2)` refer to (highest priority first):
- `(1, 1)` if `n_rays_total` is provided, (batch_size=n_rays_total)
- `(n_rays_per_image, 1) if `n_rays_per_image` if provided,
- `(n, 1)` where n is the minimum cardinality of the mask
in the batch if `mask` is provided
- `(image_height, image_width)` if nothing from above is satisfied
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
`HeterogeneousRayBundle` is returned if `n_rays_total` is provided else `RayBundle`
is returned.
"""
n_rays_total = n_rays_total or self._n_rays_total
n_rays_per_image = n_rays_per_image or self._n_rays_per_image
if (n_rays_total is not None) and (n_rays_per_image is not None):
raise ValueError(
"`n_rays_total` and `n_rays_per_image` cannot both be defined."
)
if n_rays_total:
(
cameras,
mask,
camera_ids, # unique ids of sampled cameras
camera_counts, # number of times unique camera id was sampled
# `n_rays_per_image` is equal to the max number of times a simgle camera
# was sampled. We sample all cameras at `camera_ids` `n_rays_per_image` times
# and then discard the unneeded rays.
# pyre-ignore[9]
n_rays_per_image,
) = _sample_cameras_and_masks(n_rays_total, cameras, mask)
else:
# pyre-ignore[9]
camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long)
batch_size = cameras.R.shape[0]
device = cameras.device
# expand the (H, W, 2) grid batch_size-times to (B, H, W, 2)
xy_grid = self._xy_grid.to(device).expand(batch_size, -1, -1, -1)
if mask is not None and n_rays_per_image is None:
# if num rays not given, sample according to the smallest mask
n_rays_per_image = (
n_rays_per_image or mask.sum(dim=(1, 2)).min().int().item()
)
if n_rays_per_image is not None:
if mask is not None:
assert mask.shape == xy_grid.shape[:3]
weights = mask.reshape(batch_size, -1)
else:
# it is probably more efficient to use torch.randperm
# for uniform weights but it is unlikely given that randperm
# is not batched and does not support partial permutation
_, width, height, _ = xy_grid.shape
weights = xy_grid.new_ones(batch_size, width * height)
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[bool,
# float, int]`.
rays_idx = _safe_multinomial(weights, n_rays_per_image)[..., None].expand(
-1, -1, 2
)
xy_grid = torch.gather(xy_grid.reshape(batch_size, -1, 2), 1, rays_idx)[
:, :, None
]
min_depth = min_depth if min_depth is not None else self._min_depth
max_depth = max_depth if max_depth is not None else self._max_depth
n_pts_per_ray = (
n_pts_per_ray if n_pts_per_ray is not None else self._n_pts_per_ray
)
stratified_sampling = (
stratified_sampling
if stratified_sampling is not None
else self._stratified_sampling
)
ray_bundle = _xy_to_ray_bundle(
cameras,
xy_grid,
min_depth,
max_depth,
n_pts_per_ray,
self._unit_directions,
stratified_sampling,
)
return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if n_rays_total
else ray_bundle
)
[docs]
class NDCMultinomialRaysampler(MultinomialRaysampler):
"""
Samples a fixed number of points along rays which are regularly distributed
in a batch of rectangular image grids. Points along each ray
have uniformly-spaced z-coordinates between a predefined minimum and maximum depth.
`NDCMultinomialRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds`
renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1]
where u > 1 is the aspect ratio of the image.
For the description of arguments, see the documentation to MultinomialRaysampler.
"""
def __init__(
self,
*,
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
n_rays_per_image: Optional[int] = None,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
if image_width >= image_height:
range_x = image_width / image_height
range_y = 1.0
else:
range_x = 1.0
range_y = image_height / image_width
half_pix_width = range_x / image_width
half_pix_height = range_y / image_height
super().__init__(
min_x=range_x - half_pix_width,
max_x=-range_x + half_pix_width,
min_y=range_y - half_pix_height,
max_y=-range_y + half_pix_height,
image_width=image_width,
image_height=image_height,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
n_rays_per_image=n_rays_per_image,
n_rays_total=n_rays_total,
unit_directions=unit_directions,
stratified_sampling=stratified_sampling,
)
[docs]
class MonteCarloRaysampler(torch.nn.Module):
"""
Samples a fixed number of pixels within denoted xy bounds uniformly at random.
For each pixel, a fixed number of points is sampled along its ray at uniformly-spaced
z-coordinates such that the z-coordinates range between a predefined minimum
and maximum depth.
For practical purposes, this is similar to MultinomialRaysampler without a mask,
however sampling at real-valued locations bypassing replacement checks may be faster.
"""
[docs]
def __init__(
self,
min_x: float,
max_x: float,
min_y: float,
max_y: float,
n_rays_per_image: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
*,
n_rays_total: Optional[int] = None,
unit_directions: bool = False,
stratified_sampling: bool = False,
) -> None:
"""
Args:
min_x: The smallest x-coordinate of each ray's source pixel.
max_x: The largest x-coordinate of each ray's source pixel.
min_y: The smallest y-coordinate of each ray's source pixel.
max_y: The largest y-coordinate of each ray's source pixel.
n_rays_per_image: The number of rays randomly sampled in each camera.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
n_pts_per_ray: The number of points sampled along each ray.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
n_rays_total: How many rays in total to sample from the cameras provided. The result
is as if `n_rays_total_training` cameras were sampled with replacement from the
cameras provided and for every camera one ray was sampled. If set returns the
HeterogeneousRayBundle with batch_size=n_rays_total.
`n_rays_per_image` and `n_rays_total` cannot both be defined.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points
on each ray with uniform offsets.
"""
super().__init__()
self._min_x = min_x
self._max_x = max_x
self._min_y = min_y
self._max_y = max_y
self._n_rays_per_image = n_rays_per_image
self._n_pts_per_ray = n_pts_per_ray
self._min_depth = min_depth
self._max_depth = max_depth
self._n_rays_total = n_rays_total
self._unit_directions = unit_directions
self._stratified_sampling = stratified_sampling
[docs]
def forward(
self,
cameras: CamerasBase,
*,
stratified_sampling: Optional[bool] = None,
**kwargs,
) -> Union[RayBundle, HeterogeneousRayBundle]:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
stratified_sampling: if set, overrides stratified_sampling provided
in __init__.
Returns:
A named tuple `RayBundle` or dataclass `HeterogeneousRayBundle` with the
following fields:
origins: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the locations of ray origins in the world coordinates.
directions: A tensor of shape
`(batch_size, n_rays_per_image, 3)`
denoting the directions of each ray in the world coordinates.
lengths: A tensor of shape
`(batch_size, n_rays_per_image, n_pts_per_ray)`
containing the z-coordinate (=depth) of each ray in world units.
xys: A tensor of shape
`(batch_size, n_rays_per_image, 2)`
containing the 2D image coordinates of each ray.
If `n_rays_total` is provided `batch_size=n_rays_total`and
`n_rays_per_image=1` and `HeterogeneousRayBundle` is returned else `RayBundle`
is returned.
`HeterogeneousRayBundle` has additional members:
- camera_ids: tensor of shape (M,), where `M` is the number of unique sampled
cameras. It represents unique ids of sampled cameras.
- camera_counts: tensor of shape (M,), where `M` is the number of unique sampled
cameras. Represents how many times each camera from `camera_ids` was sampled
"""
if (
sum(x is not None for x in [self._n_rays_total, self._n_rays_per_image])
!= 1
):
raise ValueError(
"Exactly one of `self.n_rays_total` and `self.n_rays_per_image` "
"must be given."
)
if self._n_rays_total:
(
cameras,
_,
camera_ids,
camera_counts,
n_rays_per_image,
) = _sample_cameras_and_masks(self._n_rays_total, cameras, None)
else:
# pyre-ignore[9]
camera_ids: torch.LongTensor = torch.arange(len(cameras), dtype=torch.long)
n_rays_per_image = self._n_rays_per_image
batch_size = cameras.R.shape[0]
device = cameras.device
# get the initial grid of image xy coords
# of shape (batch_size, n_rays_per_image, 2)
rays_xy = torch.cat(
[
torch.rand(
size=(batch_size, n_rays_per_image, 1),
dtype=torch.float32,
device=device,
)
* (high - low)
+ low
for low, high in (
(self._min_x, self._max_x),
(self._min_y, self._max_y),
)
],
dim=2,
)
stratified_sampling = (
stratified_sampling
if stratified_sampling is not None
else self._stratified_sampling
)
ray_bundle = _xy_to_ray_bundle(
cameras,
rays_xy,
self._min_depth,
self._max_depth,
self._n_pts_per_ray,
self._unit_directions,
stratified_sampling,
)
return (
# pyre-ignore[61]
_pack_ray_bundle(ray_bundle, camera_ids, camera_counts)
if self._n_rays_total
else ray_bundle
)
# Settings for backwards compatibility
[docs]
def GridRaysampler(
min_x: float,
max_x: float,
min_y: float,
max_y: float,
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
) -> "MultinomialRaysampler":
"""
GridRaysampler has been DEPRECATED. Use MultinomialRaysampler instead.
Preserving GridRaysampler for backward compatibility.
"""
warnings.warn(
"""GridRaysampler is deprecated,
Use MultinomialRaysampler instead.
GridRaysampler will be removed in future releases.""",
PendingDeprecationWarning,
)
return MultinomialRaysampler(
min_x=min_x,
max_x=max_x,
min_y=min_y,
max_y=max_y,
image_width=image_width,
image_height=image_height,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
)
# Settings for backwards compatibility
[docs]
def NDCGridRaysampler(
image_width: int,
image_height: int,
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
) -> "NDCMultinomialRaysampler":
"""
NDCGridRaysampler has been DEPRECATED. Use NDCMultinomialRaysampler instead.
Preserving NDCGridRaysampler for backward compatibility.
"""
warnings.warn(
"""NDCGridRaysampler is deprecated,
Use NDCMultinomialRaysampler instead.
NDCGridRaysampler will be removed in future releases.""",
PendingDeprecationWarning,
)
return NDCMultinomialRaysampler(
image_width=image_width,
image_height=image_height,
n_pts_per_ray=n_pts_per_ray,
min_depth=min_depth,
max_depth=max_depth,
)
def _safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
"""
Wrapper around torch.multinomial that attempts sampling without replacement
when possible, otherwise resorts to sampling with replacement.
Args:
input: tensor of shape [B, n] containing non-negative values;
rows are interpreted as unnormalized event probabilities
in categorical distributions.
num_samples: number of samples to take.
Returns:
LongTensor of shape [B, num_samples] containing
values from {0, ..., n - 1} where the elements [i, :] of row i make
(1) if there are num_samples or more non-zero values in input[i],
a random subset of the indices of those values, with
probabilities proportional to the values in input[i, :].
(2) if not, a random sample with replacement of the indices of
those values, with probabilities proportional to them.
This sample might not contain all the indices of the
non-zero values.
Behavior undetermined if there are no non-zero values in a whole row
or if there are negative values.
"""
try:
res = torch.multinomial(input, num_samples, replacement=False)
except RuntimeError:
# this is probably rare, so we don't mind sampling twice
res = torch.multinomial(input, num_samples, replacement=True)
no_repl = (input > 0.0).sum(dim=-1) >= num_samples
res[no_repl] = torch.multinomial(input[no_repl], num_samples, replacement=False)
return res
# in some versions of Pytorch, zero probabilty samples can be drawn without an error
# due to this bug: https://github.com/pytorch/pytorch/issues/50034. Handle this case:
repl = (input > 0.0).sum(dim=-1) < num_samples
if repl.any():
res[repl] = torch.multinomial(input[repl], num_samples, replacement=True)
return res
def _xy_to_ray_bundle(
cameras: CamerasBase,
xy_grid: torch.Tensor,
min_depth: float,
max_depth: float,
n_pts_per_ray: int,
unit_directions: bool,
stratified_sampling: bool = False,
) -> RayBundle:
"""
Extends the `xy_grid` input of shape `(batch_size, ..., 2)` to rays.
This adds to each xy location in the grid a vector of `n_pts_per_ray` depths
uniformly spaced between `min_depth` and `max_depth`.
The extended grid is then unprojected with `cameras` to yield
ray origins, directions and depths.
Args:
cameras: cameras object representing a batch of cameras.
xy_grid: torch.tensor grid of image xy coords.
min_depth: The minimum depth of each ray-point.
max_depth: The maximum depth of each ray-point.
n_pts_per_ray: The number of points sampled along each ray.
unit_directions: whether to normalize direction vectors in ray bundle.
stratified_sampling: if True, performs stratified sampling in n_pts_per_ray
bins for each ray; otherwise takes n_pts_per_ray deterministic points
on each ray with uniform offsets.
"""
batch_size = xy_grid.shape[0]
spatial_size = xy_grid.shape[1:-1]
n_rays_per_image = spatial_size.numel()
# ray z-coords
rays_zs = xy_grid.new_empty((0,))
if n_pts_per_ray > 0:
depths = torch.linspace(
min_depth,
max_depth,
n_pts_per_ray,
dtype=xy_grid.dtype,
device=xy_grid.device,
)
rays_zs = depths[None, None].expand(batch_size, n_rays_per_image, n_pts_per_ray)
if stratified_sampling:
rays_zs = _jiggle_within_stratas(rays_zs)
# make two sets of points at a constant depth=1 and 2
to_unproject = torch.cat(
(
xy_grid.view(batch_size, 1, n_rays_per_image, 2)
.expand(batch_size, 2, n_rays_per_image, 2)
.reshape(batch_size, n_rays_per_image * 2, 2),
torch.cat(
(
xy_grid.new_ones(batch_size, n_rays_per_image, 1),
2.0 * xy_grid.new_ones(batch_size, n_rays_per_image, 1),
),
dim=1,
),
),
dim=-1,
)
# unproject the points
unprojected = cameras.unproject_points(to_unproject, from_ndc=True)
# split the two planes back
rays_plane_1_world = unprojected[:, :n_rays_per_image]
rays_plane_2_world = unprojected[:, n_rays_per_image:]
# directions are the differences between the two planes of points
rays_directions_world = rays_plane_2_world - rays_plane_1_world
# origins are given by subtracting the ray directions from the first plane
rays_origins_world = rays_plane_1_world - rays_directions_world
if unit_directions:
rays_directions_world = F.normalize(rays_directions_world, dim=-1)
return RayBundle(
rays_origins_world.view(batch_size, *spatial_size, 3),
rays_directions_world.view(batch_size, *spatial_size, 3),
rays_zs.view(batch_size, *spatial_size, n_pts_per_ray),
xy_grid,
)
def _jiggle_within_stratas(bin_centers: torch.Tensor) -> torch.Tensor:
"""
Performs sampling of 1 point per bin given the bin centers.
More specifically, it replaces each point's value `z`
with a sample from a uniform random distribution on
`[z - delta_-, z + delta_+]`, where `delta_-` is half of the difference
between `z` and the previous point, and `delta_+` is half of the difference
between the next point and `z`. For the first and last items, the
corresponding boundary deltas are assumed zero.
Args:
`bin_centers`: The input points of size (..., N); the result is broadcast
along all but the last dimension (the rows). Each row should be
sorted in ascending order.
Returns:
a tensor of size (..., N) with the locations jiggled within stratas/bins.
"""
# Get intervals between bin centers.
mids = 0.5 * (bin_centers[..., 1:] + bin_centers[..., :-1])
upper = torch.cat((mids, bin_centers[..., -1:]), dim=-1)
lower = torch.cat((bin_centers[..., :1], mids), dim=-1)
# Samples in those intervals.
jiggled = lower + (upper - lower) * torch.rand_like(lower)
return jiggled
def _sample_cameras_and_masks(
n_samples: int, cameras: CamerasBase, mask: Optional[torch.Tensor] = None
) -> Tuple[
CamerasBase,
Optional[torch.Tensor],
torch.LongTensor,
torch.LongTensor,
torch.LongTensor,
]:
"""
Samples n_rays_total cameras and masks and returns them in a form
(camera_idx, count), where count represents number of times the same camera
has been sampled.
Args:
n_samples: how many camera and mask pairs to sample
cameras: A batch of `batch_size` cameras from which the rays are emitted.
mask: Optional. Should be of size (batch_size, image_height, image_width).
Returns:
tuple of a form (sampled_cameras, sampled_masks, unique_sampled_camera_ids,
number_of_times_each_sampled_camera_has_been_sampled,
max_number_of_times_camera_has_been_sampled,
)
"""
sampled_ids = torch.randint(
0,
len(cameras),
size=(n_samples,),
dtype=torch.long,
)
unique_ids, counts = torch.unique(sampled_ids, return_counts=True)
# pyre-ignore[7]
return (
cameras[unique_ids],
mask[unique_ids] if mask is not None else None,
unique_ids,
counts,
torch.max(counts),
)
# TODO: this function can be unified with ImplicitronRayBundle.get_padded_xys
def _pack_ray_bundle(
ray_bundle: RayBundle, camera_ids: torch.LongTensor, camera_counts: torch.LongTensor
) -> HeterogeneousRayBundle:
"""
Pack the raybundle from [n_cameras, max(rays_per_camera), ...] to
[total_num_rays, 1, ...]
Args:
ray_bundle: A ray_bundle to pack
camera_ids: Unique ids of cameras that were sampled
camera_counts: how many of which camera to pack, each count coresponds to
one 'row' of the ray_bundle and says how many rays wll be taken
from it and packed.
Returns:
HeterogeneousRayBundle where batch_size=sum(camera_counts) and n_rays_per_image=1
"""
# pyre-ignore[9]
camera_counts = camera_counts.to(ray_bundle.origins.device)
cumsum = torch.cumsum(camera_counts, dim=0, dtype=torch.long)
# pyre-ignore[9]
first_idxs: torch.LongTensor = torch.cat(
(camera_counts.new_zeros((1,), dtype=torch.long), cumsum[:-1])
)
num_inputs = int(camera_counts.sum())
return HeterogeneousRayBundle(
origins=padded_to_packed(ray_bundle.origins, first_idxs, num_inputs)[:, None],
directions=padded_to_packed(ray_bundle.directions, first_idxs, num_inputs)[
:, None
],
lengths=padded_to_packed(ray_bundle.lengths, first_idxs, num_inputs)[:, None],
xys=padded_to_packed(ray_bundle.xys, first_idxs, num_inputs)[:, None],
camera_ids=camera_ids,
camera_counts=camera_counts,
)