pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor
resnet_feature_extractor
- class pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor.ResNetFeatureExtractor(*args, **kwargs)[source]
Bases:
FeatureExtractorBase
Implements an image feature extractor. Depending on the settings allows to extract:
- deep features: A CNN ResNet backbone from torchvision (with/without
pretrained weights) which extracts deep features.
masks: Segmentation masks.
images: Raw input RGB images.
- Settings:
name: name of the resnet backbone (from torchvision) pretrained: If true, will load the pretrained weights stages: List of stages from which to extract features.
Features from each stage are returned as key value pairs in the forward function
- normalize_image: If set will normalize the RGB values of
the image based on the Resnet mean/std
- image_rescale: If not 1.0, this rescale factor will be
used to resize the image
- first_max_pool: If set, a max pool layer is added after the first
convolutional layer
proj_dim: The number of output channels for the convolutional layers l2_norm: If set, l2 normalization is applied to the extracted features add_masks: If set, the masks will be saved in the output dictionary add_images: If set, the images will be saved in the output dictionary global_average_pool: If set, global average pooling step is performed feature_rescale: If not 1.0, this rescale factor will be used to
rescale the output features
- name: str = 'resnet34'
- pretrained: bool = True
- stages: Tuple[int, ...] = (1, 2, 3, 4)
- normalize_image: bool = True
- image_rescale: float = 0.16
- first_max_pool: bool = True
- proj_dim: int = 32
- l2_norm: bool = True
- add_masks: bool = True
- add_images: bool = True
- global_average_pool: bool = False
- feature_rescale: float = 1.0
- forward(imgs: Tensor | None, masks: Tensor | None = None, **kwargs) Dict[Any, Tensor] [source]
- Parameters:
imgs – A batch of input images of shape (B, 3, H, W).
masks – A batch of input masks of shape (B, 3, H, W).
- Returns:
out_feats –
- A dict {f_i: t_i} keyed by predicted feature names f_i
and their corresponding tensors t_i of shape (B, dim_i, H_i, W_i).