Skip to content

Visualization Library

Modular visualization components for reviewing tracking and behavior analysis results.

Components include:

  • Track + label overlay on video frames
  • Interactive video playback
  • Egocentric crop generation
  • Interaction crop generation (pair-level)
  • Global embedding colored scatter plots
  • Timeline plots

visualization_library

Visualization library for behavior datasets.

This library provides modular visualization components: - Data loading (tracks, labels, ground truth) - Overlay preparation and frame drawing - Video streaming with overlays - Interactive video playback - Egocentric crop generation

Example usage

from mosaic.behavior.visualization_library import playback playback.play_video(dataset, group="hex", sequence="hex_3", ...)

from mosaic.behavior.visualization_library.egocentric_crop import EgocentricCrop crop_feat = EgocentricCrop(params={"target_id": 0, "crop_size": (256, 256)}) dataset.run_feature(crop_feat, sequences=["hex_3"])

EgocentricCrop

EgocentricCrop(inputs: Inputs = Inputs(('tracks',)), params: dict[str, object] | None = None)

Generate egocentric (animal-centered) video crops.

Centers the view on a target individual (or all individuals if target_id=None), optionally rotating to align the animal's heading with the +x axis. Can output as video file or individual frame images.

Parameters

target_id : Any, optional ID of the individual to center on. If None, processes ALL individuals found in the tracks data, creating separate outputs for each. center_mode : str or int How to compute the center point: - "default": mean of all pose points poseX0..N/poseY0..N (pixel coords) - "pose0" or 0: use poseX0/poseY0 (typically head/nose) - int: use specific pose point index crop_size : tuple of (int, int) Output crop dimensions as (width, height) in pixels rotate_to_heading : bool If True, rotate crop so animal's heading aligns with +x axis heading_points : tuple of (int, int) (neck_idx, tail_idx) pose point indices for heading computation. Heading points FROM tail TO neck (direction animal is facing). margin_factor : float Extra margin for rotation (1.5 = 50% larger pre-crop before rotation) center_offset_px : float Pixel offset along heading direction from computed center (default 0). Positive shifts forward (toward head). Useful for centering on specific body parts, e.g. 35 for body center in bees. body_mask : bool If True, apply elliptical body mask to isolate the focal individual. body_mask_length_px : int Length (semi-major axis) of the body mask ellipse in pixels. body_mask_width_px : int Width (semi-minor axis) of the body mask ellipse in pixels. use_clahe : bool If True, apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to improve contrast in crops. clahe_clip_limit : float CLAHE clip limit parameter. clahe_tile_grid_size : int CLAHE tile grid size (both dimensions). grayscale : bool If True, convert output to single-channel grayscale. transform_keypoints : bool If True, transform pose keypoint coordinates into crop space and include them in the metadata output as poseX_crop, poseY_crop. output_mode : str Output format: - "video": single video file per individual - "frames": individual frame images per individual - "both": video + frames output_fps : float, optional Output video FPS. If None, uses source video FPS. output_root : str, optional Override output directory. If None, outputs to media/egocentric_crops/. frame_format : str Format for frame images ("png" or "jpg") background_color : int Padding color for out-of-bounds regions (0=black, 255=white)

Examples

Process a single individual:

crop = EgocentricCrop(params={"target_id": 0, "crop_size": (256, 256)}) dataset.run_feature(crop, sequences=["hex_3"])

Bee-style crop with body masking and CLAHE:

crop = EgocentricCrop(params={ ... "crop_size": (192, 192), ... "center_offset_px": 35.0, ... "body_mask": True, ... "use_clahe": True, ... "grayscale": True, ... "angle_col": "ANGLE", ... }) dataset.run_feature(crop, sequences=["hex_3"])

Source code in src/mosaic/behavior/visualization_library/egocentric_crop.py
def __init__(
    self,
    inputs: EgocentricCrop.Inputs = Inputs(("tracks",)),
    params: dict[str, object] | None = None,
):
    self.inputs = inputs
    self.params = self.Params.from_overrides(params)
    self._ds = None
    self._scope: Scope = Scope()
    self._run_root: Path | None = None
    self._clahe = None  # lazily constructed; reused across frames

    # Storage settings (for feature pipeline integration)
    self.storage_feature_name = self.name
    self.storage_use_input_suffix = False
    self.skip_existing_outputs = False

bind_dataset

bind_dataset(ds)

Called by Dataset.run_feature before any fit/transform.

Source code in src/mosaic/behavior/visualization_library/egocentric_crop.py
def bind_dataset(self, ds):
    """Called by Dataset.run_feature before any fit/transform."""
    self._ds = ds

set_scope

set_scope(scope: Scope) -> None

Receive scope constraints from run_feature.

Source code in src/mosaic/behavior/visualization_library/egocentric_crop.py
def set_scope(self, scope: Scope) -> None:
    """Receive scope constraints from run_feature."""
    self._scope = scope

transform

transform(df: DataFrame) -> pd.DataFrame

Process a single sequence's tracks to generate egocentric crop video/frames.

Parameters

df : pd.DataFrame Tracks DataFrame for a single sequence

Returns

pd.DataFrame Metadata DataFrame with crop info per frame/id

