"""
circust/top_genes.py
====================
Etapa 3: Determinacion de genes TOP ritmicos especificos del tejido.
Implementa la seleccion de genes TOP segun la Seccion 3 de CIRCUST:
1. Ajuste ORI: para cada gen, ajustar regresion isotonica circular
unimodal (PAVA) y calcular R2_ORI = 1 - MSE_unimodal / MSE_plano.
2. Filtro ORI: descartar genes con R2_ORI < 0.5 (ritmicidad no
parametrica insuficiente).
3. Filtro FMM: ajustar el modelo FMM a los supervivientes y
seleccionar aquellos que sean
(i) no puntiagudos (omega > 0.1),
(ii) altamente ritmicos (R2_FMM > 0.5),
(iii) picos (t_U) distribuidos en los 4 cuadrantes de [0, 2pi).
4. Genes semilla forzados: los 12 genes reloj centrales se
incluyen siempre; si alguno no paso los filtros, se fuerza su
inclusion.
Posicion en el pipeline
-----------------------
CircularSynchronizer -> **TopGeneSelector** -> RandomSelector
-> RobustEstimator
"""
from __future__ import annotations
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from math import pi
from circust.fitting.ori import ORIModel
from circust.fitting.fmm import FMMModel
from circust.fitting.cosinor import CosinorModel
# ═══════════════════════════════════════════════════════════════════════════
# Funciones auxiliares de estadistica circular
# ═══════════════════════════════════════════════════════════════════════════
[docs]
def assign_sectors(
peaks: np.ndarray,
reference: float,
n_sectors: int = 8,
) -> np.ndarray:
"""
Asigna cada angulo de pico a uno de *n_sectors* sectores equiespaciados
centrados en *reference*.
Los sectores se numeran 1 ... n_sectors.
Parameters
----------
peaks : array (n,) de angulos en [0, 2pi).
reference : float — mediana circular usada como origen.
n_sectors : int — por defecto 8.
Returns
-------
array int (n,) — etiquetas de sector en {1, 2, ..., n_sectors}.
"""
peaks = np.asarray(peaks, dtype=np.float64)
rotated = (peaks - reference) % (2.0 * pi)
sector_width = 2.0 * pi / n_sectors
sectors = n_sectors - np.floor(rotated / sector_width).astype(int)
sectors = np.clip(sectors, 1, n_sectors)
return sectors
[docs]
def sector_boundaries(reference: float, n_sectors: int = 8) -> np.ndarray:
"""
Devuelve los *n_sectors* puntos centrales de sector.
Parameters
----------
reference : float — mediana circular usada como origen.
n_sectors : int — por defecto 8.
Returns
-------
array (n_sectors,) — centros de cada sector en [0, 2pi).
"""
return np.array([
(reference - k * pi / (n_sectors / 2)) % (2.0 * pi)
for k in range(n_sectors)
])
# ═══════════════════════════════════════════════════════════════════════════
# Workers a nivel de modulo (serializables por joblib)
# ═══════════════════════════════════════════════════════════════════════════
def _fit_ori_one(
row: np.ndarray,
circ: np.ndarray,
) -> float:
"""
Ajusta ORI a un unico gen. Devuelve R2_ORI.
Funcion pura a nivel de modulo para que ``joblib`` pueda serializarla
hacia procesos worker.
"""
fr = ORIModel().fit(row, circ)
return fr.r2
def _fit_fmm_one(
row: np.ndarray,
circ: np.ndarray,
alpha_grid: int,
omega_grid: int,
num_reps: int,
) -> tuple[float, float, float, np.ndarray]:
"""
Ajusta FMM a un unico gen. Devuelve (peak_time, omega, r2, fitted).
Funcion pura a nivel de modulo para que ``joblib`` pueda serializarla
hacia procesos worker.
"""
model = FMMModel(
length_alpha_grid=alpha_grid,
length_omega_grid=omega_grid,
num_reps=num_reps,
)
fr = model.fit(row, circ)
return fr.peak_time, fr.params["omega"], fr.r2, fr.fitted
# ═══════════════════════════════════════════════════════════════════════════
# Dataclass de resultado
# ═══════════════════════════════════════════════════════════════════════════
[docs]
@dataclass
class TopGeneResult:
"""
Salida de :class:`TopGeneSelector`.
Los campos principales son compatibles por duck-typing con
``ReferenceSetResult``, de modo que ``RandomSelector`` puede
consumir este resultado directamente.
Atributos — pipeline
--------------------
gene_names : list[str]
Simbolos de los genes TOP seleccionados.
fmm_peaks : np.ndarray (n_top,)
Tiempos de pico FMM (t_U) por gen.
r2_par : np.ndarray (n_top,)
R2 parametrico (FMM) por gen.
sector_labels : np.ndarray int (n_top,)
Cuadrante asignado (1–4) por gen.
sector_centres : np.ndarray (n_sectors,)
Centros de los cuadrantes en [0, 2pi).
candidate_matrix : pd.DataFrame (n_top, n_muestras)
Matriz de expresion de los genes TOP (en el orden circular
preliminar). Equivale a [X_TOP] del paper.
cosinor_peaks : np.ndarray (n_top,)
Acrofases Cosinor de cada gen TOP.
circular_scale : np.ndarray (n_muestras,)
Eje temporal circular del ordenamiento preliminar.
added_genes : list[str]
Genes semilla que fueron forzados (no pasaron los filtros
por si solos). Se excluyen del pool de muestreo aleatorio
en ``RandomSelector``.
sample_order : np.ndarray int (n_muestras,)
Orden de muestras del sincronizado preliminar.
Atributos — diagnostico
-----------------------
omega : np.ndarray (n_top,)
Parametro omega de FMM por gen.
r2_ori : np.ndarray (n_top,)
R2 ORI (no parametrico) de cada gen TOP.
n_total_genes : int
Numero total de genes antes de cualquier filtro.
n_after_ori : int
Genes que superaron el filtro R2_ORI >= umbral.
n_after_fmm : int
Genes que superaron los filtros FMM (omega + R2_FMM)
*antes* de forzar genes semilla.
"""
# Pipeline (compatibles con ReferenceSetResult por duck-typing)
gene_names: list = field(default_factory=list)
fmm_peaks: np.ndarray = field(default_factory=lambda: np.array([]))
r2_par: np.ndarray = field(default_factory=lambda: np.array([]))
sector_labels: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
sector_centres: np.ndarray = field(default_factory=lambda: np.zeros(4))
candidate_matrix: pd.DataFrame = field(default_factory=pd.DataFrame)
cosinor_peaks: np.ndarray = field(default_factory=lambda: np.array([]))
circular_scale: np.ndarray = field(default_factory=lambda: np.array([]))
added_genes: list = field(default_factory=list)
sample_order: np.ndarray = field(default_factory=lambda: np.array([], dtype=int))
# Diagnostico
omega: np.ndarray = field(default_factory=lambda: np.array([]))
r2_ori: np.ndarray = field(default_factory=lambda: np.array([]))
n_total_genes: int = 0
n_after_ori: int = 0
n_after_fmm: int = 0
[docs]
def summary(self) -> str:
n = len(self.gene_names)
n_sec = len(self.sector_centres)
n_covered = len(np.unique(self.sector_labels)) if n > 0 else 0
lines = [
"=== Resumen de Genes TOP ===",
f" Genes totales : {self.n_total_genes}",
f" Tras filtro R2_ORI>=0.5 : {self.n_after_ori}",
f" Tras filtros FMM : {self.n_after_fmm}",
f" Genes semilla forzados : {self.added_genes or 'ninguno'}",
f" Genes TOP finales : {n}",
f" Cuadrantes cubiertos : {n_covered} / {n_sec}",
]
if n > 0:
for s in sorted(np.unique(self.sector_labels)):
cnt = int(np.sum(self.sector_labels == s))
lines.append(f" Cuadrante {s}: {cnt} gen(es)")
return "\n".join(lines)
# ═══════════════════════════════════════════════════════════════════════════
# TopGeneSelector
# ═══════════════════════════════════════════════════════════════════════════
[docs]
class TopGeneSelector:
"""
Selecciona los genes TOP ritmicos especificos del tejido.
Algoritmo (Seccion 3 del paper CIRCUST)
----------------------------------------
1. Ajustar ORI a cada gen y calcular R2_ORI.
2. Descartar genes con R2_ORI < ``r2_ori_threshold`` (0.5).
3. Ajustar FMM a los supervivientes.
4. Filtrar: omega > ``omega_min`` (0.1) Y R2_FMM > ``r2_fmm_threshold`` (0.5).
5. Asignar picos FMM a ``n_sectors`` cuadrantes y verificar cobertura.
6. Forzar inclusion de los genes core si no estan presentes.
7. Construir la matriz TOP ([X_TOP] en el paper).
Parameters
----------
r2_ori_threshold : float
Umbral minimo de R2 ORI para el pre-filtro no parametrico.
Paper: 0.5.
r2_fmm_threshold : float
Umbral minimo de R2 FMM para la ritmicidad parametrica.
Paper: 0.5.
omega_min : float
omega minimo de FMM; rechaza ajustes puntiagudos.
Paper: 0.1.
n_sectors : int
Numero de sectores circulares para verificar cobertura de picos.
Paper: 4 (cuadrantes).
fmm_length_alpha_grid, fmm_length_omega_grid, fmm_num_reps : int
Parametros de la rejilla de ajuste FMM.
n_jobs : int
Numero de procesos para el ajuste en paralelo (ORI y FMM).
``-1`` usa todos los nucleos disponibles.
verbose : bool
Imprimir mensajes de progreso.
"""
def __init__(
self,
r2_ori_threshold: float = 0.5,
r2_fmm_threshold: float = 0.5,
omega_min: float = 0.1,
n_sectors: int = 4,
fmm_length_alpha_grid: int = 48,
fmm_length_omega_grid: int = 24,
fmm_num_reps: int = 3,
n_jobs: int = -1,
verbose: bool = True,
) -> None:
self.r2_ori_threshold = r2_ori_threshold
self.r2_fmm_threshold = r2_fmm_threshold
self.omega_min = omega_min
self.n_sectors = n_sectors
self.fmm_length_alpha_grid = fmm_length_alpha_grid
self.fmm_length_omega_grid = fmm_length_omega_grid
self.fmm_num_reps = fmm_num_reps
self.n_jobs = n_jobs
self.verbose = verbose
# ------------------------------------------------------------------
# API publica
# ------------------------------------------------------------------
[docs]
def run(
self,
expr_norm: pd.DataFrame,
circular_scale: np.ndarray,
seed_genes: list[str],
sample_order: np.ndarray | None = None,
) -> TopGeneResult:
"""
Selecciona los genes TOP.
Internamente ajusta ORI (regresion isotonica circular unimodal)
a cada gen para calcular R2_ORI, aplica los filtros y construye
el conjunto TOP.
Parameters
----------
expr_norm : pd.DataFrame (n_genes, n_muestras)
Matriz de expresion normalizada completa, ya ordenada
por el tiempo circular preliminar.
circular_scale : array (n_muestras,)
Eje temporal circular del ordenamiento preliminar.
seed_genes : list[str]
Genes core del pipeline (los mismos seleccionados por
``CoreGeneSelector``). Se fuerzan en el conjunto TOP
si no superan los filtros.
sample_order : array int (n_muestras,) o None
Indices de muestra del ordenamiento preliminar.
Si es None se asume identidad.
Returns
-------
TopGeneResult
"""
seed_genes = list(seed_genes)
self._log("=== Etapa 3: Seleccion de Genes TOP ===")
gene_names = list(expr_norm.index)
circ = np.asarray(circular_scale, dtype=np.float64)
n_total = len(gene_names)
values = expr_norm.values.astype(np.float64)
if sample_order is None:
sample_order = np.arange(expr_norm.shape[1], dtype=int)
# ── 1. Ajuste ORI (regresion isotonica circular unimodal) ──────
self._log(f" Genes totales: {n_total}")
self._log(" Ajustando ORI (PAVA unimodal) ...")
r2_ori = self._fit_ori_batch(values, circ)
# ── 2. Filtro R2_ORI ────────────────────────────────────────────
self._log(
f" Filtro R2_ORI >= {self.r2_ori_threshold} ..."
)
mask_ori = r2_ori >= self.r2_ori_threshold
surv_genes = [g for g, m in zip(gene_names, mask_ori) if m]
surv_r2_ori = r2_ori[mask_ori]
n_after_ori = len(surv_genes)
self._log(f" Supervivientes R2_ORI: {n_after_ori}")
# ── 3. Ajustar FMM a los supervivientes (paralelo) ────────────
self._log(" Ajustando FMM ...")
surv_rows = np.array(
[expr_norm.loc[g].values.astype(np.float64) for g in surv_genes]
)
fmm_results = self._fit_fmm_batch(surv_rows, circ)
fmm_peaks_all = np.array([r[0] for r in fmm_results])
fmm_omega_all = np.array([r[1] for r in fmm_results])
fmm_r2_all = np.array([r[2] for r in fmm_results])
# ── 4. Filtro FMM: omega > min Y R2_FMM > umbral ──────────────
mask_fmm = (
(fmm_omega_all > self.omega_min)
& (fmm_r2_all > self.r2_fmm_threshold)
)
n_after_fmm = int(mask_fmm.sum())
self._log(
f" Filtro FMM (omega>{self.omega_min}, "
f"R2>{self.r2_fmm_threshold}): {n_after_fmm} supervivientes"
)
# Acumular genes aceptados
acc_names: list[str] = []
acc_r2_ori: list[float] = []
acc_r2_fmm: list[float] = []
acc_omega: list[float] = []
acc_peaks: list[float] = []
acc_rows: list[np.ndarray] = []
for i, gene in enumerate(surv_genes):
if mask_fmm[i]:
acc_names.append(gene)
acc_r2_ori.append(float(surv_r2_ori[i]))
acc_r2_fmm.append(float(fmm_r2_all[i]))
acc_omega.append(float(fmm_omega_all[i]))
acc_peaks.append(float(fmm_peaks_all[i]))
acc_rows.append(surv_rows[i])
# ── 5. Forzar genes core (semilla) ────────────────────────────
forced: list[str] = []
accepted_set = set(acc_names)
for seed in seed_genes:
if seed in accepted_set:
continue
if seed not in expr_norm.index:
self._log(
f" AVISO: gen semilla {seed} ausente de la matriz."
)
continue
row = expr_norm.loc[seed].values.astype(np.float64)
pk, om, r2, _ = _fit_fmm_one(
row, circ,
self.fmm_length_alpha_grid,
self.fmm_length_omega_grid,
self.fmm_num_reps,
)
idx_all = (
gene_names.index(seed) if seed in gene_names else -1
)
seed_r2_ori = float(r2_ori[idx_all]) if idx_all >= 0 else 0.0
acc_names.append(seed)
acc_r2_ori.append(seed_r2_ori)
acc_r2_fmm.append(float(r2))
acc_omega.append(float(om))
acc_peaks.append(float(pk))
acc_rows.append(row)
forced.append(seed)
if forced:
self._log(f" Genes semilla forzados: {forced}")
n_top = len(acc_names)
self._log(f" Genes TOP finales: {n_top}")
# ── 6. Acrofases Cosinor (para cuadrantes y RandomSelector) ────
cos_model = CosinorModel()
cos_peaks = np.empty(n_top, dtype=np.float64)
for i in range(n_top):
cr = cos_model.fit(acc_rows[i], circ)
cos_peaks[i] = cr.peak_time
# ── 7. Asignar cuadrantes ──────────────────────────────────────
peaks_arr = np.array(acc_peaks)
if n_top > 0:
peak_ref = circular_median(peaks_arr)
centres = sector_boundaries(peak_ref, self.n_sectors)
sectors = assign_sectors(
peaks_arr, peak_ref, self.n_sectors,
)
else:
centres = np.zeros(self.n_sectors)
sectors = np.array([], dtype=int)
covered = set(sectors.tolist()) if n_top > 0 else set()
expected = set(range(1, self.n_sectors + 1))
missing = expected - covered
if missing:
self._log(
f" AVISO: cuadrantes sin cobertura: {sorted(missing)}"
)
# ── 8. Construir matriz TOP ────────────────────────────────────
if n_top > 0:
top_matrix = pd.DataFrame(
np.vstack(acc_rows),
index=acc_names,
columns=expr_norm.columns,
)
else:
top_matrix = pd.DataFrame(
dtype=float, columns=expr_norm.columns,
)
result = TopGeneResult(
gene_names = acc_names,
fmm_peaks = peaks_arr,
r2_par = np.array(acc_r2_fmm),
sector_labels = sectors,
sector_centres = centres,
candidate_matrix = top_matrix,
cosinor_peaks = cos_peaks,
circular_scale = circ,
added_genes = forced,
sample_order = np.asarray(sample_order, dtype=int),
omega = np.array(acc_omega),
r2_ori = np.array(acc_r2_ori),
n_total_genes = n_total,
n_after_ori = n_after_ori,
n_after_fmm = n_after_fmm,
)
self._log(result.summary())
return result
# ------------------------------------------------------------------
# Ajuste ORI en lote (paralelo)
# ------------------------------------------------------------------
def _fit_ori_batch(
self,
rows: np.ndarray,
circ: np.ndarray,
) -> np.ndarray:
"""Ajusta ORI a multiples genes en paralelo. Devuelve R2_ORI (n_genes,)."""
n = rows.shape[0]
if self.n_jobs == 1 or n <= 8:
r2_vals = [
_fit_ori_one(rows[i], circ) for i in range(n)
]
else:
from joblib import Parallel, delayed
self._log(
f" Paralelizando {n} ajustes ORI en {self.n_jobs} jobs ..."
)
r2_vals = Parallel(n_jobs=self.n_jobs)(
delayed(_fit_ori_one)(rows[i], circ) for i in range(n)
)
return np.array(r2_vals, dtype=np.float64)
# ------------------------------------------------------------------
# Ajuste FMM en lote (paralelo)
# ------------------------------------------------------------------
def _fit_fmm_batch(
self,
rows: np.ndarray,
circ: np.ndarray,
) -> list[tuple[float, float, float, np.ndarray]]:
"""Ajusta FMM a multiples genes en paralelo con joblib."""
n = rows.shape[0]
ag = self.fmm_length_alpha_grid
og = self.fmm_length_omega_grid
nr = self.fmm_num_reps
if self.n_jobs == 1 or n <= 8:
return [
_fit_fmm_one(rows[i], circ, ag, og, nr)
for i in range(n)
]
from joblib import Parallel, delayed
self._log(
f" Paralelizando {n} ajustes FMM en {self.n_jobs} jobs ..."
)
return Parallel(n_jobs=self.n_jobs)(
delayed(_fit_fmm_one)(rows[i], circ, ag, og, nr)
for i in range(n)
)
# ------------------------------------------------------------------
# Utilidad
# ------------------------------------------------------------------
def _log(self, message: str) -> None:
if self.verbose:
print(message, flush=True)