Skip to content

Interactive Plotting

WebGL-accelerated Plotly scatter plot with hover metadata (pipeline step 4b).

Note

Requires the plotly extra: pip install 'multiscoresplot[interactive]'

plot_embedding_interactive

plot_embedding_interactive(
    adata_or_coords: object,
    rgb: NDArray,
    *,
    basis: str | None = None,
    components: tuple[int, int] = (0, 1),
    scores: DataFrame | None = None,
    method: str | None = None,
    gene_set_names: list[str] | None = None,
    legend: bool = True,
    legend_loc: str = "lower right",
    legend_size: float = 0.3,
    legend_resolution: int = 128,
    colors: list[tuple[float, float, float]] | None = None,
    hover_columns: list[str] | None = None,
    point_size: float = 2,
    alpha: float = 1.0,
    width: int = 500,
    height: int = 450,
    title: str = "",
    show: bool = True,
) -> object | None

Interactive Plotly scatter plot of embedding coordinates coloured by RGB.

Parameters:

Name Type Description Default
adata_or_coords object

An AnnData object (with basis in .obsm) or a raw (n_cells, 2) coordinate array.

required
rgb NDArray

(n_cells, 3) RGB array from blend_to_rgb or reduce_to_rgb.

required
basis str | None

Embedding key (e.g. "umap", "pca"). Required when adata_or_coords is AnnData.

None
components tuple[int, int]

Which two components to plot (0-indexed).

(0, 1)
scores DataFrame | None

DataFrame with score-* columns. If None and adata_or_coords is AnnData, scores are auto-extracted from adata.obs.

None
method str | None

Reduction method ("pca", "nmf", etc.) used to derive RGB. Controls the channel labels in hover info. If None or "direct", channels are labeled R/G/B.

None
gene_set_names list[str] | None

Human-readable labels for gene set scores in hover info.

None
legend bool

Whether to add a colour-space legend overlay.

True
legend_loc str

Position for the legend ("lower right", "lower left", "upper right", "upper left").

'lower right'
legend_size float

Size of the legend as a fraction of the plot (0-1).

0.3
legend_resolution int

Pixel resolution of the legend image.

128
colors list[tuple[float, float, float]] | None

Base colours for direct-mode legends.

None
hover_columns list[str] | None

Extra columns from adata.obs to include in hover info.

None
point_size float

Scatter marker size.

2
alpha float

Marker opacity.

1.0
width int

Figure width in pixels.

500
height int

Figure height in pixels.

450
title str

Plot title.

''
show bool

If True, call fig.show() and return None. If False, return the plotly.graph_objects.Figure.

True

Returns:

Type Description
Figure or None

The figure when show=False; None when show=True.