Source code in src/mosaic/behavior/visualization_library/egocentric_crop.py
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Process a single sequence's tracks to generate egocentric crop video/frames.

    Parameters
    ----------
    df : pd.DataFrame
        Tracks DataFrame for a single sequence

    Returns
    -------
    pd.DataFrame
        Metadata DataFrame with crop info per frame/id
    """
    if df is None or df.empty:
        return pd.DataFrame()

    p = self.params
    group = (
        str(df[COLUMNS.group_col].iloc[0])
        if COLUMNS.group_col in df.columns
        else ""
    )
    sequence = (
        str(df[COLUMNS.seq_col].iloc[0]) if COLUMNS.seq_col in df.columns else ""
    )

    # Resolve video paths (supports multi-video sequences)
    video_paths = self._ds.resolve_media_paths(group, sequence)

    # Determine which IDs to process
    if p.target_id is None:
        # Process all unique IDs
        unique_ids = df[COLUMNS.id_col].dropna().unique()
        all_metadata = []
        for uid in unique_ids:
            df_target = df[df[COLUMNS.id_col] == uid].copy()
            if df_target.empty:
                continue
            metadata = self._process_single_id(
                video_paths, df_target, group, sequence, uid
            )
            all_metadata.append(metadata)
        if all_metadata:
            return pd.concat(all_metadata, ignore_index=True)
        return pd.DataFrame()
    else:
        # Process single ID
        df_target = df[df[COLUMNS.id_col] == p.target_id].copy()
        if df_target.empty:
            raise ValueError(f"No data for target_id={p.target_id}")
        return self._process_single_id(
            video_paths, df_target, group, sequence, p.target_id
        )

InteractionCropPipeline

InteractionCropPipeline(inputs: Inputs = Inputs(('tracks', Result(feature='pair-interaction-filter'))), params: dict[str, object] | None = None)

Generate egocentric crop videos for detected interaction segments.

Reads interaction segments from an upstream pair-interaction-filter result and generates per-individual cropped videos for each segment. Optionally applies body masking, CLAHE, and grayscale conversion.

Inputs

This feature takes two inputs: 1. Tracks (standard trajectory data with pose keypoints) 2. A pair-interaction-filter result providing interaction segments

The pipeline iterates over the filter result's interaction segments (grouped by id_a, id_b, interaction_id) and extracts egocentric crops from the source video for each individual in the pair.

Output

Videos are written to <run_root>/ when run via the pipeline (run_id-tagged). Returns a metadata DataFrame with one row per generated clip: - group, sequence, id_a, id_b, target_id, interaction_id - start_frame, end_frame, n_frames - video_path (filename only, relative to run_root)

Source code in src/mosaic/behavior/visualization_library/interaction_crop.py
def __init__(
    self,
    inputs: InteractionCropPipeline.Inputs = Inputs(
        ("tracks", Result(feature="pair-interaction-filter"))
    ),
    params: dict[str, object] | None = None,
):
    self.inputs = inputs
    self.params = self.Params.from_overrides(params)
    self._ds = None
    self._scope: Scope = Scope()
    self._run_root: Path | None = None
    self._clahe = None  # lazily constructed; reused across frames/segments

apply

apply(df: DataFrame) -> pd.DataFrame

Process merged tracks + interaction-filter DataFrame.

The pipeline merges both inputs on frame, so df contains track columns and filter columns (id_a, id_b, interaction_id, interaction_start, interaction_end).

Source code in src/mosaic/behavior/visualization_library/interaction_crop.py
def apply(self, df: pd.DataFrame) -> pd.DataFrame:
    """Process merged tracks + interaction-filter DataFrame.

    The pipeline merges both inputs on frame, so *df* contains
    track columns **and** filter columns (id_a, id_b,
    interaction_id, interaction_start, interaction_end).
    """
    if df.empty:
        return pd.DataFrame()

    from mosaic.media.video_io import MultiVideoReader

    cv2.setNumThreads(2)  # prevent OpenCV from saturating all cores

    p = self.params
    group = str(df[C.group_col].iloc[0]) if C.group_col in df.columns else ""
    sequence = str(df[C.seq_col].iloc[0]) if C.seq_col in df.columns else ""

    # Resolve video
    video_paths = self._ds.resolve_media_paths(group, sequence)

    # Group by interaction segment
    required = ["id_a", "id_b", "interaction_id", "interaction_start", "interaction_end"]
    for col in required:
        if col not in df.columns:
            raise ValueError(
                f"Missing column '{col}' — ensure pair-interaction-filter "
                f"output is provided as an input."
            )

    seg_groups = list(df.groupby(["id_a", "id_b", "interaction_id"]))

    # Sort segments by start frame for sequential video reading
    seg_groups.sort(key=lambda x: int(x[1]["interaction_start"].iloc[0]))

    # Open video reader once for the whole sequence
    reader = MultiVideoReader(video_paths)
    output_fps = p.output_fps or reader.fps

    clip_records = []
    try:
        for (id_a, id_b, seg_id), seg_df in seg_groups:
            start_frame = int(seg_df["interaction_start"].iloc[0])
            end_frame = int(seg_df["interaction_end"].iloc[0])

            # Determine which individuals to crop
            target_ids = [id_a]
            if p.crop_both_individuals:
                target_ids.append(id_b)

            for target_id in target_ids:
                record = self._crop_segment(
                    reader=reader,
                    output_fps=output_fps,
                    df_tracks=df,
                    target_id=target_id,
                    start_frame=start_frame,
                    end_frame=end_frame,
                    id_a=id_a,
                    id_b=id_b,
                    seg_id=seg_id,
                    group=group,
                    sequence=sequence,
                )
                if record is not None:
                    clip_records.append(record)
    finally:
        reader.close()

    if not clip_records:
        return pd.DataFrame()
    return pd.DataFrame(clip_records)

VizGlobalColored

VizGlobalColored(inputs: Inputs = Inputs(()), params: dict[str, object] | None = None)

Generic scatter plot visualization for any global embedding or feature columns.

Uses ResultColumn params for fully customizable x and y axes. For example, t-SNE coordinates, PCA components, speed vs approach distance, etc. Labels can be from any feature's parquet output (via ResultColumn) or ground truth labels (via GroundTruthLabelsSource).

Source code in src/mosaic/behavior/visualization_library/viz_global_colored.py
def __init__(
    self,
    inputs: VizGlobalColored.Inputs = Inputs(()),
    params: dict[str, object] | None = None,
):
    self.inputs = inputs
    self.params = self.Params.from_overrides(params)
    self.storage_feature_name = self.name
    self.storage_use_input_suffix = True
    self._ds = None
    self._figs: list[tuple[str, Figure]] = []
    self._marker_written = False
    self._summary: dict[str, object] = {}
    self._scope: Scope = Scope()
    self._debug_arrays: dict[str, object] | None = None

TimelinePlot

TimelinePlot(inputs: Inputs = Inputs(()), params: dict[str, object] | None = None)

Visualize per-frame labels as horizontal colored-bar timelines.

Params

source : dict Feature reference: {"feature": "kpms-apply", "run_id": None, "pattern": "*.parquet"} Or ground-truth labels: {"source": "labels", "kind": "CalMS21"} label_column : str or None Column containing the labels. Auto-detected if None. label_columns : list[str] or None Combine multiple binary (0/1) columns into one composite label. The label value is the column name of the first active column, or 0 when none are active. Overrides label_column. Use with skip_labels=[0] to hide inactive frames. skip_labels : list or None Label values to not draw (rendered as white/blank space). Example: [0] to hide "no event" frames. symmetric_pairs : bool If True, treat (A,B) == (B,A) for pair-level data. palette : str Seaborn palette name for label colors (default "tab20"). pair_palette : str Palette for asymmetric pair role shading (default "Paired"). figsize_width : float Width of each figure in inches. row_height : float Height per timeline row in inches. min_fig_height / max_fig_height : float Clamp figure height. show_legend : bool Whether to add a legend. When there are many labels the legend is placed outside the plot area. title_template : str Format string for plot title; {sequence} is replaced. dpi : int Output resolution. per_sequence : bool One PNG per sequence (True) or a single combined PNG (False). missing_label_value : int Sentinel for unlabeled frames (rendered gray). label_name_map : dict or None Optional {label_id: display_name} for the legend.

Output

PNG file(s) in the run folder plus a single marker parquet row for indexing.

Source code in src/mosaic/behavior/visualization_library/viz_timeline.py
def __init__(
    self,
    inputs: TimelinePlot.Inputs = Inputs(()),
    params: dict[str, object] | None = None,
):
    self.inputs = inputs
    self.params = self.Params.from_overrides(params)

    self.storage_feature_name = self.name
    self.storage_use_input_suffix = True
    self._ds = None
    self._figs: list[tuple[str, Figure]] = []
    self._marker_written = False
    self._summary: dict = {}
    self._scope: Scope = Scope()

demo_load_visual_inputs

demo_load_visual_inputs(ds, group: str, sequence: str, features: Dict[str, Optional[str]])

Small wrapper to quickly inspect what load_tracks_and_labels returns. Usage (notebook): tracks, labels = demo_load_visual_inputs(dataset, "G1", "S1", {"temporal-stack": None, "behavior-xgb-pred": ""})

Source code in src/mosaic/behavior/visualization_library/data_loading.py
def demo_load_visual_inputs(
    ds, group: str, sequence: str, features: Dict[str, Optional[str]]
):
    """
    Small wrapper to quickly inspect what load_tracks_and_labels returns.
    Usage (notebook):
        tracks, labels = demo_load_visual_inputs(dataset, "G1", "S1",
                                                 {"temporal-stack": None,
                                                  "behavior-xgb-pred": "<run_id>"})
    """
    tracks, labels = load_tracks_and_labels(ds, group, sequence, features)
    print(f"Tracks shape: {tracks.shape}")
    for kind in ("per_id", "per_pair"):
        print(f"{kind}:")
        for feat, mapping in labels[kind].items():
            print(f"  {feat}: {len(mapping)} series")
    return tracks, labels

load_ground_truth_labels

load_ground_truth_labels(ds, label_kind: str, group: str, sequence: str) -> pd.DataFrame

Load per-frame ground-truth labels for a given kind/group/sequence.

Returns a DataFrame with columns

frame, label_id, label_name (if mapping provided in the npz). For individual_pair_v1 format, also includes id1, id2 columns.

Source code in src/mosaic/behavior/visualization_library/data_loading.py
def load_ground_truth_labels(
    ds,
    label_kind: str,
    group: str,
    sequence: str,
) -> pd.DataFrame:
    """
    Load per-frame ground-truth labels for a given kind/group/sequence.

    Returns a DataFrame with columns:
        frame, label_id, label_name (if mapping provided in the npz).
        For individual_pair_v1 format, also includes id1, id2 columns.
    """
    labels_root = Path(ds.get_root("labels")) / label_kind
    idx_path = labels_root / "index.csv"
    if not idx_path.exists():
        raise FileNotFoundError(
            f"Label index not found for kind='{label_kind}': {idx_path}"
        )
    df_idx = pd.read_csv(idx_path)
    if df_idx.empty:
        raise FileNotFoundError(f"No labels indexed for kind='{label_kind}'.")

    hits = df_idx[
        (df_idx["group"].astype(str) == str(group))
        & (df_idx["sequence"].astype(str) == str(sequence))
    ]
    if hits.empty:
        raise FileNotFoundError(
            f"No GT labels for kind='{label_kind}' group='{group}' sequence='{sequence}'."
        )

    path = ds.resolve_path(hits.iloc[0]["abs_path"])
    payload = np.load(path, allow_pickle=True)
    frames = payload["frames"]
    label_ids = payload["labels"]
    label_id_list = payload.get("label_ids")
    label_name_list = payload.get("label_names")
    id_to_name: dict[int, str] = {}
    if label_id_list is not None and label_name_list is not None:
        for lid, name in zip(label_id_list, label_name_list):
            id_to_name[int(lid)] = str(name)
    label_names = [id_to_name.get(int(val), str(val)) for val in label_ids]

    result = {
        "frame": frames.astype(int, copy=False),
        "label_id": label_ids.astype(int, copy=False),
        "label_name": label_names,
    }

    # Include individual_ids for pair-aware labels (individual_pair_v1 format)
    if "individual_ids" in payload.files:
        individual_ids = np.asarray(payload["individual_ids"])
        if individual_ids.ndim == 1:
            individual_ids = individual_ids.reshape(-1, 2)
        result["id1"] = individual_ids[:, 0].astype(int, copy=False)
        result["id2"] = individual_ids[:, 1].astype(int, copy=False)

    return pd.DataFrame(result)

load_tracks_and_labels

load_tracks_and_labels(ds, group: str, sequence: str, feature_runs: Dict[str, Optional[str]]) -> Tuple[pd.DataFrame, Dict[str, Any]]

Load a single sequence's tracks plus per-frame labels from feature/model runs.

Parameters

ds : Dataset Loaded Dataset instance. group, sequence : str The scope to load. feature_runs : dict[str, str | None] Mapping of feature/model storage names -> run_id. If run_id is None, the latest finished run is used.

Returns

tracks_df : pd.DataFrame Standard tracks for the requested (group, sequence). labels : dict { "per_id": {feature_name: {id_value: Series}}, "per_pair": {feature_name: {(id1, id2): Series}}, "raw": {feature_name: DataFrame} # full frame per feature for bespoke use } Series are indexed by frame and hold the chosen label column.

Source code in src/mosaic/behavior/visualization_library/data_loading.py
def load_tracks_and_labels(
    ds,
    group: str,
    sequence: str,
    feature_runs: Dict[str, Optional[str]],
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Load a single sequence's tracks plus per-frame labels from feature/model runs.

    Parameters
    ----------
    ds : Dataset
        Loaded Dataset instance.
    group, sequence : str
        The scope to load.
    feature_runs : dict[str, str | None]
        Mapping of feature/model storage names -> run_id.
        If run_id is None, the latest finished run is used.

    Returns
    -------
    tracks_df : pd.DataFrame
        Standard tracks for the requested (group, sequence).
    labels : dict
        {
          "per_id": {feature_name: {id_value: Series}},
          "per_pair": {feature_name: {(id1, id2): Series}},
          "raw": {feature_name: DataFrame}  # full frame per feature for bespoke use
        }
        Series are indexed by frame and hold the chosen label column.
    """
    tracks_df = None
    for _, _, df in yield_sequences(ds, groups=[group], sequences=[sequence]):
        tracks_df = df
        break
    if tracks_df is None:
        raise FileNotFoundError(
            f"No tracks found for group='{group}', sequence='{sequence}'."
        )

    per_id: dict[str, dict[Any, pd.Series]] = {}
    per_pair: dict[str, dict[Tuple[Any, Any], pd.Series]] = {}
    raw: dict[str, pd.DataFrame] = {}

    for feature_name, run_id in feature_runs.items():
        # Resolve run_id if not provided
        resolved_run_id = run_id
        if not resolved_run_id:
            resolved_run_id, _ = latest_feature_run_root(ds, feature_name)

        idx_path = feature_index_path(ds, feature_name)
        if not idx_path.exists():
            raise FileNotFoundError(
                f"Missing feature index for '{feature_name}': {idx_path}"
            )
        df_idx = pd.read_csv(idx_path)

        # Normalize NaNs/None to empty strings so blank/absent groups still match
        for col in ("sequence", "group"):
            if col in df_idx.columns:
                df_idx[col] = df_idx[col].fillna("").astype(str)

        df_idx = df_idx[df_idx["run_id"].astype(str) == str(resolved_run_id)]
        df_idx = df_idx[df_idx["sequence"].astype(str) == str(sequence)]
        if "group" in df_idx.columns:
            df_idx = df_idx[df_idx["group"].astype(str) == str(group)]

        if df_idx.empty:
            raise FileNotFoundError(
                f"No rows in feature index for '{feature_name}' run_id='{resolved_run_id}' "
                f"group='{group}' sequence='{sequence}'."
            )

        abs_path_raw = df_idx.iloc[0]["abs_path"]
        path = ds.resolve_path(abs_path_raw)
        df_feat = pd.read_parquet(path)
        raw[feature_name] = df_feat

        label_col = _pick_label_column(df_feat)
        if not label_col or "frame" not in df_feat.columns:
            continue  # nothing label-like to index

        df_norm = _normalize_identity_columns(df_feat)
        if "id1" in df_norm.columns:
            has_id1 = df_norm["id1"].notna()
            has_id2 = (
                df_norm["id2"].notna()
                if "id2" in df_norm.columns
                else pd.Series(False, index=df_norm.index)
            )

            # Per-pair rows (id1 + id2 present)
            pair_rows = df_norm[has_id1 & has_id2]
            if not pair_rows.empty:
                pairs = pair_rows[["id1", "id2"]].apply(
                    lambda row: tuple(sorted((int(row["id1"]), int(row["id2"])))),
                    axis=1,
                )
                pair_rows = pair_rows.assign(_pair=pairs)
                for pair, sub in pair_rows.groupby("_pair"):
                    series = sub.sort_values("frame").groupby("frame")[label_col].last()
                    per_pair.setdefault(feature_name, {})[pair] = series.sort_index()

            # Per-individual rows (id1 present, id2 missing)
            indiv_rows = df_norm[has_id1 & ~has_id2]
            if not indiv_rows.empty:
                for id_val, sub in indiv_rows.groupby("id1"):
                    id_key = int(id_val)
                    series = sub.sort_values("frame").groupby("frame")[label_col].last()
                    per_id.setdefault(feature_name, {})[id_key] = series.sort_index()

            # Global rows (no id1) stay under None
            global_rows = df_norm[~has_id1]
            if not global_rows.empty:
                series = (
                    global_rows.sort_values("frame").groupby("frame")[label_col].last()
                )
                per_id.setdefault(feature_name, {})[None] = series.sort_index()
        else:
            # No identity columns: global series
            series = df_norm.sort_values("frame").groupby("frame")[label_col].last()
            per_id.setdefault(feature_name, {})[None] = series.sort_index()

    labels = {"per_id": per_id, "per_pair": per_pair, "raw": raw}
    return tracks_df, labels

