Source code for circust.visualization.pipeline_summary

"""
circust/visualization/pipeline_summary.py
==========================================
Figura resumen de extremo a extremo del pipeline CIRCUST (Etapas 1-2).

  plot_pipeline_summary(cpca, outlier, order)
      Figura única multipanel con cuatro vistas diagnósticas clave:

        A. Dispersión PC1 vs PC2 (etapa CPCA)   — muestra estructura circular
        B. Heatmap de residuos (etapa outlier)   — muestra calidad de muestras
        C. Diagrama circular de picos (ordenación) — muestra programa temporal
        D. Gráfico de barras R² (ordenación)     — muestra calidad del ajuste

      Diseñada para informes de tesis y presentaciones.

  plot_variance_explained(cpca)
      Gráfico de barras tipo scree de la varianza explicada por las
      primeras componentes principales, con línea acumulada superpuesta.

  plot_expression_overview(expr_norm, core_genes, circular_scale, n_top)
      Heatmap de expresión normalizada para los n genes más rítmicos,
      ordenados por fase circular. Los genes core se resaltan.
      Proporciona una visión global de la señal circadiana a nivel genómico.

Todas las funciones devuelven un matplotlib Figure. Ninguna llama a plt.show().
"""
from typing import Optional

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
import pandas as pd

from circust.cpca import CPCAResult
from circust.synchronizer import SynchronizationResult


# ---------------------------------------------------------------------------
# Compartido
# ---------------------------------------------------------------------------
_FMM_COLOUR = "#E41A1C"
_COS_COLOUR = "#377EB8"
_DAY_COLOUR = "#FDB863"
_NIGHT_COLOUR = "#5E4FA2"
_ARNTL_COLOUR = "#E41A1C"
_DBP_COLOUR = "#377EB8"

_HOUR_LABELS_8 = [
    "CT12", "CT15", "CT18", "CT21",
    "CT0", "CT3", "CT6", "CT9",
]


# ═══════════════════════════════════════════════════════════════════════════
# Plot 1 — Pipeline summary (4-panel composite)
# ═══════════════════════════════════════════════════════════════════════════

