pytorch3d.implicitron.models.implicit_function.voxel_grid

voxel_grid

This file contains classes that implement Voxel grids, both in their full resolution as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the TensoRF (https://arxiv.org/abs/2203.09517) paper.

In addition, the module VoxelGridModule implements a trainable instance of one of these classes.

class pytorch3d.implicitron.models.implicit_function.voxel_grid.VoxelGridValuesBase[source]

Bases: object

class pytorch3d.implicitron.models.implicit_function.voxel_grid.VoxelGridBase(*args, **kwargs)[source]

Bases: ReplaceableBase, Module

Base class for all the voxel grid variants whith added trilinear interpolation between voxels (for example if voxel (0.333, 1, 3) is queried that would return the result 2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])

Internally voxel grids are indexed by (features, x, y, z). If queried the point is not inside the voxel grid the vector that will be returned is determined by padding.

Members:
align_corners: parameter used in torch.functional.grid_sample. For details go to

https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by default is True

padding: padding mode for outside grid values ‘zeros’ | ‘border’ | ‘reflection’.

Default is ‘zeros’

mode: interpolation mode to calculate output values :

‘bilinear’ | ‘nearest’ | ‘bicubic’ | ‘trilinear’. Default: ‘bilinear’ Note: mode=’bicubic’ supports only FullResolutionVoxelGrid. When mode=’bilinear’ and the input is 5-D, the interpolation mode used internally will actually be trilinear.

n_features: number of dimensions of base feature vector. Determines how many features

the grid returns.

resolution_changes: a dictionary, where keys are change epochs and values are

3-tuples containing x, y, z grid sizes corresponding to each axis to each epoch

align_corners: bool = True
padding: str = 'zeros'
mode: str = 'bilinear'
n_features: int = 1
resolution_changes: Dict[int, Any] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<function VoxelGridBase.<lambda>>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)
evaluate_world(points: Tensor, grid_values: VoxelGridValuesBase, locator: VolumeLocator) Tensor[source]

Evaluates the voxel grid at points in the world coordinate frame. The interpolation type is determined by the mode member.

Parameters:
  • points (torch.Tensor) – tensor of points that you want to query of a form (n_grids, …, 3)

  • grid_values – an object of type Class.values_type which has tensors as members which have shapes derived from the get_shapes() method

  • locator – a VolumeLocator object

Returns:

torch.Tensor – shape (n_grids, …, n_features)

evaluate_local(points: Tensor, grid_values: VoxelGridValuesBase) Tensor[source]

Evaluates the voxel grid at points in the local coordinate frame, The interpolation type is determined by the mode member.

Parameters:
  • points (torch.Tensor) – tensor of points that you want to query of a form (n_grids, …, 3), in a normalized form (coordinates are in [-1, 1])

  • grid_values – an object of type VMFactorizedVoxelGrid.values_type which has tensors as members which have shapes derived from the get_shapes() method

Returns:

torch.Tensor – shape (n_grids, …, n_features)

get_shapes(epoch: int) Dict[str, Tuple][source]

Using parameters from the __init__ method, this method returns the shapes of individual tensors needed to run the evaluate method.

Parameters:

epoch – If the shape varies during training, which training epoch’s shape to return.

Returns:

a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods

replace the shapes in the dictionary with tensors of those shapes and add the first ‘batch’ dimension. If the required shape is (a, b) and you want to have g grids then the tensor that replaces the shape should have the shape (g, a, b).

get_resolution(epoch: int) List[int][source]

Returns the resolution which the grid should have at specific epoch

Parameters:

calculation (epoch which to use in the resolution) –

Returns:

resolution at specific epoch

static get_output_dim(args: DictConfig) int[source]

Given all the arguments of the grid’s __init__, returns output’s last dimension length.

In particular, if self.evaluate_world or self.evaluate_local are called with points of shape (n_grids, n_points, 3), their output will be of shape (n_grids, n_points, grid.get_output_dim()).

Parameters:

args – DictConfig which would be used to initialize the object

Returns:

output’s last dimension length

change_resolution(grid_values: VoxelGridValuesBase, *, epoch: int | None = None, grid_values_with_wanted_resolution: VoxelGridValuesBase | None = None, mode: str = 'linear', align_corners: bool = True, antialias: bool = False) Tuple[VoxelGridValuesBase, bool][source]

Changes resolution of tensors in grid_values to match the grid_values_with_wanted_resolution or resolution on wanted epoch.