draw_frame

draw_frame(image: ndarray, frame_overlay: dict, id_colors: dict, show_labels: bool = True, point_radius: int = 8, bbox_thickness: int = 2, show_individual_bboxes: bool = True, scale: Tuple[float, float] = (1.0, 1.0), color_feature: Optional[str] = None, color_mode: Optional[str] = None, pair_box_feature: Optional[str] = None, pair_box_behaviors: Optional[Iterable[Any]] = None, hide_individual_bboxes_for_pair: bool = False) -> np.ndarray

Draw pose points, bounding boxes, and labels for a single frame.

Parameters

image : np.ndarray (H,W,3) Video frame in BGR order. frame_overlay : dict Entry from overlay_data["per_frame"][frame]. id_colors : dict Mapping produced by prepare_overlay.

Source code in src/mosaic/behavior/visualization_library/overlay.py
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
def draw_frame(
    image: np.ndarray,
    frame_overlay: dict,
    id_colors: dict,
    show_labels: bool = True,
    point_radius: int = 8,
    bbox_thickness: int = 2,
    show_individual_bboxes: bool = True,
    scale: Tuple[float, float] = (1.0, 1.0),
    color_feature: Optional[str] = None,
    color_mode: Optional[str] = None,
    pair_box_feature: Optional[str] = None,
    pair_box_behaviors: Optional[Iterable[Any]] = None,
    hide_individual_bboxes_for_pair: bool = False,
) -> np.ndarray:
    """
    Draw pose points, bounding boxes, and labels for a single frame.

    Parameters
    ----------
    image : np.ndarray (H,W,3)
        Video frame in BGR order.
    frame_overlay : dict
        Entry from overlay_data["per_frame"][frame].
    id_colors : dict
        Mapping produced by prepare_overlay.
    """
    canvas = image.copy()
    sx, sy = scale
    ids = frame_overlay.get("ids", {})
    frame_color = frame_overlay.get("frame_color")
    render_layers = frame_overlay.get("render_layers") or {}

    # Layer-driven style overrides applied before drawing.
    for style in render_layers.get("id_styles", []) or []:
        if not isinstance(style, dict):
            continue
        raw_id = style.get("id")
        key = _resolve_id_key(ids, raw_id)
        if key is None:
            continue
        info = ids.get(key) or {}
        if "color" in style:
            info["color"] = _coerce_color(
                style.get("color"), info.get("color") or id_colors.get(key, (0, 255, 0))
            )
        lbl = style.get("label")
        if lbl is not None:
            label_key = style.get("label_key", "overlay")
            labels_map = info.setdefault("labels", {})
            labels_map[str(label_key)] = lbl
        ids[key] = info

    pair_boxes = []
    ids_in_pair_boxes = set()
    pair_labels_all = frame_overlay.get("pair_labels") or {}
    behavior_set = {
        str(v).strip().lower() for v in (pair_box_behaviors or []) if str(v).strip()
    }

    selected_pair_feature = pair_box_feature
    if selected_pair_feature is None and color_mode == "gt" and "gt" in pair_labels_all:
        selected_pair_feature = "gt"
    if selected_pair_feature:
        for feat_name in pair_labels_all.keys():
            if (
                str(feat_name).strip().lower()
                == str(selected_pair_feature).strip().lower()
            ):
                selected_pair_feature = feat_name
                break

    if selected_pair_feature and behavior_set:
        pair_map = pair_labels_all.get(selected_pair_feature, {})
        grouped = {}

        def _canon_pair(a, b):
            try:
                return tuple(sorted((a, b)))
            except TypeError:
                return tuple(sorted((a, b), key=lambda v: str(v)))

        for pair, val in pair_map.items():
            if not isinstance(pair, (tuple, list)) or len(pair) != 2:
                continue
            lbl_norm = str(val).strip().lower()
            if lbl_norm not in behavior_set:
                continue

            key_a = _resolve_id_key(ids, pair[0])
            key_b = _resolve_id_key(ids, pair[1])
            if key_a is None or key_b is None:
                continue
            info_a = ids.get(key_a) or {}
            info_b = ids.get(key_b) or {}
            if "bbox" not in info_a or "bbox" not in info_b:
                continue
            xa1, ya1, xa2, ya2 = info_a["bbox"]
            xb1, yb1, xb2, yb2 = info_b["bbox"]
            if not all(np.isfinite([xa1, ya1, xa2, ya2, xb1, yb1, xb2, yb2])):
                continue

            union = (
                min(float(xa1), float(xb1)),
                min(float(ya1), float(yb1)),
                max(float(xa2), float(xb2)),
                max(float(ya2), float(yb2)),
            )

            src = key_a
            dst = key_b
            canon = _canon_pair(src, dst)
            grouped.setdefault(canon, []).append(
                {
                    "src": src,
                    "dst": dst,
                    "label": val,
                    "label_norm": lbl_norm,
                    "bbox": union,
                    "color": _color_for_label(val),
                }
            )
            ids_in_pair_boxes.update({key_a, key_b})

        for _, entries in grouped.items():
            # Remove exact duplicates (same direction + same normalized label).
            unique_entries = {}
            for e in entries:
                k = (e["src"], e["dst"], e["label_norm"])
                if k not in unique_entries:
                    unique_entries[k] = e
            collapsed = list(unique_entries.values())
            if not collapsed:
                continue

            # If both directions carry the same behavior label, draw only once.
            label_norms = {e["label_norm"] for e in collapsed}
            if len(label_norms) == 1:
                e = collapsed[0]
                pair_boxes.append(
                    {
                        "pair": (e["src"], e["dst"]),
                        "label": e["label"],
                        "bbox": e["bbox"],
                        "color": e["color"],
                        "offset_px": 0,
                    }
                )
                continue

            # Asymmetric case: keep directional entries and offset them for visibility.
            collapsed.sort(
                key=lambda e: (str(e["src"]), str(e["dst"]), str(e["label"]))
            )
            for idx, e in enumerate(collapsed):
                pair_boxes.append(
                    {
                        "pair": (e["src"], e["dst"]),
                        "label": f"{e['src']}->{e['dst']}:{_format_label_text(e['label'])}",
                        "bbox": e["bbox"],
                        "color": e["color"],
                        "offset_px": idx * 3,
                    }
                )

    for id_val, info in ids.items():
        base_color = info.get("color")
        if base_color is None:
            base_color = id_colors.get(id_val, (0, 255, 0))
        color = tuple(int(c) for c in base_color)
        if show_individual_bboxes and "bbox" in info:
            x1, y1, x2, y2 = info["bbox"]
            if all(np.isfinite([x1, y1, x2, y2])) and not (
                hide_individual_bboxes_for_pair and id_val in ids_in_pair_boxes
            ):
                pt1 = (int(x1 * sx), int(y1 * sy))
                pt2 = (int(x2 * sx), int(y2 * sy))
                cv2.rectangle(canvas, pt1, pt2, color, bbox_thickness)
        if "pose" in info:
            for x, y in info["pose"]:
                if not np.isfinite(x) or not np.isfinite(y):
                    continue
                pt = (int(x * sx), int(y * sy))
                cv2.circle(canvas, pt, point_radius, color, -1, lineType=cv2.LINE_AA)
        if show_labels and (info.get("labels") or color_mode == "gt"):
            labels_map = info.get("labels") or {}
            dominant = None
            if color_feature and color_feature in labels_map:
                dominant = labels_map[color_feature]
            elif color_mode == "gt":
                dominant = labels_map.get("gt")
                if dominant is None:
                    global_label = frame_overlay.get("global_labels", {})
                    dominant = global_label.get("label_name") or global_label.get(
                        "label_id"
                    )
            label_text = None
            if dominant is not None:
                label_text = _format_label_text(dominant)
            elif labels_map:
                label_text = " | ".join(
                    f"{k}:{_format_label_text(v)}" for k, v in labels_map.items()
                )
            if not label_text:
                continue
            anchor = None
            if "bbox" in info:
                x1, y1, _, _ = info["bbox"]
                anchor = (x1, y1)
            if anchor is None:
                anchor = info.get("centroid")
            if anchor and all(np.isfinite(anchor)):
                pos = (int(anchor[0] * sx), int(anchor[1] * sy) - 4)
                cv2.putText(
                    canvas,
                    str(label_text),
                    pos,
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.45,
                    color,
                    1,
                    cv2.LINE_AA,
                )

    # Draw pair-level boxes after per-id overlays so they stay visible.
    for pb in pair_boxes:
        x1, y1, x2, y2 = pb["bbox"]
        color = tuple(int(c) for c in pb["color"])
        off = int(pb.get("offset_px", 0))
        pt1 = (int(x1 * sx) - off, int(y1 * sy) - off)
        pt2 = (int(x2 * sx) + off, int(y2 * sy) + off)
        cv2.rectangle(canvas, pt1, pt2, color, max(1, bbox_thickness + 1))
        if show_labels:
            lbl = _format_label_text(pb["label"])
            if lbl:
                pos = (pt1[0], max(12, pt1[1] - 6))
                cv2.putText(
                    canvas,
                    str(lbl),
                    pos,
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    color,
                    2,
                    cv2.LINE_AA,
                )

    # Draw group-level outlines (union over member ids).
    for item in render_layers.get("group_outlines", []) or []:
        if not isinstance(item, dict):
            continue
        members = item.get("ids") or []
        rects = []
        for rid in members:
            key = _resolve_id_key(ids, rid)
            if key is None:
                continue
            info = ids.get(key) or {}
            bbox = info.get("bbox")
            if (
                isinstance(bbox, (list, tuple))
                and len(bbox) == 4
                and all(np.isfinite(bbox))
            ):
                rects.append(tuple(float(v) for v in bbox))
                continue
            anchor = _anchor_for_id_info(info)
            if anchor is not None:
                x, y = anchor
                r = float(item.get("fallback_radius", 20.0))
                rects.append((x - r, y - r, x + r, y + r))
        if not rects:
            continue
        x1 = min(r[0] for r in rects)
        y1 = min(r[1] for r in rects)
        x2 = max(r[2] for r in rects)
        y2 = max(r[3] for r in rects)

        color = _coerce_color(
            item.get("color"), _color_for_label(item.get("group_size", len(members)))
        )
        thickness = int(item.get("thickness", max(1, bbox_thickness + 1)))
        pt1 = (int(x1 * sx), int(y1 * sy))
        pt2 = (int(x2 * sx), int(y2 * sy))
        cv2.rectangle(canvas, pt1, pt2, color, thickness)
        if show_labels:
            lbl = _format_label_text(item.get("label"))
            if lbl:
                pos = (pt1[0], max(12, pt1[1] - 6))
                cv2.putText(
                    canvas,
                    str(lbl),
                    pos,
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    color,
                    2,
                    cv2.LINE_AA,
                )

    # Draw optional line edges between ids.
    for edge in render_layers.get("edges", []) or []:
        if not isinstance(edge, dict):
            continue
        p1 = None
        p2 = None
        if "p1" in edge and "p2" in edge:
            p1 = edge.get("p1")
            p2 = edge.get("p2")
        else:
            id_a = edge.get("id1")
            id_b = edge.get("id2")
            key_a = _resolve_id_key(ids, id_a)
            key_b = _resolve_id_key(ids, id_b)
            if key_a is not None and key_b is not None:
                p1 = _anchor_for_id_info(ids.get(key_a) or {})
                p2 = _anchor_for_id_info(ids.get(key_b) or {})
        if p1 is None or p2 is None:
            continue
        if not all(np.isfinite([p1[0], p1[1], p2[0], p2[1]])):
            continue
        color = _coerce_color(edge.get("color"), (255, 255, 255))
        thickness = int(edge.get("thickness", 1))
        cv2.line(
            canvas,
            (int(p1[0] * sx), int(p1[1] * sy)),
            (int(p2[0] * sx), int(p2[1] * sy)),
            color,
            thickness,
            lineType=cv2.LINE_AA,
        )

    # Draw optional vectors (velocity arrows or any directional primitive).
    for vec in render_layers.get("vectors", []) or []:
        if not isinstance(vec, dict):
            continue
        origin = vec.get("origin")
        if origin is None and "id" in vec:
            key = _resolve_id_key(ids, vec.get("id"))
            if key is not None:
                origin = _anchor_for_id_info(ids.get(key) or {})
        if origin is None:
            continue
        dx = float(vec.get("dx", 0.0))
        dy = float(vec.get("dy", 0.0))
        scale_v = float(vec.get("scale", 1.0))
        ox, oy = float(origin[0]), float(origin[1])
        ex, ey = ox + dx * scale_v, oy + dy * scale_v
        if not all(np.isfinite([ox, oy, ex, ey])):
            continue
        color = _coerce_color(vec.get("color"), (255, 255, 255))
        thickness = int(vec.get("thickness", 1))
        tip_len = float(vec.get("tip_length", 0.2))
        cv2.arrowedLine(
            canvas,
            (int(ox * sx), int(oy * sy)),
            (int(ex * sx), int(ey * sy)),
            color,
            thickness,
            line_type=cv2.LINE_AA,
            tipLength=tip_len,
        )

    # Draw optional polygons (e.g., ROIs).
    for poly in render_layers.get("polygons", []) or []:
        if not isinstance(poly, dict):
            continue
        pts = poly.get("points") or []
        if len(pts) < 2:
            continue
        out_pts = []
        for p in pts:
            if not isinstance(p, (list, tuple)) or len(p) != 2:
                continue
            if not all(np.isfinite(p)):
                continue
            out_pts.append([int(float(p[0]) * sx), int(float(p[1]) * sy)])
        if len(out_pts) < 2:
            continue
        arr = np.asarray(out_pts, dtype=np.int32).reshape((-1, 1, 2))
        color = _coerce_color(poly.get("color"), (255, 255, 255))
        thickness = int(poly.get("thickness", 1))
        if poly.get("fill", False):
            alpha = float(poly.get("alpha", 0.2))
            tmp = canvas.copy()
            cv2.fillPoly(tmp, [arr], color)
            canvas = cv2.addWeighted(tmp, alpha, canvas, 1.0 - alpha, 0.0)
        cv2.polylines(
            canvas,
            [arr],
            bool(poly.get("closed", True)),
            color,
            thickness,
            lineType=cv2.LINE_AA,
        )

    # Draw optional free text labels.
    for txt in render_layers.get("texts", []) or []:
        if not isinstance(txt, dict):
            continue
        pos = txt.get("pos")
        if not isinstance(pos, (list, tuple)) or len(pos) != 2:
            continue
        if not all(np.isfinite(pos)):
            continue
        text = txt.get("text")
        if text is None:
            continue
        color = _coerce_color(txt.get("color"), (255, 255, 255))
        scale_txt = float(txt.get("font_scale", 0.5))
        thickness = int(txt.get("thickness", 1))
        cv2.putText(
            canvas,
            str(text),
            (int(float(pos[0]) * sx), int(float(pos[1]) * sy)),
            cv2.FONT_HERSHEY_SIMPLEX,
            scale_txt,
            color,
            thickness,
            cv2.LINE_AA,
        )

    global_labels = frame_overlay.get("global_labels")
    if global_labels and color_mode != "gt":
        text = ", ".join(f"{k}:{v}" for k, v in global_labels.items() if v is not None)
        if text:
            cv2.putText(
                canvas,
                text,
                (10, 20),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.6,
                (255, 255, 255),
                2,
                cv2.LINE_AA,
            )
    return canvas

