# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import List, Optional, Tuple, Union
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.common.datatypes import Device, make_device
from pytorch3d.transforms import Scale, Transform3d
from . import utils as struct_utils
_Scalar = Union[int, float]
_Vector = Union[torch.Tensor, Tuple[_Scalar, ...], List[_Scalar]]
_ScalarOrVector = Union[_Scalar, _Vector]
_VoxelSize = _ScalarOrVector
_Translation = _Vector
_TensorBatch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
_ALL_CONTENT: slice = slice(0, None)
[docs]class Volumes:
"""
This class provides functions for working with batches of volumetric grids
of possibly varying spatial sizes.
VOLUME DENSITIES
The Volumes class can be either constructed from a 5D tensor of
`densities` of size `batch x density_dim x depth x height x width` or
from a list of differently-sized 4D tensors `[D_1, ..., D_batch]`,
where each `D_i` is of size `[density_dim x depth_i x height_i x width_i]`.
In case the `Volumes` object is initialized from the list of `densities`,
the list of tensors is internally converted to a single 5D tensor by
zero-padding the relevant dimensions. Both list and padded representations can be
accessed with the `Volumes.densities()` or `Volumes.densities_list()` getters.
The sizes of the individual volumes in the structure can be retrieved
with the `Volumes.get_grid_sizes()` getter.
The `Volumes` class is immutable. I.e. after generating a `Volumes` object,
one cannot change its properties, such as `self._densities` or `self._features`
anymore.
VOLUME FEATURES
While the `densities` field is intended to represent various measures of the
"density" of the volume cells (opacity, signed/unsigned distances
from the nearest surface, ...), one can additionally initialize the
object with the `features` argument. `features` are either a 5D tensor
of shape `batch x feature_dim x depth x height x width` or a list of
of differently-sized 4D tensors `[F_1, ..., F_batch]`,
where each `F_i` is of size `[feature_dim x depth_i x height_i x width_i]`.
`features` are intended to describe other properties of volume cells,
such as per-voxel 3D vectors of RGB colors that can be later used
for rendering the volume.
VOLUME COORDINATES
Additionally, using the `VolumeLocator` class the `Volumes` class keeps track
of the locations of the centers of the volume cells in the local volume
coordinates as well as in the world coordinates.
Local coordinates:
- Represent the locations of the volume cells in the local coordinate
frame of the volume.
- The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
3D local coordinate set to `[1, 1, 1]`.
- The first/second/third coordinate of each of the 3D per-voxel
XYZ vector denotes the horizontal/vertical/depth-wise position
respectively. I.e the order of the coordinate dimensions in the
volume is reversed w.r.t. the order of the 3D coordinate vectors.
- The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
are linearly interpolated over the spatial dimensions of the volume.
- Note that the convention is the same as for the 5D version of the
`torch.nn.functional.grid_sample` function called with
`align_corners==True`.
- Note that the local coordinate convention of `Volumes`
(+X = left to right, +Y = top to bottom, +Z = away from the user)
is *different* from the world coordinate convention of the
renderer for `Meshes` or `Pointclouds`
(+X = right to left, +Y = bottom to top, +Z = away from the user).
World coordinates:
- These define the locations of the centers of the volume cells
in the world coordinates.
- They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world`
in the world coordinates::
x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation,
here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame.
Both `voxel_size` and `volume_translation` are specified in
the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows::
x_local = (
(x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1)
- For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`.
Coordinate tensors that denote the locations of each of the volume cells in
local / world coordinates (with shape `(depth x height x width x 3)`)
can be retrieved by calling the `Volumes.get_coord_grid()` getter with the
appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `Volumes.VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the
`Volumes.get_world_to_local_coords_transform()` and
`Volumes.get_local_to_world_coords_transform()`
functions.
Example coordinate conversion:
- For a "trivial" volume with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, and the spatial size of
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds:
torch.nn.functional.grid_sample(
v.densities(),
v.get_coord_grid(world_coordinates=False),
align_corners=True,
) == v.densities(),
i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume.
"""
[docs] def __init__(
self,
densities: _TensorBatch,
features: Optional[_TensorBatch] = None,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
) -> None:
"""
Args:
**densities**: Batch of input feature volume occupancies of shape
`(minibatch, density_dim, depth, height, width)`, or a list
of 4D tensors `[D_1, ..., D_minibatch]` where each `D_i` has
shape `(density_dim, depth_i, height_i, width_i)`.
Typically, each voxel contains a non-negative number
corresponding to its opaqueness.
**features**: Batch of input feature volumes of shape:
`(minibatch, feature_dim, depth, height, width)` or a list
of 4D tensors `[F_1, ..., F_minibatch]` where each `F_i` has
shape `(feature_dim, depth_i, height_i, width_i)`.
The field is optional and can be set to `None` in case features are
not required.
**voxel_size**: Denotes the size of each volume voxel in world units.
Has to be one of:
a) A scalar (square voxels)
b) 3-tuple or a 3-list of scalars
c) a Tensor of shape (3,)
d) a Tensor of shape (minibatch, 3)
e) a Tensor of shape (minibatch, 1)
f) a Tensor of shape (1,) (square voxels)
**volume_translation**: Denotes the 3D translation of the center
of the volume in world units. Has to be one of:
a) 3-tuple or a 3-list of scalars
b) a Tensor of shape (3,)
c) a Tensor of shape (minibatch, 3)
d) a Tensor of shape (1,) (square voxels)
"""
# handle densities
densities_, grid_sizes = self._convert_densities_features_to_tensor(
densities, "densities"
)
# take device from densities
self.device = densities_.device
# assign to the internal buffers
self._densities = densities_
# assign a coordinate transformation member
self.locator = VolumeLocator(
batch_size=len(self),
grid_sizes=grid_sizes,
voxel_size=voxel_size,
volume_translation=volume_translation,
device=self.device,
)
# handle features
self._features = None
if features is not None:
self._set_features(features)
def _convert_densities_features_to_tensor(
self, x: _TensorBatch, var_name: str
) -> Tuple[torch.Tensor, torch.LongTensor]:
"""
Handle the `densities` or `features` arguments to the constructor.
"""
if isinstance(x, (list, tuple)):
x_tensor = struct_utils.list_to_padded(x)
if any(x_.ndim != 4 for x_ in x):
raise ValueError(
f"`{var_name}` has to be a list of 4-dim tensors of shape: "
f"({var_name}_dim, height, width, depth)"
)
if any(x_.shape[0] != x[0].shape[0] for x_ in x):
raise ValueError(
f"Each entry in the list of `{var_name}` has to have the "
"same number of channels (first dimension in the tensor)."
)
x_shapes = torch.stack(
[
torch.tensor(
list(x_.shape[1:]), dtype=torch.long, device=x_tensor.device
)
for x_ in x
],
dim=0,
)
elif torch.is_tensor(x):
if x.ndim != 5:
raise ValueError(
f"`{var_name}` has to be a 5-dim tensor of shape: "
f"(minibatch, {var_name}_dim, height, width, depth)"
)
x_tensor = x
x_shapes = torch.tensor(
list(x.shape[2:]), dtype=torch.long, device=x.device
)[None].repeat(x.shape[0], 1)
else:
raise ValueError(
f"{var_name} must be either a list or a tensor with "
f"shape (batch_size, {var_name}_dim, H, W, D)."
)
# pyre-ignore[7]
return x_tensor, x_shapes
def __len__(self) -> int:
return self._densities.shape[0]
[docs] def __getitem__(
self,
index: Union[
int, List[int], Tuple[int], slice, torch.BoolTensor, torch.LongTensor
],
) -> "Volumes":
"""
Args:
index: Specifying the index of the volume to retrieve.
Can be an int, slice, list of ints or a boolean or a long tensor.
Returns:
Volumes object with selected volumes. The tensors are not cloned.
"""
if isinstance(index, int):
index = torch.LongTensor([index])
elif isinstance(index, (slice, list, tuple)):
pass
elif torch.is_tensor(index):
if index.dim() != 1 or index.dtype.is_floating_point:
raise IndexError(index)
else:
raise IndexError(index)
new = self.__class__(
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
features=self.features()[index] if self._features is not None else None,
densities=self.densities()[index],
)
# dont forget to update grid_sizes!
self.locator._copy_transform_and_sizes(new.locator, index=index)
return new
[docs] def features(self) -> Optional[torch.Tensor]:
"""
Returns the features of the volume.
Returns:
**features**: The tensor of volume features.
"""
return self._features
[docs] def densities(self) -> torch.Tensor:
"""
Returns the densities of the volume.
Returns:
**densities**: The tensor of volume densities.
"""
return self._densities
[docs] def densities_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the densities.
Returns:
list of tensors of densities of shape (dim_i, D_i, H_i, W_i).
"""
return self._features_densities_list(self.densities())
[docs] def features_list(self) -> List[torch.Tensor]:
"""
Get the list representation of the features.
Returns:
list of tensors of features of shape (dim_i, D_i, H_i, W_i)
or `None` for feature-less volumes.
"""
features_ = self.features()
if features_ is None:
# No features provided so return None
# pyre-fixme[7]: Expected `List[torch.Tensor]` but got `None`.
return None
return self._features_densities_list(features_)
def _features_densities_list(self, x: torch.Tensor) -> List[torch.Tensor]:
"""
Retrieve the list representation of features/densities.
Args:
x: self.features() or self.densities()
Returns:
list of tensors of features/densities of shape (dim_i, D_i, H_i, W_i).
"""
x_dim = x.shape[1]
pad_sizes = torch.nn.functional.pad(
self.get_grid_sizes(), [1, 0], mode="constant", value=x_dim
)
x_list = struct_utils.padded_to_list(x, pad_sizes.tolist())
return x_list
[docs] def update_padded(
self, new_densities: torch.Tensor, new_features: Optional[torch.Tensor] = None
) -> "Volumes":
"""
Returns a Volumes structure with updated padded tensors and copies of
the auxiliary tensors `self._local_to_world_transform`,
`device` and `self._grid_sizes`. This function allows for an update of
densities (and features) without having to explicitly
convert it to the list representation for heterogeneous batches.
Args:
new_densities: FloatTensor of shape (N, dim_density, D, H, W)
new_features: (optional) FloatTensor of shape (N, dim_feature, D, H, W)
Returns:
Volumes with updated features and densities
"""
new = copy.copy(self)
new._set_densities(new_densities)
if new_features is None:
new._features = None
else:
new._set_features(new_features)
return new
def _set_features(self, features: _TensorBatch) -> None:
self._set_densities_features("features", features)
def _set_densities(self, densities: _TensorBatch) -> None:
self._set_densities_features("densities", densities)
def _set_densities_features(self, var_name: str, x: _TensorBatch) -> None:
x_tensor, grid_sizes = self._convert_densities_features_to_tensor(x, var_name)
if x_tensor.device != self.device:
raise ValueError(
f"`{var_name}` have to be on the same device as `self.densities`."
)
if len(x_tensor.shape) != 5:
raise ValueError(
f"{var_name} has to be a 5-dim tensor of shape: "
f"(minibatch, {var_name}_dim, height, width, depth)"
)
if not (
(self.get_grid_sizes().shape == grid_sizes.shape)
and torch.allclose(self.get_grid_sizes(), grid_sizes)
):
raise ValueError(
f"The size of every grid in `{var_name}` has to match the size of"
"the corresponding `densities` grid."
)
setattr(self, "_" + var_name, x_tensor)
[docs] def clone(self) -> "Volumes":
"""
Deep copy of Volumes object. All internal tensors are cloned
individually.
Returns:
new Volumes object.
"""
return copy.deepcopy(self)
[docs] def to(self, device: Device, copy: bool = False) -> "Volumes":
"""
Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
Returns:
Volumes object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
return self
other = self.clone()
if self.device == device_:
return other
other.device = device_
other._densities = self._densities.to(device_)
if self._features is not None:
# pyre-fixme[16]: `Optional` has no attribute `to`.
other._features = self.features().to(device_)
self.locator._copy_transform_and_sizes(other.locator, device=device_)
other.locator = other.locator.to(device, copy)
return other
[docs] def cpu(self) -> "Volumes":
return self.to("cpu")
[docs] def cuda(self) -> "Volumes":
return self.to("cuda")
[docs] def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self.locator.get_grid_sizes()
[docs] def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
return self.locator.world_to_local_coords(points_3d_world)
[docs] def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
return self.locator.local_to_world_coords(points_3d_local)
[docs] def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
return self.locator.get_coord_grid(world_coordinates)
class VolumeLocator:
"""
The `VolumeLocator` class keeps track of the locations of the
centers of the volume cells in the local volume coordinates as well as in
the world coordinates for a voxel grid structure in 3D.
Local coordinates:
- Represent the locations of the volume cells in the local coordinate
frame of the volume.
- The center of the voxel indexed with `[·, ·, 0, 0, 0]` in the volume
has its 3D local coordinate set to `[-1, -1, -1]`, while the voxel
at index `[·, ·, depth_i-1, height_i-1, width_i-1]` has its
3D local coordinate set to `[1, 1, 1]`.
- The first/second/third coordinate of each of the 3D per-voxel
XYZ vector denotes the horizontal/vertical/depth-wise position
respectively. I.e the order of the coordinate dimensions in the
volume is reversed w.r.t. the order of the 3D coordinate vectors.
- The intermediate coordinates between `[-1, -1, -1]` and `[1, 1, 1]`.
are linearly interpolated over the spatial dimensions of the volume.
- Note that the convention is the same as for the 5D version of the
`torch.nn.functional.grid_sample` function called with
`align_corners==True`.
- Note that the local coordinate convention of `VolumeLocator`
(+X = left to right, +Y = top to bottom, +Z = away from the user)
is *different* from the world coordinate convention of the
renderer for `Meshes` or `Pointclouds`
(+X = right to left, +Y = bottom to top, +Z = away from the user).
World coordinates:
- These define the locations of the centers of the volume cells
in the world coordinates.
- They are specified with the following mapping that converts
points `x_local` in the local coordinates to points `x_world`
in the world coordinates::
x_world = (
x_local * (volume_size - 1) * 0.5 * voxel_size
) - volume_translation,
here `voxel_size` specifies the size of each voxel of the volume,
and `volume_translation` is the 3D offset of the central voxel of
the volume w.r.t. the origin of the world coordinate frame.
Both `voxel_size` and `volume_translation` are specified in
the world coordinate units. `volume_size` is the spatial size of
the volume in form of a 3D vector `[width, height, depth]`.
- Given the above definition of `x_world`, one can derive the
inverse mapping from `x_world` to `x_local` as follows::
x_local = (
(x_world + volume_translation) / (0.5 * voxel_size)
) / (volume_size - 1)
- For a trivial volume with `volume_translation==[0, 0, 0]`
with `voxel_size=-1`, `x_world` would range
from -(volume_size-1)/2` to `+(volume_size-1)/2`.
Coordinate tensors that denote the locations of each of the volume cells in
local / world coordinates (with shape `(depth x height x width x 3)`)
can be retrieved by calling the `VolumeLocator.get_coord_grid()` getter with the
appropriate `world_coordinates` argument.
Internally, the mapping between `x_local` and `x_world` is represented
as a `Transform3d` object `VolumeLocator._local_to_world_transform`.
Users can access the relevant transformations with the
`VolumeLocator.get_world_to_local_coords_transform()` and
`VolumeLocator.get_local_to_world_coords_transform()`
functions.
Example coordinate conversion:
- For a "trivial" volume with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, and the spatial size of
`DxHxW = 5x5x5`, the point `x_world = (-2, 0, 2)` gets mapped
to `x_local=(-1, 0, 1)`.
- For a "trivial" volume `v` with `voxel_size = 1.`,
`volume_translation=[0., 0., 0.]`, the following holds::
torch.nn.functional.grid_sample(
v.densities(),
v.get_coord_grid(world_coordinates=False),
align_corners=True,
) == v.densities(),
i.e. sampling the volume at trivial local coordinates
(no scaling with `voxel_size`` or shift with `volume_translation`)
results in the same volume.
"""
def __init__(
self,
batch_size: int,
grid_sizes: Union[
torch.LongTensor, Tuple[int, int, int], List[torch.LongTensor]
],
device: torch.device,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
**batch_size** : Batch size of the underlying grids
**grid_sizes** : Represents the resolutions of different grids in the batch. Can be
a) tuple of form (H, W, D)
b) list/tuple of length batch_size of lists/tuples of form (H, W, D)
c) torch.Tensor of shape (batch_size, H, W, D)
H, W, D are height, width, depth respectively. If `grid_sizes` is a tuple than
all the grids in the batch have the same resolution.
**voxel_size**: Denotes the size of each volume voxel in world units.
Has to be one of:
a) A scalar (square voxels)
b) 3-tuple or a 3-list of scalars
c) a Tensor of shape (3,)
d) a Tensor of shape (minibatch, 3)
e) a Tensor of shape (minibatch, 1)
f) a Tensor of shape (1,) (square voxels)
**volume_translation**: Denotes the 3D translation of the center
of the volume in world units. Has to be one of:
a) 3-tuple or a 3-list of scalars
b) a Tensor of shape (3,)
c) a Tensor of shape (minibatch, 3)
d) a Tensor of shape (1,) (square voxels)
"""
self.device = device
self._batch_size = batch_size
self._grid_sizes = self._convert_grid_sizes2tensor(grid_sizes)
self._resolution = tuple(torch.max(self._grid_sizes.cpu(), dim=0).values)
# set the local_to_world transform
self._set_local_to_world_transform(
voxel_size=voxel_size, volume_translation=volume_translation
)
def _convert_grid_sizes2tensor(
self, x: Union[torch.LongTensor, List[torch.LongTensor], Tuple[int, int, int]]
) -> torch.LongTensor:
"""
Handle the grid_sizes argument to the constructor.
"""
if isinstance(x, (list, tuple)):
if isinstance(x[0], (torch.LongTensor, list, tuple)):
if self._batch_size != len(x):
raise ValueError("x should have a batch size of 'batch_size'")
# pyre-ignore[6]
if any(len(x_) != 3 for x_ in x):
raise ValueError(
"`grid_sizes` has to be a list of 3-dim tensors of shape: "
"(height, width, depth)"
)
x_shapes = torch.stack(
[
torch.tensor(
# pyre-ignore[6]
list(x_),
dtype=torch.long,
device=self.device,
)
for x_ in x
],
dim=0,
)
elif isinstance(x[0], int):
x_shapes = torch.stack(
[
torch.tensor(list(x), dtype=torch.long, device=self.device)
for _ in range(self._batch_size)
],
dim=0,
)
else:
raise ValueError(
"`grid_sizes` can be a list/tuple of int or torch.Tensor not of "
+ "{type(x[0])}."
)
elif torch.is_tensor(x):
if x.ndim != 2:
raise ValueError(
"`grid_sizes` has to be a 2-dim tensor of shape: (minibatch, 3)"
)
x_shapes = x.to(self.device)
else:
raise ValueError(
"grid_sizes must be either a list of tensors with shape (H, W, D), tensor with"
"shape (batch_size, H, W, D) or a tuple of (H, W, D)."
)
# pyre-ignore[7]
return x_shapes
def _voxel_size_translation_to_transform(
self,
voxel_size: torch.Tensor,
volume_translation: torch.Tensor,
batch_size: int,
) -> Transform3d:
"""
Converts the `voxel_size` and `volume_translation` constructor arguments
to the internal `Transform3d` object `local_to_world_transform`.
"""
volume_size_zyx = self.get_grid_sizes().float()
volume_size_xyz = volume_size_zyx[:, [2, 1, 0]]
# x_local = (
# (x_world + volume_translation) / (0.5 * voxel_size)
# ) / (volume_size - 1)
# x_world = (
# x_local * (volume_size - 1) * 0.5 * voxel_size
# ) - volume_translation
local_to_world_transform = Scale(
(volume_size_xyz - 1) * voxel_size * 0.5, device=self.device
).translate(-volume_translation)
return local_to_world_transform
def get_coord_grid(self, world_coordinates: bool = True) -> torch.Tensor:
"""
Return the 3D coordinate grid of the volumetric grid
in local (`world_coordinates=False`) or world coordinates
(`world_coordinates=True`).
The grid records location of each center of the corresponding volume voxel.
Local coordinates are scaled s.t. the values along one side of the
volume are in range [-1, 1].
Args:
**world_coordinates**: if `True`, the method
returns the grid in the world coordinates,
otherwise, in local coordinates.
Returns:
**coordinate_grid**: The grid of coordinates of shape
`(minibatch, depth, height, width, 3)`, where `minibatch`,
`height`, `width` and `depth` are the batch size, height, width
and depth of the volume `features` or `densities`.
"""
# TODO(dnovotny): Implement caching of the coordinate grid.
return self._calculate_coordinate_grid(world_coordinates=world_coordinates)
def _calculate_coordinate_grid(
self, world_coordinates: bool = True
) -> torch.Tensor:
"""
Calculate the 3D coordinate grid of the volumetric grid either
in local (`world_coordinates=False`) or
world coordinates (`world_coordinates=True`) .
"""
ba, (de, he, wi) = self._batch_size, self._resolution
grid_sizes = self.get_grid_sizes()
# generate coordinate axes
vol_axes = [
torch.linspace(-1.0, 1.0, r, dtype=torch.float32, device=self.device)
for r in (de, he, wi)
]
# generate per-coord meshgrids
Z, Y, X = meshgrid_ij(vol_axes)
# stack the coord grids ... this order matches the coordinate convention
# of torch.nn.grid_sample
vol_coords_local = torch.stack((X, Y, Z), dim=3)[None].repeat(ba, 1, 1, 1, 1)
# get grid sizes relative to the maximal volume size
grid_sizes_relative = (
torch.tensor([[de, he, wi]], device=grid_sizes.device, dtype=torch.float32)
- 1
) / (grid_sizes - 1).float()
if (grid_sizes_relative != 1.0).any():
# if any of the relative sizes != 1.0, adjust the grid
grid_sizes_relative_reshape = grid_sizes_relative[:, [2, 1, 0]][
:, None, None, None
]
vol_coords_local *= grid_sizes_relative_reshape
vol_coords_local += grid_sizes_relative_reshape - 1
if world_coordinates:
vol_coords = self.local_to_world_coords(vol_coords_local)
else:
vol_coords = vol_coords_local
return vol_coords
def get_local_to_world_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
the local coordinate frame of the volume to world coordinates.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**local_to_world_transform**: A Transform3d object converting
points from local coordinates to the world coordinates.
"""
return self._local_to_world_transform
def get_world_to_local_coords_transform(self) -> Transform3d:
"""
Return a Transform3d object that converts points in the
world coordinates to the local coordinate frame of the volume.
Local volume coordinates are scaled s.t. the coordinates along one
side of the volume are in range [-1, 1].
Returns:
**world_to_local_transform**: A Transform3d object converting
points from world coordinates to local coordinates.
"""
return self.get_local_to_world_coords_transform().inverse()
def world_to_local_coords(self, points_3d_world: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_world` of shape
(minibatch, ..., dim) in the world coordinates to
the local coordinate frame of the volume. Local volume
coordinates are scaled s.t. the coordinates along one side of the volume
are in range [-1, 1].
Args:
**points_3d_world**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_local**: `points_3d_world` converted to the local
volume coordinates of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_world.shape
return (
self.get_world_to_local_coords_transform()
.transform_points(points_3d_world.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def local_to_world_coords(self, points_3d_local: torch.Tensor) -> torch.Tensor:
"""
Convert a batch of 3D point coordinates `points_3d_local` of shape
(minibatch, ..., dim) in the local coordinate frame of the volume
to the world coordinates.
Args:
**points_3d_local**: A tensor of shape `(minibatch, ..., 3)`
containing the 3D coordinates of a set of points that will
be converted from the local volume coordinates (ranging
within [-1, 1]) to the world coordinates
defined by the `self.center` and `self.voxel_size` parameters.
Returns:
**points_3d_world**: `points_3d_local` converted to the world
coordinates of the volume of shape `(minibatch, ..., 3)`.
"""
pts_shape = points_3d_local.shape
return (
self.get_local_to_world_coords_transform()
.transform_points(points_3d_local.view(pts_shape[0], -1, 3))
.view(pts_shape)
)
def get_grid_sizes(self) -> torch.LongTensor:
"""
Returns the sizes of individual volumetric grids in the structure.
Returns:
**grid_sizes**: Tensor of spatial sizes of each of the volumes
of size (batchsize, 3), where i-th row holds (D_i, H_i, W_i).
"""
return self._grid_sizes
def _set_local_to_world_transform(
self,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
"""
Sets the internal representation of the transformation between the
world and local volume coordinates by specifying
`voxel_size` and `volume_translation`
Args:
**voxel_size**: Denotes the size of input voxels. Has to be one of:
a) A scalar (square voxels)
b) 3-tuple or a 3-list of scalars
c) a Tensor of shape (3,)
d) a Tensor of shape (minibatch, 3)
e) a Tensor of shape (1,) (square voxels)
**volume_translation**: Denotes the 3D translation of the center
of the volume in world units. Has to be one of:
a) 3-tuple or a 3-list of scalars
b) a Tensor of shape (3,)
c) a Tensor of shape (minibatch, 3)
d) a Tensor of shape (1,) (square voxels)
"""
# handle voxel size and center
# here we force the tensors to lie on self.device
voxel_size = self._handle_voxel_size(voxel_size, len(self))
volume_translation = self._handle_volume_translation(
volume_translation, len(self)
)
self._local_to_world_transform = self._voxel_size_translation_to_transform(
voxel_size, volume_translation, len(self)
)
def _copy_transform_and_sizes(
self,
other: "VolumeLocator",
device: Optional[torch.device] = None,
index: Optional[
Union[int, List[int], Tuple[int], slice, torch.Tensor]
] = _ALL_CONTENT,
) -> None:
"""
Copies the local to world transform and grid sizes to other VolumeLocator object
and moves it to specified device. Operates in place on other.
Args:
other: VolumeLocator object to which to copy
device: torch.device on which to put the result, defatults to self.device
index: Specifies which parts to copy.
Can be an int, slice, list of ints or a boolean or a long tensor.
Defaults to all items (`:`).
"""
device = device if device is not None else self.device
other._grid_sizes = self._grid_sizes[index].to(device)
other._local_to_world_transform = self.get_local_to_world_coords_transform()[
# pyre-fixme[6]: For 1st param expected `Union[List[int], int, slice,
# BoolTensor, LongTensor]` but got `Union[None, List[int], Tuple[int],
# int, slice, Tensor]`.
index
].to(device)
def _handle_voxel_size(
self, voxel_size: _VoxelSize, batch_size: int
) -> torch.Tensor:
"""
Handle the `voxel_size` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"voxel_size has to be either a 3-tuple of scalars, or a scalar, or"
" a torch.Tensor of shape (3,) or (1,) or (minibatch, 3) or (minibatch, 1)."
)
if isinstance(voxel_size, (float, int)):
# convert a scalar to a 3-element tensor
voxel_size = torch.full(
(1, 3), voxel_size, device=self.device, dtype=torch.float32
)
elif isinstance(voxel_size, torch.Tensor):
if voxel_size.numel() == 1:
# convert a single-element tensor to a 3-element one
voxel_size = voxel_size.view(-1).repeat(3)
elif len(voxel_size.shape) == 2 and (
voxel_size.shape[0] == batch_size and voxel_size.shape[1] == 1
):
voxel_size = voxel_size.repeat(1, 3)
return self._convert_volume_property_to_tensor(voxel_size, batch_size, err_msg)
def _handle_volume_translation(
self, translation: _Translation, batch_size: int
) -> torch.Tensor:
"""
Handle the `volume_translation` argument to the `VolumeLocator` constructor.
"""
err_msg = (
"`volume_translation` has to be either a 3-tuple of scalars, or"
" a Tensor of shape (1,3) or (minibatch, 3) or (3,)`."
)
return self._convert_volume_property_to_tensor(translation, batch_size, err_msg)
def __len__(self) -> int:
return self._batch_size
def _convert_volume_property_to_tensor(
self, x: _Vector, batch_size: int, err_msg: str
) -> torch.Tensor:
"""
Handle the `volume_translation` or `voxel_size` argument to
the VolumeLocator constructor.
Return a tensor of shape (N, 3) where N is the batch_size.
"""
if isinstance(x, (list, tuple)):
if len(x) != 3:
raise ValueError(err_msg)
x = torch.tensor(x, device=self.device, dtype=torch.float32)[None]
x = x.repeat((batch_size, 1))
elif isinstance(x, torch.Tensor):
ok = (
(x.shape[0] == 1 and x.shape[1] == 3)
or (x.shape[0] == 3 and len(x.shape) == 1)
or (x.shape[0] == batch_size and x.shape[1] == 3)
)
if not ok:
raise ValueError(err_msg)
if x.device != self.device:
x = x.to(self.device)
if x.shape[0] == 3 and len(x.shape) == 1:
x = x[None]
if x.shape[0] == 1:
x = x.repeat((batch_size, 1))
else:
raise ValueError(err_msg)
return x
def to(self, device: Device, copy: bool = False) -> "VolumeLocator":
"""
Match the functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
returned tensor is a copy of self with the desired torch.device.
If copy = False and the self Tensor already has the correct torch.device,
then self is returned.
Args:
device: Device (as str or torch.device) for the new tensor.
copy: Boolean indicator whether or not to clone self. Default False.
Returns:
VolumeLocator object.
"""
device_ = make_device(device)
if not copy and self.device == device_:
return self
other = self.clone()
if self.device == device_:
return other
other.device = device_
other._grid_sizes = self._grid_sizes.to(device_)
other._local_to_world_transform = self.get_local_to_world_coords_transform().to(
device
)
return other
def clone(self) -> "VolumeLocator":
"""
Deep copy of VoluVolumeLocatormes object. All internal tensors are cloned
individually.
Returns:
new VolumeLocator object.
"""
return copy.deepcopy(self)
def cpu(self) -> "VolumeLocator":
return self.to("cpu")
def cuda(self) -> "VolumeLocator":
return self.to("cuda")