Parameters:
  • epoch – current training epoch, used to see if the grid needs regridding

  • grid_values – instance of self.values_type which contains the voxel grid which will be interpolated to create the new grid

  • epoch – epoch which is used to get the resolution of the new grid_values using self.resolution_changes.

  • grid_values_with_wanted_resolutionVoxelGridValuesBase to whose resolution to interpolate grid_values

  • align_corners – as for torch.nn.functional.interpolate

  • mode – as for torch.nn.functional.interpolate ‘nearest’ | ‘bicubic’ | ‘linear’ | ‘area’ | ‘nearest-exact’. Default: ‘linear’

  • antialias – as for torch.nn.functional.interpolate. Using anti-alias option together with align_corners=False and mode=’bicubic’, interpolation result would match Pillow result for downsampling operation. Supported mode: ‘bicubic’

Returns:

tuple of
  • new voxel grid_values of desired resolution, of type self.values_type

  • True if regridding has happened.

get_resolution_change_epochs() Tuple[int, ...][source]

Returns epochs at which this grid should change epochs.

get_align_corners() bool[source]

Returns True if voxel grid uses align_corners=True

crop_world(min_point_world: Tensor, max_point_world: Tensor, grid_values: VoxelGridValuesBase, volume_locator: VolumeLocator) VoxelGridValuesBase[source]

Crops the voxel grid based on minimum and maximum occupied point in world coordinates. After cropping all 8 corner points are preserved in the voxel grid. This is achieved by preserving all the voxels needed to calculate the point.

+——–B

/ /|

/ / |

+——–+ | <==== Bounding box represented by points A and B: | | | - B has x, y and z coordinates bigger or equal | | + to all other points of the object | | / - A has x, y and z coordinates smaller or equal | |/ to all other points of the object A——–+

Parameters:
  • min_point_world – torch.Tensor of shape (3,). Has x, y and z coordinates smaller or equal to all other occupied points. Point A from the picture above.

  • max_point_world – torch.Tensor of shape (3,). Has x, y and z coordinates bigger or equal to all other occupied points. Point B from the picture above.

  • grid_values – instance of self.values_type which contains the voxel grid which will be cropped to create the new grid

  • volume_locator – VolumeLocator object used to convert world to local cordinates

Returns:

instance of self.values_type which has volume cropped to desired size.

crop_local(min_point_local: Tensor, max_point_local: Tensor, grid_values: VoxelGridValuesBase) VoxelGridValuesBase[source]

Crops the voxel grid based on minimum and maximum occupied point in local coordinates. After cropping both min and max point are preserved in the voxel grid. This is achieved by preserving all the voxels needed to calculate the point.

+——–B

/ /|

/ / |

+——–+ | <==== Bounding box represented by points A and B: | | | - B has x, y and z coordinates bigger or equal | | + to all other points of the object | | / - A has x, y and z coordinates smaller or equal | |/ to all other points of the object A——–+

Parameters:
  • min_point_local – torch.Tensor of shape (3,). Has x, y and z coordinates smaller or equal to all other occupied points. Point A from the picture above. All elements in [-1, 1].

  • max_point_local – torch.Tensor of shape (3,). Has x, y and z coordinates bigger or equal to all other occupied points. Point B from the picture above. All elements in [-1, 1].

  • grid_values – instance of self.values_type which contains the voxel grid which will be cropped to create the new grid

Returns:

instance of self.values_type which has volume cropped to desired size.

class pytorch3d.implicitron.models.implicit_function.voxel_grid.FullResolutionVoxelGridValues(voxel_grid: torch.Tensor)[source]

Bases: VoxelGridValuesBase

voxel_grid: Tensor
class pytorch3d.implicitron.models.implicit_function.voxel_grid.FullResolutionVoxelGrid(*args, **kwargs)[source]

Bases: VoxelGridBase

Full resolution voxel grid equivalent to 4D tensor where shape is (features, width, height, depth) with linear interpolation between voxels.

values_type

alias of FullResolutionVoxelGridValues

evaluate_local(points: Tensor, grid_values: FullResolutionVoxelGridValues) Tensor[source]

Evaluates the voxel grid at points in the local coordinate frame, The interpolation type is determined by the mode member.

Parameters:
  • points (torch.Tensor) – tensor of points that you want to query of a form (…, 3), in a normalized form (coordinates are in [-1, 1])

  • grid_values – an object of type values_type which has tensors as members which have shapes derived from the get_shapes() method

Returns:

torch.Tensor – shape (n_grids, …, n_features)

get_shapes(epoch: int) Dict[str, Tuple][source]
crop_local(min_point_local: Tensor, max_point_local: Tensor, grid_values: FullResolutionVoxelGridValues) FullResolutionVoxelGridValues[source]
class pytorch3d.implicitron.models.implicit_function.voxel_grid.CPFactorizedVoxelGridValues(vector_components_x: torch.Tensor, vector_components_y: torch.Tensor, vector_components_z: torch.Tensor, basis_matrix: torch.Tensor | None = None)[source]