prepare_overlay

prepare_overlay(tracks_df: DataFrame, labels: dict, gt_df: Optional[DataFrame] = None, kinds: Iterable[str] = ('pose', 'bbox'), color_by: Optional[str] = None, hide_unlabeled: bool = False) -> dict

Precompute lightweight per-frame overlay structures (pose keypoints, bounding boxes, labels).

Parameters

tracks_df : DataFrame Output of load_tracks_and_labels()[0]. labels : dict Output of load_tracks_and_labels()[1]. gt_df : DataFrame, optional Output of load_ground_truth_labels (used as global per-frame labels). kinds : Iterable[str] Overlay primitives to compute ("pose", "bbox").

Returns

dict with keys: frames: sorted list of frame numbers per_frame: {frame -> {"ids": {id -> info}, "global_labels": {...}}} id_colors: {id -> (B,G,R)}

Source code in src/mosaic/behavior/visualization_library/overlay.py
def prepare_overlay(
    tracks_df: pd.DataFrame,
    labels: dict,
    gt_df: Optional[pd.DataFrame] = None,
    kinds: Iterable[str] = ("pose", "bbox"),
    color_by: Optional[str] = None,
    hide_unlabeled: bool = False,
) -> dict:
    """
    Precompute lightweight per-frame overlay structures (pose keypoints, bounding boxes, labels).

    Parameters
    ----------
    tracks_df : DataFrame
        Output of load_tracks_and_labels()[0].
    labels : dict
        Output of load_tracks_and_labels()[1].
    gt_df : DataFrame, optional
        Output of load_ground_truth_labels (used as global per-frame labels).
    kinds : Iterable[str]
        Overlay primitives to compute ("pose", "bbox").

    Returns
    -------
    dict with keys:
        frames: sorted list of frame numbers
        per_frame: {frame -> {"ids": {id -> info}, "global_labels": {...}}}
        id_colors: {id -> (B,G,R)}
    """
    if tracks_df.empty:
        raise ValueError("tracks_df is empty; cannot build overlay.")
    kinds = tuple(kinds)
    pose_pairs = pose_column_pairs(tracks_df.columns)

    # Precompute label sources for quick lookup
    per_id_labels = labels.get("per_id", {})
    per_pair_labels = labels.get("per_pair", {})

    gt_global_map, gt_pair_map = _build_gt_maps(
        gt_df if gt_df is not None else pd.DataFrame()
    )

    per_frame: dict[int, dict[str, Any]] = {}
    id_colors: dict[Any, Tuple[int, int, int]] = {}
    label_colors: dict[str, Tuple[int, int, int]] = {}
    color_mode = (color_by or "").strip().lower()
    color_feature = None
    if color_mode and color_mode != "gt":
        feature_names = list(
            dict.fromkeys([*per_id_labels.keys(), *per_pair_labels.keys()])
        )
        for feat in feature_names:
            if feat.lower() == color_mode:
                color_feature = feat
                break

    centroid_cols = [("X#wcentroid", "Y#wcentroid"), ("X", "Y")]

    grouped = tracks_df.groupby("frame", sort=True)
    for frame_val, frame_df in grouped:
        frame_int = int(frame_val)
        id_infos: dict[Any, dict[str, Any]] = {}
        global_labels = gt_global_map.get(frame_int, {})
        frame_pair_labels: dict[str, dict[tuple[int, int], Any]] = {}

        # Pair labels from model/feature outputs
        for feat_name, per_pair_map in per_pair_labels.items():
            frame_pairs: dict[tuple[int, int], Any] = {}
            for pair, series in per_pair_map.items():
                if not isinstance(pair, (tuple, list)) or len(pair) != 2:
                    continue
                val = _scalar_from_series(series.get(frame_int))
                if val is None or (isinstance(val, float) and np.isnan(val)):
                    continue
                try:
                    a = int(pair[0])
                    b = int(pair[1])
                except Exception:
                    continue
                frame_pairs[tuple(sorted((a, b)))] = val
            if frame_pairs:
                frame_pair_labels[feat_name] = frame_pairs

        # Pair labels from GT rows
        gt_pairs_here = gt_pair_map.get(frame_int, {})
        if gt_pairs_here:
            gt_pairs_out: dict[tuple[int, int], Any] = {}
            for pair, ent in gt_pairs_here.items():
                if not isinstance(pair, (tuple, list)) or len(pair) != 2:
                    continue
                val = ent.get("label_name") or ent.get("label_id")
                if val is None or (isinstance(val, float) and np.isnan(val)):
                    continue
                try:
                    a = int(pair[0])
                    b = int(pair[1])
                except Exception:
                    continue
                gt_pairs_out[tuple(sorted((a, b)))] = val
            if gt_pairs_out:
                frame_pair_labels["gt"] = gt_pairs_out

        frame_color = None
        if color_mode == "gt" and global_labels:
            label_val = global_labels.get("label_name") or global_labels.get("label_id")
            if label_val is not None:
                color_key = f"gt:{label_val}"
                if color_key not in label_colors:
                    label_colors[color_key] = _color_for_label(label_val)
                frame_color = label_colors[color_key]
        for _, row in frame_df.iterrows():
            id_val = row.get("id")
            if pd.isna(id_val):
                continue
            info: dict[str, Any] = {}
            centroid = _extract_centroid(row, centroid_cols)
            if "pose" in kinds and pose_pairs:
                pose_pts = _extract_pose_points(row, pose_pairs)
                if pose_pts:
                    info["pose"] = pose_pts
            if "bbox" in kinds and info.get("pose"):
                info["bbox"] = _compute_bbox(info["pose"])
            if centroid:
                info["centroid"] = centroid

            labels_for_id: dict[str, Any] = {}
            if color_mode == "gt":
                gt_entry = _collect_gt_for_id(gt_pair_map.get(frame_int, {}), id_val)
                if gt_entry is None and global_labels:
                    gt_entry = global_labels
                if gt_entry is not None:
                    gt_val = gt_entry.get("label_name") or gt_entry.get("label_id")
                    if gt_val is not None and not (
                        isinstance(gt_val, float) and np.isnan(gt_val)
                    ):
                        labels_for_id["gt"] = gt_val
            feature_names = list(
                dict.fromkeys([*per_id_labels.keys(), *per_pair_labels.keys()])
            )
            for feat_name in feature_names:
                per_id_map = per_id_labels.get(feat_name, {})
                series = (
                    _lookup_label_series(per_id_map, id_val) if per_id_map else None
                )
                if series is not None:
                    val = _scalar_from_series(series.get(frame_int))
                    if val is not None and not (
                        isinstance(val, float) and np.isnan(val)
                    ):
                        labels_for_id[feat_name] = val
                        continue
                per_pair_map = per_pair_labels.get(feat_name, {})
                if per_pair_map:
                    pair_val = _collect_pair_label_for_id(
                        per_pair_map, id_val, frame_int
                    )
                    if pair_val is not None:
                        labels_for_id[feat_name] = pair_val
            if labels_for_id:
                info["labels"] = labels_for_id

            if hide_unlabeled and not labels_for_id:
                continue
            if not info:
                continue
            color = None
            if color_feature:
                label_val = labels_for_id.get(color_feature)
                if label_val is not None:
                    color_key = f"{color_feature}:{label_val}"
                    if color_key not in label_colors:
                        label_colors[color_key] = _color_for_label(label_val)
                    color = label_colors[color_key]
            elif color_mode == "gt":
                gt_val = labels_for_id.get("gt")
                if gt_val is not None:
                    color_key = f"gt:{gt_val}"
                    if color_key not in label_colors:
                        label_colors[color_key] = _color_for_label(gt_val)
                    color = label_colors[color_key]
                elif frame_color is not None:
                    color = frame_color
            if color is None:
                color = id_colors.get(id_val)
                if color is None:
                    color = _color_for_id(id_val)
                    id_colors[id_val] = color
            info["color"] = color
            id_infos[id_val] = info

        if not id_infos and not global_labels and not frame_pair_labels:
            continue
        per_frame[frame_int] = {
            "ids": id_infos,
            "global_labels": global_labels,
            "frame_color": frame_color,
            "pair_labels": frame_pair_labels,
        }

    frames = sorted(per_frame.keys())
    return {
        "frames": frames,
        "per_frame": per_frame,
        "id_colors": id_colors,
        "color_mode": color_mode,
        "color_feature": color_feature,
    }

