class pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor.ResNetFeatureExtractor(*args, **kwargs)[source]

Bases: FeatureExtractorBase

Implements an image feature extractor. Depending on the settings allows to extract:

  • deep features: A CNN ResNet backbone from torchvision (with/without

    pretrained weights) which extracts deep features.

  • masks: Segmentation masks.

  • images: Raw input RGB images.


name: name of the resnet backbone (from torchvision) pretrained: If true, will load the pretrained weights stages: List of stages from which to extract features.

Features from each stage are returned as key value pairs in the forward function

normalize_image: If set will normalize the RGB values of

the image based on the Resnet mean/std

image_rescale: If not 1.0, this rescale factor will be

used to resize the image

first_max_pool: If set, a max pool layer is added after the first

convolutional layer

proj_dim: The number of output channels for the convolutional layers l2_norm: If set, l2 normalization is applied to the extracted features add_masks: If set, the masks will be saved in the output dictionary add_images: If set, the images will be saved in the output dictionary global_average_pool: If set, global average pooling step is performed feature_rescale: If not 1.0, this rescale factor will be used to

rescale the output features

name: str = 'resnet34'
pretrained: bool = True
stages: Tuple[int, ...] = (1, 2, 3, 4)
normalize_image: bool = True
image_rescale: float = 0.16
first_max_pool: bool = True
proj_dim: int = 32
l2_norm: bool = True
add_masks: bool = True
add_images: bool = True
global_average_pool: bool = False
feature_rescale: float = 1.0
get_feat_dims() int[source]
forward(imgs: Tensor | None, masks: Tensor | None = None, **kwargs) Dict[Any, Tensor][source]
  • imgs – A batch of input images of shape (B, 3, H, W).

  • masks – A batch of input masks of shape (B, 3, H, W).



A dict {f_i: t_i} keyed by predicted feature names f_i

and their corresponding tensors t_i of shape (B, dim_i, H_i, W_i).