import logging
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING

import torch

    from visdom import Visdom

logger = logging.getLogger(__name__)

[docs] def get_visdom_env(visdom_env: str, exp_dir: str) -> str: """ Parse out visdom environment name from the input config. Args: visdom_env: Name of the wisdom environment, could be empty string. exp_dir: Root experiment directory. Returns: visdom_env: The name of the visdom environment. If the given visdom_env is empty, return the name of the bottom directory in exp_dir. """ if len(visdom_env) == 0: visdom_env = exp_dir.split("/")[-1] else: visdom_env = visdom_env return visdom_env
# TODO: a proper singleton _viz_singleton = None
[docs] def get_visdom_connection( server: str = "http://localhost", port: int = 8097, ) -> Optional["Visdom"]: """ Obtain a connection to a visdom server if visdom is installed. Args: server: Server address. port: Server port. Returns: connection: The connection object. """ try: from visdom import Visdom except ImportError: logger.debug("Cannot load visdom") return None if server == "None": return None global _viz_singleton if _viz_singleton is None: _viz_singleton = Visdom(server=server, port=port) return _viz_singleton
[docs] def visualize_basics( viz: "Visdom", preds: Dict[str, Any], visdom_env_imgs: str, title: str = "", visualize_preds_keys: Tuple[str, ...] = ( "image_rgb", "images_render", "fg_probability", "masks_render", "depths_render", "depth_map", ), store_history: bool = False, ) -> None: """ Visualize basic outputs of a `GenericModel` to visdom. Args: viz: The visdom object. preds: A dictionary containing `GenericModel` outputs. visdom_env_imgs: Target visdom environment name. title: The title of produced visdom window. visualize_preds_keys: The list of keys of `preds` for visualization. store_history: Store the history buffer in visdom windows. """ imout = {} for k in visualize_preds_keys: if k not in preds or preds[k] is None:"cant show {k}") continue v = preds[k].cpu().detach().clone() if k.startswith("depth"): # divide by 95th percentile normfac = ( v.view(v.shape[0], -1) .topk(k=int(0.05 * (v.numel() // v.shape[0])), dim=-1) .values[:, -1] ) v = v / normfac[:, None, None, None].clamp(1e-4) if v.shape[1] == 1: v = v.repeat(1, 3, 1, 1) v = torch.nn.functional.interpolate( v, scale_factor=( 600.0 if ( "_eval" in visdom_env_imgs and k in ("images_render", "depths_render") ) else 200.0 ) / v.shape[2], mode="bilinear", ) imout[k] = v # TODO: handle errors on the outside try: imout = {"all":, dim=2)} except RuntimeError as e: print("cant cat!", e.args) for k, v in imout.items(): viz.images( v.clamp(0.0, 1.0), win=k, env=visdom_env_imgs, opts={"title": title + "_" + k, "store_history": store_history}, )
[docs] def make_depth_image( depths: torch.Tensor, masks: torch.Tensor, max_quantile: float = 0.98, min_quantile: float = 0.02, min_out_depth: float = 0.1, max_out_depth: float = 0.9, ) -> torch.Tensor: """ Convert a batch of depth maps to a grayscale image. Args: depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks. max_quantile: The quantile of the input depth values which will be mapped to `max_out_depth`. min_quantile: The quantile of the input depth values which will be mapped to `min_out_depth`. min_out_depth: The minimal value in each depth map will be assigned this color. max_out_depth: The maximal value in each depth map will be assigned this color. Returns: depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale depth images. """ normfacs = [] for d, m in zip(depths, masks): ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5) if ok.sum() <= 1:"empty depth!") normfacs.append(torch.zeros(2).type_as(depths)) continue dok = d.view(-1)[ok].view(-1) _maxk = max(int(round((1 - max_quantile) * (dok.numel()))), 1) _mink = max(int(round(min_quantile * (dok.numel()))), 1) normfac_max = dok.topk(k=_maxk, dim=-1).values[-1] normfac_min = dok.topk(k=_mink, dim=-1, largest=False).values[-1] normfacs.append(torch.stack([normfac_min, normfac_max])) normfacs = torch.stack(normfacs) _min, _max = (normfacs[:, 0].view(-1, 1, 1, 1), normfacs[:, 1].view(-1, 1, 1, 1)) depths = (depths - _min) / (_max - _min).clamp(1e-4) depths = ( (depths * (max_out_depth - min_out_depth) + min_out_depth) * masks.float() ).clamp(0.0, 1.0) return depths