build_overlay

build_overlay(ds, group: str, sequence: str, feature_runs: Dict[str, Optional[str]], label_kind: Optional[str] = 'behavior', color_by: Optional[str] = None, label_maps: Optional[Dict[str, dict]] = None, hide_unlabeled: bool = False, visualization_spec: Optional[dict] = None) -> Tuple[dict, Any, Dict[str, Any]]

Build a base overlay (and optional spec layers), returning overlay/tracks/labels.

Source code in src/mosaic/behavior/visualization_library/playback.py
def build_overlay(
    ds,
    group: str,
    sequence: str,
    feature_runs: Dict[str, Optional[str]],
    label_kind: Optional[str] = "behavior",
    color_by: Optional[str] = None,
    label_maps: Optional[Dict[str, dict]] = None,
    hide_unlabeled: bool = False,
    visualization_spec: Optional[dict] = None,
) -> Tuple[dict, Any, Dict[str, Any]]:
    """Build a base overlay (and optional spec layers), returning overlay/tracks/labels."""
    tracks_df, labels = load_tracks_and_labels(ds, group, sequence, feature_runs)

    if label_maps:
        for feat, mapping in label_maps.items():
            per_id = labels.get("per_id", {}).get(feat, {})
            for key, series in list(per_id.items()):
                per_id[key] = series.map(mapping).fillna(series)

    gt_df = None
    if label_kind:
        try:
            gt_df = load_ground_truth_labels(ds, label_kind, group, sequence)
        except FileNotFoundError as exc:
            print(f"[build_overlay] warning: {exc}")

    overlay = prepare_overlay(tracks_df, labels, gt_df=gt_df, color_by=color_by, hide_unlabeled=hide_unlabeled)
    if visualization_spec:
        apply_visualization_spec(overlay, tracks_df, labels, visualization_spec)
    return overlay, tracks_df, labels