Bases: VoxelGridValuesBase

vector_components_x: Tensor
vector_components_y: Tensor
vector_components_z: Tensor
basis_matrix: Tensor | None = None
class pytorch3d.implicitron.models.implicit_function.voxel_grid.CPFactorizedVoxelGrid(*args, **kwargs)[source]

Bases: VoxelGridBase

Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products (call it ⊗) of each vector type (x, y, z):

3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + … + xn ⊗ yn ⊗ zn

These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are (n_components, r) where r is the relevant resolution.

Each element of this sum has an extra dimension, which gets matrix-multiplied by an appropriate “basis matrix” of shape (n_grids, n_components, n_features). This multiplication brings us to the desired “n_features” dimensionality. If basis_matrix=False the elements of different components are summed together to create (n_grids, n_components, 1) tensor. With some notation abuse, ignoring the interpolation operation, simplifying and denoting n_features as F, n_components as C and n_grids as G:

3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF

The basis feature vectors are passed as obj.basis_matrix.

Members:

n_components: number of vector triplets, higher number gives better approximation. basis_matrix: how to transform components. If matrix_reduction=True result

matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied by the basis_matrix of shape (n_grids, n_components, n_features). If matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed to return to starting shape (n_grids, …, 1).

values_type

alias of CPFactorizedVoxelGridValues

n_components: int = 24
basis_matrix: bool = True
evaluate_local(points: Tensor, grid_values: CPFactorizedVoxelGridValues) Tensor[source]
get_shapes(epoch: int) Dict[str, Tuple[int, int]][source]
crop_local(min_point_local: Tensor, max_point_local: Tensor, grid_values: CPFactorizedVoxelGridValues) CPFactorizedVoxelGridValues[source]
class pytorch3d.implicitron.models.implicit_function.voxel_grid.VMFactorizedVoxelGridValues(vector_components_x: torch.Tensor, vector_components_y: torch.Tensor, vector_components_z: torch.Tensor, matrix_components_xy: torch.Tensor, matrix_components_yz: torch.Tensor, matrix_components_xz: torch.Tensor, basis_matrix: torch.Tensor | None = None)[source]

Bases: VoxelGridValuesBase

vector_components_x: Tensor
vector_components_y: Tensor
vector_components_z: Tensor
matrix_components_xy: Tensor
matrix_components_yz: Tensor
matrix_components_xz: Tensor
basis_matrix: Tensor | None = None
class pytorch3d.implicitron.models.implicit_function.voxel_grid.VMFactorizedVoxelGrid(*args, **kwargs)[source]

Bases: VoxelGridBase

Implementation of Vector-Matrix Factorization of a tensor from https://arxiv.org/abs/2203.09517.

