Source code for ablator.modules.loggers.tensor

import logging
import multiprocessing
import struct
import threading
import time
import typing as ty
from pathlib import Path
from typing import Union

import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from tensorboardX import SummaryWriter
from tensorboardX.event_file_writer import EventsWriter
from tensorboardX.proto import event_pb2
from tensorboardX.record_writer import masked_crc32c

from ablator.config.main import ConfigBase
from ablator.config.utils import flatten_nested_dict
from ablator.modules.loggers import LoggerBase

logging.getLogger("tensorboardX").setLevel(logging.ERROR)


[docs]class RecordWriter: """ an extension to `tensorboardX.record_writer.RecordWriter` but removes support for remote writing. Parameters ---------- path : Path The path of where to write the records. """ def __init__(self, path: Path): self._name_to_tf_name: dict[str, str] = {} self._tf_names: set[str] = set() self.path = path self._writer = None
[docs] def write(self, data): with open(self.path, "ab") as writer: w = writer.write header = struct.pack("Q", len(data)) w(header) w(struct.pack("I", masked_crc32c(header))) w(data) w(struct.pack("I", masked_crc32c(data)))
[docs] def flush(self): ...
[docs] def close(self): ...
# Monkey-patching for faster writes to work with mount # pylint: disable=super-init-not-called,no-member
[docs]class MyEventsWriter(EventsWriter): """ Events files have a name of the form '/some/file/path/events.out.tfevents.[timestamp].[hostname]' """ def __init__(self, filename): self._file_name = filename self._num_outstanding_events = 0 self._py_recordio_writer = RecordWriter(self._file_name) # Initialize an event instance. self._event = event_pb2.Event() self._event.wall_time = time.time() self._event.file_version = "brain.Event:2" self._lock = threading.Lock() self.write_event(self._event)
[docs]class TensorboardLogger(LoggerBase): """ A logger class for Tensorboard visualization. Parameters ---------- summary_dir : Union[str, Path] The directory to store the Tensorboard summary files. Attributes ---------- summary_dir : Union[str, Path] The directory to store the Tensorboard summary files. backend_logger : SummaryWriter The PyTorch Tensorboard SummaryWriter object used to log data. """ def __init__(self, summary_dir: Union[str, Path]): # Initialize the TensorboardLogger with a summary directory. self.thread_lock = threading.Lock() self.summary_dir = Path(summary_dir).as_posix() self.backend_logger = SummaryWriter( log_dir=summary_dir, max_queue=2, flush_secs=2 ) fw = self.backend_logger.file_writer.event_writer fw.close() filename = fw._ev_writer._file_name fw._ev_writer = MyEventsWriter(filename) fw._event_queue = multiprocessing.Queue(2) fw.reopen() super().__init__(heartbeat_interval=10)
[docs] def add_image( self, k: str, v: np.ndarray, itr: int, dataformats: ty.Optional[str] = "CHW" ): """ Add an image to the TensorBoard dashboard. Parameters ---------- k : str The tag associated with the image. v : np.ndarray The image data. itr : int The iteration number. dataformats : ty.Optional[str] The format of the image data, by default ``"CHW"``. """ with self.thread_lock: self.backend_logger.add_image(k, v, itr, dataformats=dataformats) self.backend_logger.flush()
[docs] def add_table(self, k: str, v: pd.DataFrame, itr: int): """ Add a table to the TensorBoard dashboard. Parameters ---------- k : str The tag associated with the table. v : pd.DataFrame The table data. itr : int The iteration number. """ with self.thread_lock: self.backend_logger.add_text(k, v.to_markdown(), itr) self.backend_logger.flush()
[docs] def add_text(self, k: str, v: str, itr: int): """ Add a text to the TensorBoard dashboard. Parameters ---------- k : str The tag associated with the text. v : str The text data. itr : int The iteration number. """ with self.thread_lock: self.backend_logger.add_text(k, v, itr) self.backend_logger.flush()
[docs] def add_scalars(self, k: str, v: dict[str, float | int], itr: int): """ Add multiple scalars to the TensorBoard dashboard. Parameters ---------- k : str The main tag associated with the scalars. v : dict[str, float | int] A dictionary of scalar tags and values. itr : int The iteration number. """ with self.thread_lock: for _k, _v in v.items(): self.backend_logger.add_scalar(f"{k}_{_k}", _v, itr) self.backend_logger.flush()
# NOTE this is buggy: # self.backend_logger.add_scalars(k, v_dict, itr)
[docs] def add_scalar(self, k: str, v: float | int, itr: int): """ Add a scalar to the TensorBoard dashboard. Parameters ---------- k : str The tag associated with the scalar. v : float | int The scalar value. itr : int The iteration number. """ with self.thread_lock: if v is None: self.backend_logger.add_scalar(k, np.nan, itr) else: self.backend_logger.add_scalar(k, v, itr) self.backend_logger.flush()
[docs] def write_config(self, config: ConfigBase): """ Write the configuration to the TensorBoard dashboard. Parameters ---------- config : ConfigBase The configuration object. """ with self.thread_lock: hparams = flatten_nested_dict(config.to_dict()) run_config = OmegaConf.to_yaml(OmegaConf.create(hparams)).replace( "\n", "\n\n" ) self.backend_logger.add_text("config", run_config, 0) self.backend_logger.flush()
[docs] def heartbeat(self, timeout: int | None = None): assert timeout is None self._sync()
def _sync(self): with self.thread_lock: self.backend_logger.flush()