play_video

play_video(ds, group: str, sequence: str, feature_runs: Dict[str, Optional[str]], label_kind: Optional[str] = 'behavior', color_by: Optional[str] = None, label_maps: Optional[Dict[str, dict]] = None, hide_unlabeled: bool = False, overlay_data: Optional[dict] = None, start: int = 0, end: Optional[int] = None, downscale: float = 1.0, draw_options: Optional[Dict[str, Any]] = None, show_individual_bboxes: bool = True, pair_box_feature: Optional[str] = None, pair_box_behaviors: Optional[Iterable[Any]] = None, hide_individual_bboxes_for_pair: bool = False, output_path: Optional[Path | str] = None, show_window: bool = True, window_name: Optional[str] = None, visualization_spec: Optional[dict] = None) -> Optional[Path]

Stream a video with overlays; optionally save to disk.

Parameters

ds : Dataset Loaded Dataset instance. group : str Group name. sequence : str Sequence name. feature_runs : dict[str, str | None] Mapping of feature/model storage names -> run_id. label_kind : str, optional Kind of labels to load (default "behavior"). color_by : str, optional Feature name to color by, or "gt" for ground-truth. label_maps : dict[str, dict], optional Optional mapping per feature to replace numeric labels with names, e.g. {"behavior-xgb-pred__from__...": {0: "attack", 1: "investigation", ...}}. hide_unlabeled : bool If True, skip drawing ids that lack labels (after any filtering/mapping). overlay_data : dict, optional Precomputed overlay from prepare_overlay(). If provided, skips rebuilding overlay (useful when you want to pre-filter labels before playback). start : int Starting frame index. end : int, optional Ending frame index. downscale : float Downscale factor (1.0 = no scaling). draw_options : dict, optional Optional frame-drawing options. Allowed keys: "show_labels", "point_radius", "bbox_thickness". You can also store defaults in overlay_data["draw_options"]. show_individual_bboxes : bool If False, skip drawing per-id bounding boxes while keeping pose points/labels. pair_box_feature : str, optional Pair-label feature to inspect when drawing union boxes. pair_box_behaviors : iterable, optional Behavior values that should trigger pair-level boxes. hide_individual_bboxes_for_pair : bool If True, do not draw per-id boxes for ids participating in selected pair boxes. output_path : Path or str, optional If provided, saves video to this path. show_window : bool If True, displays video in a window. window_name : str, optional Name for the display window. visualization_spec : dict, optional Optional spec with extra render layers and playback overrides.