Vector-Matrix Factorization factorizes the 3d grid into three matrices (xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid is a sum of the outer products (call it ⊗) of each matrix with its complementary vector:

3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x.

These tensors are passed in a VMFactorizedVoxelGridValues object (here obj) as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.

Their shapes are (n_grids, n_components, r0, r1) for matrix_components and (n_grids, n_components, r2)` for vector_componenets. Each of r0, r1 and r2 coresponds to one resolution in (width, height and depth).

Each element of this sum has an extra dimension, which gets matrix-multiplied by an appropriate “basis matrix” of shape (n_grids, n_components, n_features). This multiplication brings us to the desired “n_features” dimensionality. If basis_matrix=False the elements of different components are summed together to create (n_grids, n_components, 1) tensor. With some notation abuse, ignoring the interpolation operation, simplifying and denoting n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:

3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1),

(yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF

Members:
n_components: total number of matrix vector pairs, this must be divisible by 3. Set

this if you want to have equal representational power in all 3 directions. You must specify either n_components or distribution_of_components, you cannot specify both.

distribution_of_components: if you do not want equal representational power in

all 3 directions specify a tuple of numbers of matrix_vector pairs for each coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify either n_components or distribution_of_components, you cannot specify both.

basis_matrix: how to transform components. If matrix_reduction=True result

matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied by the basis_matrix of shape (n_grids, n_components, n_features). If matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components) is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed to return to starting shape (n_grids, …, 1).

values_type

alias of VMFactorizedVoxelGridValues

n_components: int | None = None
distribution_of_components: Tuple[int, int, int] | None = None
basis_matrix: bool = True
evaluate_local(points: Tensor, grid_values: VMFactorizedVoxelGridValues) Tensor[source]
get_shapes(epoch: int) Dict[str, Tuple][source]
crop_local(min_point_local: Tensor, max_point_local: Tensor, grid_values: VMFactorizedVoxelGridValues) VMFactorizedVoxelGridValues[source]
class pytorch3d.implicitron.models.implicit_function.voxel_grid.VoxelGridModule(*args, **kwargs)[source]

Bases: Configurable, Module

A wrapper torch.nn.Module for the VoxelGrid classes, which contains parameters that are needed to train the VoxelGrid classes. Can contain the parameters for the voxel grid as pytorch parameters or as registered buffers.

Members:
voxel_grid_class_type: The name of the class to use for voxel_grid,

which must be available in the registry. Default FullResolutionVoxelGrid.

voxel_grid: An instance of VoxelGridBase. This is the object which

this class wraps.

extents: 3-tuple of a form (width, height, depth), denotes the size of the grid

in world units.

translation: 3-tuple of float. The center of the volume in world units as (x, y, z). init_std: Parameters are initialized using the gaussian distribution

with mean=init_mean and std=init_std. Default 0.1

init_mean: Parameters are initialized using the gaussian distribution

with mean=init_mean and std=init_std. Default 0.

hold_voxel_grid_as_parameters: if True components of the underlying voxel grids

will be saved as parameters and therefore be trainable. Default True.

param_groups: dictionary where keys are names of individual parameters

or module members and values are the parameter group where the parameter/member will be sorted to. “self” key is used to denote the parameter group at the module level. Possible keys, including the “self” key do not have to be defined. By default all parameters are put into “default” parameter group and have the learning rate defined in the optimizer, it can be overridden at the:

  • module level with “self” key, all the parameters and child

    module’s parameters will be put to that parameter group

  • member level, which is the same as if the param_groups in that

    member has key=“self” and value equal to that parameter group. This is useful if members do not have param_groups, for example torch.nn.Linear.

  • parameter level, parameter with the same name as the key

    will be put to that parameter group.

voxel_grid_class_type: str = 'FullResolutionVoxelGrid'
voxel_grid: VoxelGridBase
extents: Tuple[float, float, float] = (2.0, 2.0, 2.0)
translation: Tuple[float, float, float] = (0.0, 0.0, 0.0)
init_std: float = 0.1
init_mean: float = 0
hold_voxel_grid_as_parameters: bool = True
param_groups: Dict[str, str] = Field(name=None,type=None,default=<dataclasses._MISSING_TYPE object>,default_factory=<function VoxelGridModule.<lambda>>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=<dataclasses._MISSING_TYPE object>,_field_type=None)
forward(points: Tensor) Tensor[source]

Evaluates points in the world coordinate frame on the voxel_grid.

Parameters:

points (torch.Tensor) – tensor of points that you want to query of a form (…, 3)

Returns:

torch.Tensor of shape (…, n_features)

set_voxel_grid_parameters(params: VoxelGridValuesBase) None[source]

Sets the parameters of the underlying voxel grid.

Parameters:

params – parameters of type self.voxel_grid.values_type which will replace current parameters

static get_output_dim(args: DictConfig) int[source]

Utility to help predict the shape of the output of forward.

Parameters:

args – DictConfig which would be used to initialize the object

Returns:

int – the length of the last dimension of the output tensor

subscribe_to_epochs() Tuple[Tuple[int, ...], Callable[[int], bool]][source]

Method which expresses interest in subscribing to optimization epoch updates.

Returns:

tuple of epochs on which to call a callable and callable to be called on

particular epoch. The callable returns True if parameter change has happened else False and it must be supplied with one argument, epoch.

get_device() device[source]

Returns torch.device on which module parameters are located

crop_self(min_point: Tensor, max_point: Tensor) None[source]

Crops self to only represent points between min_point and max_point (inclusive).

Parameters:
  • min_point – torch.Tensor of shape (3,). Has x, y and z coordinates smaller or equal to all other occupied points.

  • max_point – torch.Tensor of shape (3,). Has x, y and z coordinates bigger or equal to all other occupied points.

Returns:

nothing

get_grid_points(epoch: int) Tensor[source]

Returns a grid of points that represent centers of voxels of the underlying voxel grid in world coordinates at specific epoch.

Parameters:

epoch – underlying voxel grids change resolution depending on the epoch, this argument is used to determine the resolution of the voxel grid at that epoch.

Returns:

tensor of shape [xresolution, yresolution, zresolution, 3] where

xresolution, yresolution, zresolution are resolutions of the underlying voxel grid