Source code in src/multiscoresplot/_interactive.py
def plot_embedding_interactive(
    adata_or_coords: object,
    rgb: NDArray,
    *,
    basis: str | None = None,
    components: tuple[int, int] = (0, 1),
    scores: DataFrame | None = None,
    method: str | None = None,
    gene_set_names: list[str] | None = None,
    # legend
    legend: bool = True,
    legend_loc: str = "lower right",
    legend_size: float = 0.30,
    legend_resolution: int = 128,
    colors: list[tuple[float, float, float]] | None = None,
    # hover / scatter
    hover_columns: list[str] | None = None,
    point_size: float = 2,
    alpha: float = 1.0,
    width: int = 500,
    height: int = 450,
    title: str = "",
    show: bool = True,
) -> object | None:
    """Interactive Plotly scatter plot of embedding coordinates coloured by RGB.

    Parameters
    ----------
    adata_or_coords
        An ``AnnData`` object (with *basis* in ``.obsm``) or a raw
        ``(n_cells, 2)`` coordinate array.
    rgb
        ``(n_cells, 3)`` RGB array from ``blend_to_rgb`` or ``reduce_to_rgb``.
    basis
        Embedding key (e.g. ``"umap"``, ``"pca"``).  Required when
        *adata_or_coords* is AnnData.
    components
        Which two components to plot (0-indexed).
    scores
        DataFrame with ``score-*`` columns.  If *None* and *adata_or_coords*
        is AnnData, scores are auto-extracted from ``adata.obs``.
    method
        Reduction method (``"pca"``, ``"nmf"``, etc.) used to derive RGB.
        Controls the channel labels in hover info.  If *None* or ``"direct"``,
        channels are labeled R/G/B.
    gene_set_names
        Human-readable labels for gene set scores in hover info.
    legend
        Whether to add a colour-space legend overlay.
    legend_loc
        Position for the legend (``"lower right"``, ``"lower left"``,
        ``"upper right"``, ``"upper left"``).
    legend_size
        Size of the legend as a fraction of the plot (0-1).
    legend_resolution
        Pixel resolution of the legend image.
    colors
        Base colours for direct-mode legends.
    hover_columns
        Extra columns from ``adata.obs`` to include in hover info.
    point_size
        Scatter marker size.
    alpha
        Marker opacity.
    width
        Figure width in pixels.
    height
        Figure height in pixels.
    title
        Plot title.
    show
        If *True*, call ``fig.show()`` and return *None*.  If *False*,
        return the ``plotly.graph_objects.Figure``.

    Returns
    -------
    Figure or None
        The figure when ``show=False``; *None* when ``show=True``.
    """
    go = _ensure_plotly()

    coords, basis_label = _extract_coords(adata_or_coords, basis, components)
    n_cells = coords.shape[0]
    rgb = _validate_rgb(rgb, n_cells)

    # Determine if we have an AnnData object
    has_obs = hasattr(adata_or_coords, "obs")

    # --- Build hover text ---
    hover_parts: list[list[str]] = [[] for _ in range(n_cells)]

    # 1. Gene set scores
    score_df: DataFrame | None = scores
    if score_df is None and has_obs:
        obs = adata_or_coords.obs  # type: ignore[attr-defined]
        score_cols = [c for c in obs.columns if c.startswith(SCORE_PREFIX)]
        if score_cols:
            score_df = obs[score_cols]

    if score_df is not None:
        score_cols = [c for c in score_df.columns if c.startswith(SCORE_PREFIX)]
        labels = (
            gene_set_names
            if gene_set_names is not None and len(gene_set_names) == len(score_cols)
            else [c[len(SCORE_PREFIX) :] for c in score_cols]
        )
        score_vals = score_df[score_cols].to_numpy(dtype=np.float64)
        for i in range(n_cells):
            for j, label in enumerate(labels):
                hover_parts[i].append(f"{label}: {score_vals[i, j]:.3f}")

    # 2. RGB channel values
    if method is not None and method != "direct":
        channel_labels = get_component_labels(method)
    else:
        channel_labels = ["R", "G", "B"]

    for i in range(n_cells):
        for j, ch_label in enumerate(channel_labels):
            hover_parts[i].append(f"{ch_label}: {rgb[i, j]:.2f}")

    # 3. Extra .obs columns
    if hover_columns is not None:
        if not has_obs:
            raise ValueError("hover_columns requires an AnnData object, not raw coordinates.")
        obs = adata_or_coords.obs  # type: ignore[attr-defined]
        missing = [c for c in hover_columns if c not in obs.columns]
        if missing:
            raise KeyError(f"Columns not found in adata.obs: {missing}")

        import pandas as _pd

        for col_name in hover_columns:
            col = obs[col_name]
            is_numeric = _pd.api.types.is_numeric_dtype(col)
            for i in range(n_cells):
                val = col.iloc[i]
                if is_numeric:
                    hover_parts[i].append(f"{col_name}: {val:.3f}")
                else:
                    hover_parts[i].append(f"{col_name}: {val}")

    hover_text = ["<br>".join(parts) for parts in hover_parts]

    # --- Build color strings ---
    marker_colors = [
        f"rgba({int(r * 255)},{int(g * 255)},{int(b * 255)},{alpha})" for r, g, b in rgb
    ]

    # --- Axis labels ---
    if basis_label is not None:
        xaxis_title = f"{basis_label}{components[0] + 1}"
        yaxis_title = f"{basis_label}{components[1] + 1}"
    else:
        xaxis_title = ""
        yaxis_title = ""

    # --- Create figure ---
    fig = go.Figure(
        data=go.Scattergl(
            x=coords[:, 0],
            y=coords[:, 1],
            mode="markers",
            marker=dict(
                size=point_size,
                color=marker_colors,
            ),
            hovertext=hover_text,
            hoverinfo="text",
        ),
    )

    _axis_style = dict(
        showticklabels=False,
        ticks="",
        showline=True,
        linecolor="black",
        linewidth=1,
        mirror=True,
    )
    fig.update_layout(
        width=width,
        height=height,
        title=title,
        xaxis=dict(title=xaxis_title, scaleanchor="y", **_axis_style),
        yaxis=dict(title=yaxis_title, **_axis_style),
        plot_bgcolor="white",
    )

    # Infer direct mode for legend when method is unset but gene_set_names has 2-3 entries
    legend_method = method
    if legend_method is None and gene_set_names is not None and len(gene_set_names) in (2, 3):
        legend_method = "direct"

    # Legend (skip when we can't determine the method)
    if (
        legend
        and legend_method is not None
        and (legend_method != "direct" or gene_set_names is not None)
    ):
        _add_plotly_legend(
            fig,
            method=legend_method,
            gene_set_names=gene_set_names,
            colors=colors,
            legend_loc=legend_loc,
            legend_size=legend_size,
            legend_resolution=legend_resolution,
        )

    if show:
        fig.show()
        return None

    return fig  # type: ignore[no-any-return]