Returns

Path or None Path to the saved video file if output_path was provided.

Keyboard Controls
  • q or Esc: Quit
  • Space: Pause/resume
  • d: Step one frame (while paused)
  • s: Save current frame as PNG
Source code in src/mosaic/behavior/visualization_library/playback.py
def play_video(ds,
               group: str,
               sequence: str,
               feature_runs: Dict[str, Optional[str]],
               label_kind: Optional[str] = "behavior",
               color_by: Optional[str] = None,
               label_maps: Optional[Dict[str, dict]] = None,
               hide_unlabeled: bool = False,
               overlay_data: Optional[dict] = None,
               start: int = 0,
               end: Optional[int] = None,
               downscale: float = 1.0,
               draw_options: Optional[Dict[str, Any]] = None,
               show_individual_bboxes: bool = True,
               pair_box_feature: Optional[str] = None,
               pair_box_behaviors: Optional[Iterable[Any]] = None,
               hide_individual_bboxes_for_pair: bool = False,
               output_path: Optional[Path | str] = None,
               show_window: bool = True,
               window_name: Optional[str] = None,
               visualization_spec: Optional[dict] = None) -> Optional[Path]:
    """
    Stream a video with overlays; optionally save to disk.

    Parameters
    ----------
    ds : Dataset
        Loaded Dataset instance.
    group : str
        Group name.
    sequence : str
        Sequence name.
    feature_runs : dict[str, str | None]
        Mapping of feature/model storage names -> run_id.
    label_kind : str, optional
        Kind of labels to load (default "behavior").
    color_by : str, optional
        Feature name to color by, or "gt" for ground-truth.
    label_maps : dict[str, dict], optional
        Optional mapping per feature to replace numeric labels with names, e.g.
        {"behavior-xgb-pred__from__...": {0: "attack", 1: "investigation", ...}}.
    hide_unlabeled : bool
        If True, skip drawing ids that lack labels (after any filtering/mapping).
    overlay_data : dict, optional
        Precomputed overlay from prepare_overlay(). If provided, skips rebuilding overlay
        (useful when you want to pre-filter labels before playback).
    start : int
        Starting frame index.
    end : int, optional
        Ending frame index.
    downscale : float
        Downscale factor (1.0 = no scaling).
    draw_options : dict, optional
        Optional frame-drawing options. Allowed keys: "show_labels", "point_radius", "bbox_thickness".
        You can also store defaults in overlay_data["draw_options"].
    show_individual_bboxes : bool
        If False, skip drawing per-id bounding boxes while keeping pose points/labels.
    pair_box_feature : str, optional
        Pair-label feature to inspect when drawing union boxes.
    pair_box_behaviors : iterable, optional
        Behavior values that should trigger pair-level boxes.
    hide_individual_bboxes_for_pair : bool
        If True, do not draw per-id boxes for ids participating in selected pair boxes.
    output_path : Path or str, optional
        If provided, saves video to this path.
    show_window : bool
        If True, displays video in a window.
    window_name : str, optional
        Name for the display window.
    visualization_spec : dict, optional
        Optional spec with extra render layers and playback overrides.

    Returns
    -------
    Path or None
        Path to the saved video file if output_path was provided.

    Keyboard Controls
    -----------------
    - q or Esc: Quit
    - Space: Pause/resume
    - d: Step one frame (while paused)
    - s: Save current frame as PNG
    """
    overlay = overlay_data
    spec_playback = playback_kwargs_from_spec(visualization_spec) if visualization_spec else {}
    if overlay is None:
        overlay, _, _ = build_overlay(
            ds=ds,
            group=group,
            sequence=sequence,
            feature_runs=feature_runs,
            label_kind=label_kind,
            color_by=color_by,
            label_maps=label_maps,
            hide_unlabeled=hide_unlabeled,
            visualization_spec=visualization_spec,
        )
    elif label_maps:
        _remap_overlay_labels(overlay, label_maps)

    # Spec can provide defaults, direct args still win.
    if pair_box_feature is None:
        pair_box_feature = spec_playback.get("pair_box_feature")
    if pair_box_behaviors is None:
        pair_box_behaviors = spec_playback.get("pair_box_behaviors")
    show_individual_bboxes = bool(spec_playback.get("show_individual_bboxes", show_individual_bboxes))
    hide_individual_bboxes_for_pair = bool(
        spec_playback.get("hide_individual_bboxes_for_pair", hide_individual_bboxes_for_pair)
    )

    video_paths = ds.resolve_media_paths(group, sequence)

    stream = render_stream(
        video_paths,
        overlay,
        start=start,
        end=end,
        downscale=downscale,
        draw_options=draw_options,
        show_individual_bboxes=show_individual_bboxes,
        pair_box_feature=pair_box_feature,
        pair_box_behaviors=pair_box_behaviors,
        hide_individual_bboxes_for_pair=hide_individual_bboxes_for_pair,
    )
    writer = None
    out_path = None
    if output_path:
        out_path = Path(output_path).expanduser()
        out_path.parent.mkdir(parents=True, exist_ok=True)
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        frame_size = getattr(stream, "frame_size", (0, 0))
        writer = cv2.VideoWriter(str(out_path), fourcc, float(getattr(stream, "fps", 30.0)), frame_size)
        if not writer.isOpened():
            writer = None
            out_path = None
            print("[play_video] warning: failed to open VideoWriter; skipping output file.")

    win = window_name or f"{group}:{sequence}"
    try:
        stream_iter = iter(stream)
        paused = False
        step_once = False
        current = None
        frame_idx = None
        while True:
            if not paused or step_once or current is None:
                try:
                    frame_idx, current = next(stream_iter)
                except StopIteration:
                    break
            if writer:
                writer.write(current)
            if show_window:
                cv2.imshow(win, current)
                delay = 1 if not paused else 50
                key = cv2.waitKey(delay) & 0xFF
            else:
                key = -1

            if key == ord("q") or key == 27:
                break
            elif key == ord(" "):
                paused = not paused
                step_once = False
            elif key == ord("s") and frame_idx is not None:
                snap_path = Path(f"frame_{frame_idx}.png")
                cv2.imwrite(str(snap_path), current)
                print(f"[play_video] saved frame -> {snap_path}")
            elif key == ord("d"):
                paused = True
                step_once = True
            else:
                step_once = False
    finally:
        if hasattr(stream, "close"):
            stream.close()
        if writer:
            writer.release()
        if show_window:
            try:
                cv2.destroyWindow(win)
                cv2.waitKey(1)
            except cv2.error:
                pass
    return out_path

play_video_with_spec

play_video_with_spec(ds, group: str, sequence: str, feature_runs: Dict[str, Optional[str]], visualization_spec: dict, **kwargs: Any) -> Optional[Path]

Convenience wrapper: build overlay from tracks/labels + visualization_spec, then play/save.

