Source code for pytorch3d.implicitron.tools.image_utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import Sequence, Union

import torch


[docs] def mask_background( image_rgb: torch.Tensor, mask_fg: torch.Tensor, dim_color: int = 1, bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0, ) -> torch.Tensor: """ Mask the background input image tensor `image_rgb` with `bg_color`. The background regions are obtained from the binary foreground segmentation mask `mask_fg`. """ tgt_view = [1, 1, 1, 1] tgt_view[dim_color] = 3 # obtain the background color tensor if isinstance(bg_color, torch.Tensor): bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb) elif isinstance(bg_color, (float, tuple, list)): if isinstance(bg_color, float): bg_color = [bg_color] * 3 bg_color_t = torch.tensor( bg_color, device=image_rgb.device, dtype=image_rgb.dtype ).view(*tgt_view) elif isinstance(bg_color, str): if bg_color == "white": bg_color_t = image_rgb.new_ones(tgt_view) elif bg_color == "black": bg_color_t = image_rgb.new_zeros(tgt_view) else: raise ValueError(_invalid_color_error_msg(bg_color)) else: raise ValueError(_invalid_color_error_msg(bg_color)) # cast to the image_rgb's type mask_fg = mask_fg.type_as(image_rgb) # mask the bg image_masked = mask_fg * image_rgb + (1 - mask_fg) * bg_color_t return image_masked
def _invalid_color_error_msg(bg_color) -> str: return ( f"Invalid bg_color={bg_color}. Plese set bg_color to a 3-element" + " tensor. or a string (white | black), or a float." )