[docs] def plot_pipeline_summary( cpca_result: CPCAResult, order_result: SynchronizationResult, title: str = "", figsize: tuple[float, float] = (14, 10), ) -> Figure: """ Resumen compuesto de cuatro paneles del pipeline CIRCUST. Disposición:: ┌──────────────┬──────────────┐ │ A. PC scatter │ B. Residuals │ ├──────────────┼──────────────┤ │ C. Peaks │ D. R² bars │ └──────────────┴──────────────┘ Parámetros ---------- cpca_result : CPCAResult Salida de ``CPCA.run()`` (incluye CPCA + deteccion de outliers). order_result : SynchronizationResult Salida de CircularSynchronizer. title : str Titulo global de la figura. figsize : tuple Tamano de la figura en pulgadas (ancho, alto). Returns ------- matplotlib.figure.Figure """ fig = plt.figure(figsize=figsize) gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.30) # ── Panel A: dispersión PC1 vs PC2 ────────────────────────────────── ax_a = fig.add_subplot(gs[0, 0]) _draw_pc_scatter(ax_a, cpca_result) ax_a.set_title("A. CPCA — PC1 vs PC2", fontsize=10, fontweight="bold", pad=6) # ── Panel B: heatmap de residuos (compacto) ────────────────────────── ax_b = fig.add_subplot(gs[0, 1]) _draw_residual_compact(ax_b, cpca_result) n_out = len(cpca_result.samples_dropped) ax_b.set_title( f"B. Standardised FMM residuals ({n_out} outliers)", fontsize=10, fontweight="bold", pad=6, ) # ── Panel C: diagrama circular de picos ────────────────────────────── ax_c = fig.add_subplot(gs[1, 0], polar=True) _draw_polar_peaks(ax_c, order_result) ax_c.set_title("C. Core gene peak times", fontsize=10, fontweight="bold", pad=12) # ── Panel D: gráfico de barras R² ──────────────────────────────────── ax_d = fig.add_subplot(gs[1, 1]) _draw_r2_bars(ax_d, order_result) ax_d.set_title("D. FMM R\u00b2 per core gene", fontsize=10, fontweight="bold", pad=6) suptitle = f"{title} \u2014 " if title else "" fig.suptitle( f"{suptitle}CIRCUST Pipeline Summary (Stages 1\u20132)", fontsize=13, fontweight="bold", y=1.01, ) return fig
# --------------------------------------------------------------------------- # Funciones auxiliares de paneles (dibujan en un Axes existente) # --------------------------------------------------------------------------- def _draw_pc_scatter(ax, cpca: "CPCAResult") -> None: """Dispersión PC mini para el panel resumen.""" # Usar proyecciones iniciales para mostrar outliers en su posicion original pc1 = cpca.pc1_initial if len(cpca.pc1_initial) > 0 else cpca.pc1 pc2 = cpca.pc2_initial if len(cpca.pc2_initial) > 0 else cpca.pc2 var = cpca.variance_explained dropped = set(cpca.samples_dropped) # Todas las muestras (normales en gris, eliminadas en rojo) normal_mask = np.array([i not in dropped for i in range(len(pc1))]) dropped_mask = ~normal_mask ax.scatter(pc1[normal_mask], pc2[normal_mask], s=10, color="#CCCCCC", edgecolors="#999999", linewidths=0.3, zorder=3) if dropped_mask.any(): ax.scatter(pc1[dropped_mask], pc2[dropped_mask], s=40, marker="x", color="#E41A1C", linewidths=1.4, zorder=5, label=f"{dropped_mask.sum()} outlier(s)") ax.legend(fontsize=6, loc="upper right", framealpha=0.7) # Círculos de umbral radial (0.10 y 0.15) theta = np.linspace(0, 2 * np.pi, 200) for r, colour, lw in [(0.10, "#E41A1C", 1.2), (0.15, "#AAAAAA", 0.6)]: ax.plot(np.cos(theta) * r, np.sin(theta) * r, "--", color=colour, linewidth=lw, zorder=1) ax.axhline(0, color="#EEEEEE", linewidth=0.4, zorder=0) ax.axvline(0, color="#EEEEEE", linewidth=0.4, zorder=0) ax.set_aspect("equal") maxi = round(max(np.abs(pc1).max(), np.abs(pc2).max()), 1) + 0.05 ax.set_xlim(-maxi, maxi) ax.set_ylim(-maxi, maxi) ax.set_xlabel(f"PC1 ({var[0]*100:.1f}%)", fontsize=8) ax.set_ylabel(f"PC2 ({var[1]*100:.1f}%)", fontsize=8) ax.tick_params(labelsize=6) def _draw_residual_compact(ax, result) -> None: """Diagrama de tiras de residuos compacto para el panel resumen.""" std_res = result.std_residuals_fmm if std_res is None: ax.text(0.5, 0.5, "No residual data", ha="center", va="center", transform=ax.transAxes) return gene_names = list(std_res.index) n_genes = len(gene_names) for i, gene in enumerate(gene_names): row = std_res.loc[gene].values y = n_genes - 1 - i normal = np.abs(row) <= 3 flagged = np.abs(row) > 3 ax.scatter(row[normal], np.full(normal.sum(), y), s=4, color="#BBBBBB", alpha=0.4, zorder=2, linewidths=0) if flagged.any(): ax.scatter(row[flagged], np.full(flagged.sum(), y), s=12, color="#E41A1C", alpha=0.7, zorder=3, linewidths=0) for thresh in [3, 4]: col = "#984EA3" if thresh == 3 else "#E41A1C" ax.axvline(thresh, color=col, linestyle="--", linewidth=0.6, alpha=0.6) ax.axvline(-thresh, color=col, linestyle="--", linewidth=0.6, alpha=0.6) ax.set_yticks(range(n_genes)) ax.set_yticklabels(list(reversed(gene_names)), fontsize=6) ax.set_ylim(-0.5, n_genes - 0.5) ax.axvline(0, color="#EEEEEE", linewidth=0.4, zorder=0) ax.set_xlabel("Std. FMM residual", fontsize=8) ax.tick_params(labelsize=6) ax.spines[["top", "right"]].set_visible(False) def _draw_polar_peaks(ax, result) -> None: """Diagrama circular de picos mini para el panel resumen.""" ax.set_theta_zero_location("N") ax.set_theta_direction(-1) genes = result.core_genes peaks = result.peak_times for i, gene in enumerate(genes): theta = peaks[i] if gene == "ARNTL": colour = _ARNTL_COLOUR elif gene == "DBP": colour = _DBP_COLOUR elif gene in result.day_genes: colour = _DAY_COLOUR else: colour = _NIGHT_COLOUR ax.plot(theta, 0.85, "o", color=colour, markersize=7, zorder=5) ax.plot([theta, theta], [0, 0.85], "-", color=colour, linewidth=0.5, alpha=0.4, zorder=2) ax.text(theta, 1.05, gene, fontsize=5.5, fontweight="bold", color=colour, ha="center", va="center", zorder=6) # Sombreado día/noche day_t = np.linspace(0, np.pi, 100) night_t = np.linspace(np.pi, 2 * np.pi, 100) ax.fill_between(day_t, 0, 0.6, alpha=0.06, color=_DAY_COLOUR) ax.fill_between(night_t, 0, 0.6, alpha=0.06, color=_NIGHT_COLOUR) ax.set_ylim(0, 1.15) ax.set_yticks([]) angles_8 = np.linspace(0, 2 * np.pi, 8, endpoint=False) ax.set_xticks(angles_8) ax.set_xticklabels(_HOUR_LABELS_8, fontsize=5.5, color="#666666") def _draw_r2_bars(ax, result) -> None: """Gráfico de barras R² mini para el panel resumen.""" genes = result.core_genes r2 = result.r2_fmm order = np.argsort(r2)[::-1] for rank, idx in enumerate(order): gene = genes[idx] if gene == "ARNTL": colour = _ARNTL_COLOUR elif gene == "DBP": colour = _DBP_COLOUR elif gene in result.day_genes: colour = _DAY_COLOUR else: colour = _NIGHT_COLOUR ax.barh(rank, r2[idx], color=colour, edgecolor="white", linewidth=0.3, height=0.65, zorder=3) ax.text(r2[idx] + 0.008, rank, f"{r2[idx]:.3f}", va="center", fontsize=6, zorder=4) ax.set_yticks(range(len(genes))) ax.set_yticklabels([genes[i] for i in order], fontsize=7) ax.invert_yaxis() ax.axvline(0.5, color="#999999", linestyle="--", linewidth=0.6, zorder=2) ax.set_xlim(0, min(1.0, r2.max() + 0.08)) ax.set_xlabel("FMM R\u00b2", fontsize=8) ax.tick_params(labelsize=6) ax.spines[["top", "right"]].set_visible(False) # ═══════════════════════════════════════════════════════════════════════════ # Gráfico 2 — Varianza explicada (scree plot) # ═══════════════════════════════════════════════════════════════════════════
[docs] def plot_variance_explained( cpca_result: "CPCAResult", n_components: int = 10, title: str = "", figsize: tuple[float, float] = (6, 4), ) -> Figure: """ Gráfico de barras tipo scree de la varianza explicada por cada CP. Muestra las primeras ``n_components`` barras con una línea acumulada superpuesta. PC1 y PC2 (usados por CPCA) se resaltan; se marcan los umbrales mínimos para PC2 (10%) y PC1+PC2 total (40%). Parámetros ---------- cpca_result : CPCAResult Debe tener ``variance_explained`` con al menos 3 entradas. n_components : int Numero de CPs a mostrar. title : str Titulo de la figura. figsize : tuple Tamano de la figura en pulgadas. Returns ------- matplotlib.figure.Figure """ var = cpca_result.variance_explained n_avail = len(var) n_show = min(n_components, n_avail) var_show = var[:n_show] cumvar = np.cumsum(var_show) fig, ax1 = plt.subplots(figsize=figsize) # Barras colours = ["#E41A1C" if i < 2 else "#AAAAAA" for i in range(n_show)] bars = ax1.bar(range(1, n_show + 1), var_show * 100, color=colours, edgecolor="white", linewidth=0.5, zorder=3) ax1.set_xlabel("Principal component", fontsize=9) ax1.set_ylabel("Variance explained (%)", fontsize=9, color="#333333") ax1.set_xticks(range(1, n_show + 1)) ax1.tick_params(labelsize=7) ax1.spines[["top", "right"]].set_visible(False) # Línea acumulada en eje secundario ax2 = ax1.twinx() ax2.plot(range(1, n_show + 1), cumvar * 100, "o-", color="#377EB8", markersize=5, linewidth=1.2, zorder=4) ax2.set_ylabel("Cumulative (%)", fontsize=9, color="#377EB8") ax2.tick_params(labelsize=7, colors="#377EB8") ax2.spines["right"].set_color("#377EB8") ax2.set_ylim(0, 105) # Líneas de umbral ax1.axhline(10, color="#E41A1C", linestyle=":", linewidth=0.7, alpha=0.5) ax1.text(n_show + 0.3, 10, "PC2 min (10%)", fontsize=6, color="#E41A1C", va="center") ax2.axhline(40, color="#377EB8", linestyle=":", linewidth=0.7, alpha=0.5) ax2.text(0.6, 42, "PC1+PC2 min (40%)", fontsize=6, color="#377EB8", va="bottom") # Anotar PC1 + PC2 for i in range(min(2, n_show)): ax1.text(i + 1, var_show[i] * 100 + 1, f"{var_show[i]*100:.1f}%", ha="center", fontsize=7, fontweight="bold", color="#E41A1C") suptitle = f"{title} \u2014 " if title else "" ax1.set_title( f"{suptitle}Variance explained by principal components", fontsize=10, pad=8, ) fig.tight_layout() return fig
# ═══════════════════════════════════════════════════════════════════════════ # Gráfico 3 — Heatmap de expresión global # ═══════════════════════════════════════════════════════════════════════════
[docs] def plot_expression_overview( expr_ordered: pd.DataFrame, core_genes: list[str], circular_scale: np.ndarray, n_top: int = 50, title: str = "", figsize: Optional[tuple[float, float]] = None, ) -> Figure: """ Heatmap de expresión normalizada para los genes más rítmicos. Muestra un mapa de calor (genes × muestras en orden circular) donde el color representa el nivel de expresión [-1, +1]. Los genes reloj core se resaltan con una barra de color lateral. Propósito: proporciona una instantánea a nivel genómico de la señal circadiana — los genes rítmicos muestran bandas sinusoidales limpias. Parámetros ---------- expr_ordered : pd.DataFrame Matriz de expresión ya ordenada por fase circular (p. ej., ``SynchronizationResult.expr_ordered``). core_genes : list[str] Símbolos de genes reloj core para resaltar. circular_scale : np.ndarray Eje de tiempo circular para el eje x. n_top : int Numero de genes top a mostrar (por varianza de fila). title : str Titulo de la figura. figsize : tuple, opcional Tamano de la figura en pulgadas. Returns ------- matplotlib.figure.Figure """ # Seleccionar genes top por varianza (más rítmicos) row_var = expr_ordered.var(axis=1) top_genes = row_var.nlargest(n_top).index.tolist() # Asegurar que todos los genes core están incluidos for cg in core_genes: if cg in expr_ordered.index and cg not in top_genes: top_genes.append(cg) # Ordenar: genes core primero, luego por posición de pico (argmax) def sort_key(gene): is_core = gene in core_genes peak_pos = np.argmax(expr_ordered.loc[gene].values) return (0 if is_core else 1, peak_pos) top_genes.sort(key=sort_key) mat = expr_ordered.loc[top_genes].values n_genes_show = len(top_genes) if figsize is None: figsize = (10, max(4, n_genes_show * 0.15)) fig, (ax_bar, ax_heat) = plt.subplots( 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [0.03, 1], "wspace": 0.02}, ) # Barra de color lateral — genes core resaltados core_mask = np.array([g in core_genes for g in top_genes]) side_colours = np.where(core_mask, 1.0, 0.0) ax_bar.imshow(side_colours.reshape(-1, 1), aspect="auto", cmap=mcolors.ListedColormap(["#F0F0F0", "#E41A1C"]), interpolation="nearest") ax_bar.set_xticks([]) ax_bar.set_yticks([]) ax_bar.set_ylabel("Genes", fontsize=9) # Heatmap principal cmap = plt.cm.RdBu_r im = ax_heat.imshow(mat, aspect="auto", cmap=cmap, vmin=-1, vmax=1, interpolation="nearest") # Eje x: mostrar valores de fase n_samples = mat.shape[1] n_ticks = 8 tick_positions = np.linspace(0, n_samples - 1, n_ticks, dtype=int) tick_labels = [f"{circular_scale[p]:.1f}" for p in tick_positions] ax_heat.set_xticks(tick_positions) ax_heat.set_xticklabels(tick_labels, fontsize=6) ax_heat.set_xlabel("Circular phase (rad)", fontsize=9) # Eje y: nombres de genes (mostrar solo un subconjunto si hay demasiados) if n_genes_show <= 30: ax_heat.set_yticks(range(n_genes_show)) ax_heat.set_yticklabels(top_genes, fontsize=5) # Colorear etiquetas de genes core for i, label in enumerate(ax_heat.get_yticklabels()): if top_genes[i] in core_genes: label.set_color("#E41A1C") label.set_fontweight("bold") else: # Mostrar solo genes core core_positions = [i for i, g in enumerate(top_genes) if g in core_genes] ax_heat.set_yticks(core_positions) ax_heat.set_yticklabels( [top_genes[i] for i in core_positions], fontsize=6, color="#E41A1C", fontweight="bold", ) # Barra de color cbar = fig.colorbar(im, ax=ax_heat, shrink=0.6, pad=0.02) cbar.set_label("Normalised expression", fontsize=8) cbar.ax.tick_params(labelsize=6) # Leyenda para la barra lateral legend_elements = [ Line2D([0], [0], marker="s", color="w", markerfacecolor="#E41A1C", markersize=8, label="Core clock gene"), Line2D([0], [0], marker="s", color="w", markerfacecolor="#F0F0F0", markeredgecolor="#CCCCCC", markersize=8, label="Other gene"), ] ax_heat.legend(handles=legend_elements, loc="upper right", fontsize=6, framealpha=0.8) suptitle = f"{title} \u2014 " if title else "" fig.suptitle( f"{suptitle}Expression heatmap (top {n_top} genes, circular order)", fontsize=10, y=1.01, ) fig.tight_layout() return fig