Any explicit kwargs are forwarded to play_video and override spec playback defaults.

Source code in src/mosaic/behavior/visualization_library/playback.py
def play_video_with_spec(
    ds,
    group: str,
    sequence: str,
    feature_runs: Dict[str, Optional[str]],
    visualization_spec: dict,
    **kwargs: Any,
) -> Optional[Path]:
    """
    Convenience wrapper: build overlay from tracks/labels + visualization_spec, then play/save.

    Any explicit kwargs are forwarded to play_video and override spec playback defaults.
    """
    overlay, _, _ = build_overlay(
        ds=ds,
        group=group,
        sequence=sequence,
        feature_runs=feature_runs,
        label_kind=kwargs.pop("label_kind", "behavior"),
        color_by=kwargs.pop("color_by", None),
        label_maps=kwargs.pop("label_maps", None),
        hide_unlabeled=kwargs.pop("hide_unlabeled", False),
        visualization_spec=visualization_spec,
    )
    return play_video(
        ds=ds,
        group=group,
        sequence=sequence,
        feature_runs=feature_runs,
        overlay_data=overlay,
        visualization_spec=visualization_spec,
        **kwargs,
    )

render_stream

render_stream(video_paths: Union[list[Path], Path, str], overlay_data: dict, start: int = 0, end: Optional[int] = None, downscale: float = 1.0, show_individual_bboxes: bool = True, pair_box_feature: Optional[str] = None, pair_box_behaviors: Optional[Iterable[Any]] = None, hide_individual_bboxes_for_pair: bool = False, draw_options: Optional[Dict[str, Any]] = None) -> _FrameStream

Return an iterable that yields (frame_index, frame_bgr_with_overlay).

Parameters

video_paths : list[Path], Path, or str Path(s) to the video file(s). For multi-video sequences, pass an ordered list of Paths. A single Path/str is also accepted. overlay_data : dict Output from prepare_overlay() start : int Starting frame index end : int, optional Ending frame index (inclusive). If None, streams to end of video. downscale : float Downscale factor (1.0 = no scaling, 0.5 = half size) show_individual_bboxes : bool If False, skip drawing per-id bounding boxes while keeping pose points/labels. pair_box_feature : str, optional Pair-label feature to inspect when drawing union boxes. pair_box_behaviors : iterable, optional Behavior values that should trigger pair-level boxes. hide_individual_bboxes_for_pair : bool If True, do not draw per-id boxes for ids participating in selected pair boxes. draw_options : dict, optional Optional frame-drawing options. Allowed keys: "show_labels", "point_radius", "bbox_thickness".

Returns

_FrameStream Iterator yielding (frame_index, frame_bgr) tuples

Source code in src/mosaic/behavior/visualization_library/video_stream.py
def render_stream(video_paths: Union[list[Path], Path, str],
                  overlay_data: dict,
                  start: int = 0,
                  end: Optional[int] = None,
                  downscale: float = 1.0,
                  show_individual_bboxes: bool = True,
                  pair_box_feature: Optional[str] = None,
                  pair_box_behaviors: Optional[Iterable[Any]] = None,
                  hide_individual_bboxes_for_pair: bool = False,
                  draw_options: Optional[Dict[str, Any]] = None) -> _FrameStream:
    """
    Return an iterable that yields (frame_index, frame_bgr_with_overlay).

    Parameters
    ----------
    video_paths : list[Path], Path, or str
        Path(s) to the video file(s). For multi-video sequences, pass an
        ordered list of Paths. A single Path/str is also accepted.
    overlay_data : dict
        Output from prepare_overlay()
    start : int
        Starting frame index
    end : int, optional
        Ending frame index (inclusive). If None, streams to end of video.
    downscale : float
        Downscale factor (1.0 = no scaling, 0.5 = half size)
    show_individual_bboxes : bool
        If False, skip drawing per-id bounding boxes while keeping pose points/labels.
    pair_box_feature : str, optional
        Pair-label feature to inspect when drawing union boxes.
    pair_box_behaviors : iterable, optional
        Behavior values that should trigger pair-level boxes.
    hide_individual_bboxes_for_pair : bool
        If True, do not draw per-id boxes for ids participating in selected pair boxes.
    draw_options : dict, optional
        Optional frame-drawing options. Allowed keys: "show_labels", "point_radius", "bbox_thickness".

    Returns
    -------
    _FrameStream
        Iterator yielding (frame_index, frame_bgr) tuples
    """
    from mosaic.media.video_io import MultiVideoReader

    reader = MultiVideoReader(video_paths)
    base_size = (reader.width, reader.height)
    fps = reader.fps
    scaled_size = _scaled_size(base_size, downscale)
    per_frame = overlay_data.get("per_frame", {})
    id_colors = overlay_data.get("id_colors", {})
    color_feature = overlay_data.get("color_feature")
    color_mode = overlay_data.get("color_mode")
    merged_draw_options = {}
    overlay_draw_options = overlay_data.get("draw_options")
    if isinstance(overlay_draw_options, dict):
        merged_draw_options.update({k: v for k, v in overlay_draw_options.items() if k in _ALLOWED_DRAW_OPTIONS})
    if isinstance(draw_options, dict):
        merged_draw_options.update({k: v for k, v in draw_options.items() if k in _ALLOWED_DRAW_OPTIONS})

    return _FrameStream(
        reader, fps, base_size, scaled_size, per_frame, id_colors,
        start, end, color_feature=color_feature, color_mode=color_mode,
        show_individual_bboxes=show_individual_bboxes,
        pair_box_feature=pair_box_feature,
        pair_box_behaviors=pair_box_behaviors,
        hide_individual_bboxes_for_pair=hide_individual_bboxes_for_pair,
        draw_options=merged_draw_options)

apply_visualization_spec

apply_visualization_spec(overlay: dict[str, Any], tracks_df: DataFrame, labels: dict[str, Any], spec: Optional[dict[str, Any]]) -> dict[str, Any]

Apply all enabled layers from a visualization spec into overlay in-place.

Source code in src/mosaic/behavior/visualization_library/visual_spec.py
def apply_visualization_spec(
    overlay: dict[str, Any],
    tracks_df: pd.DataFrame,
    labels: dict[str, Any],
    spec: Optional[dict[str, Any]],
) -> dict[str, Any]:
    """Apply all enabled layers from a visualization spec into overlay in-place."""
    norm = normalize_visualization_spec(spec)
    if not norm["layers"]:
        return overlay

    ctx = {
        "overlay": overlay,
        "tracks_df": tracks_df,
        "labels": labels,
    }

    for idx, layer in enumerate(norm["layers"]):
        if not isinstance(layer, dict):
            raise TypeError(f"Layer #{idx} must be a dict.")
        if layer.get("enabled", True) is False:
            continue

        layer_type = str(layer.get("type") or "").strip().lower()
        if not layer_type:
            raise ValueError(f"Layer #{idx} is missing 'type'.")
        adapter = _VISUAL_ADAPTERS.get(layer_type)
        if adapter is None:
            available = ", ".join(list_visual_adapters()) or "<none>"
            raise KeyError(
                f"Unknown visualization layer type '{layer_type}'. "
                f"Registered adapters: {available}"
            )
        adapter(ctx, layer)

    return overlay

list_visual_adapters

list_visual_adapters() -> list[str]

List registered adapter names.

Source code in src/mosaic/behavior/visualization_library/visual_spec.py
def list_visual_adapters() -> list[str]:
    """List registered adapter names."""
    return sorted(_VISUAL_ADAPTERS.keys())

normalize_visualization_spec

normalize_visualization_spec(spec: Optional[dict[str, Any]]) -> dict[str, Any]

Normalize user-provided visualization spec.

Source code in src/mosaic/behavior/visualization_library/visual_spec.py
def normalize_visualization_spec(spec: Optional[dict[str, Any]]) -> dict[str, Any]:
    """Normalize user-provided visualization spec."""
    if spec is None:
        return {"layers": [], "playback": {}}
    if not isinstance(spec, dict):
        raise TypeError("visualization_spec must be a dict or None.")

    layers = spec.get("layers") or []
    if not isinstance(layers, list):
        raise TypeError("visualization_spec['layers'] must be a list.")
    playback = spec.get("playback") or {}
    if not isinstance(playback, dict):
        raise TypeError("visualization_spec['playback'] must be a dict.")
    return {"layers": layers, "playback": playback}