Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 151 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,85 +2,173 @@

[![Tests](https://github.com/YosefLab/csde/actions/workflows/test.yml/badge.svg)](https://github.com/YosefLab/csde/actions/workflows/test.yml)

`csde` (Corrected Spatial Differential Expression) is a Python package designed to **identify differentially expressed (DE) genes between spatially-resolved cell populations** (e.g., T-cells inside vs. outside a tumor).
Automated pipelines for spatial transcriptomics produce cell quantifications (cell-by-gene expression matrices and label assignments) that contain systematic errors, e.g., due to mis-segmentation of cell boundaries.
These errors can propagate into downstream analyses of differential expression, leading to false discoveries or missed signals

Standard analysis relies on cell population assignments (e.g., "infiltrating" vs. "non-infiltrating") obtained automatically from clustering/ML that are often prone to errors. `csde` corrects for these inaccuracies by leveraging a small subset of validated "ground-truth" data, providing rigorous statistical guarantees for spatially-resolved DE analyses.
CSDE corrects for these errors by combining the large automated dataset with a small set of manually validated cells, using prediction-powered inference to recover unbiased estimates with valid confidence intervals.

Refer to the preprint and the [project repository](https://github.com/YosefLab/csde) for more details.
The current codebase focuses on the comparison of a given cell type across two spatial regions.
It allows users to
1. export per-cell annotation panels for a small subset of cells (e.g. 600)
2. manually validate the segmentation and type assignment for these cells
3. run the CSDE model to get corrected DE estimates for all genes

Refer to the [preprint](https://www.biorxiv.org/content/10.64898/2026.01.15.699786v1) for details on the method. Reproducibility code is available [here](https://github.com/PierreBoyeau/csde_experiments).

### Input requirements

The workflow takes a [SpatialData](https://spatialdata.scverse.org/) zarr as input.

Its `"table"` AnnData must contain:

- **raw expression counts** in `.X` or a named layer
- **the following `obs` columns:**

| obs column | content |
| --- | --- |
| `cell_type` (configurable) | cell-type label for each cell |
| `spatial_group` (configurable) | binary spatial region label (e.g. `0` = outside tumour, `1` = inside tumour) |
| `center_x`, `center_y` | cell centroid in microns |

The zarr must also expose the following SpatialData elements, used to render the per-cell annotation panels (Step 1):

| element | requirement |
| --- | --- |
| `images` | at least one image with a named **fluorescence channel** (e.g. `"DAPI"`, `"Cellbound2"`) |
| `shapes` | at least one element holding the **cell-boundary polygons** |
| `points` | at least one element holding **transcript locations**, with a `gene` column |

The cell-boundary `shapes` must carry a transformation to the `global` coordinate system: it converts the micron `center_x`/`center_y` centroids into the image's pixel space. This conversion assumes a pure scale-and-translation transform (as produced for MERSCOPE); transforms with rotation or shear are not handled.

## Installation

```bash
pip install csde
pip install "csde[cuda12]" # GPU (CUDA 12)
pip install "csde[annotate]" # annotation UI (Step 2, requires streamlit)
pip install "csde[cuda12,annotate]" # both
```

By default, this installs JAX with CPU support. To enable GPU support (CUDA), install with the appropriate extra (e.g., for CUDA 12):
## Workflow overview

CSDE runs as three scripts executed in sequence, each consuming the previous one's output: `export.py` samples a small set of cells and renders an annotation panel for each, `annotate.py` lets you manually mark those cells as correct or incorrect, and `differential_expression.py` feeds those validated labels into the CSDE model to produce corrected DE estimates. All three share a single annotation directory.

```
SpatialData zarr
1. Export annotation panels ←─ scripts/export.py
(importance-sampled cells,
one image per cell)
2. Manual validation ←─ scripts/annotate.py
(annotator marks each cell
as correctly / incorrectly labelled)
3. Run CSDE ←─ scripts/differential_expression.py
(corrected DE estimates)
```

---

## Step 1 — Export annotation panels (`scripts/export.py`)

Before running the statistical model, a small subset of cells must be manually validated. `csde` provides tooling to generate the per-cell images needed for that step.

```bash
pip install "csde[cuda12]"
python scripts/export.py \
--sdata /path/to/region.zarr \
--out /path/to/annotation_dir \
--cell-type-key cell_type \
--cell-type-of-interest macrophages \
--target-proportion 0.4 \
--gene-colors scripts/gene_colors_file.json \
--image-channel Cellbound2 \
--n-cells 600 \
--layer counts
```

## Data Requirements

`csde` requires two `AnnData` objects containing gene expression counts. Typically, these are obtained by splitting your full dataset into two groups:

### 1. `adata_pred`: The dataset to analyze
This object contains the bulk of your cells (e.g., the majority of the tissue) where only standard (predicted) cell population assignments are available.

**Requirements:**
* A column in `.obs` (e.g., `"cell_population"`) containing cell population labels (e.g., "T cell (infiltrating)" vs. "T cell (non-infiltrating)"). These labels can be derived from heuristics (e.g., distance to tumor) and/or computational classifiers.

### 2. `adata_gt`: The correction set
This object contains a small subset of randomly sampled cells whose cell population assignments have been **validated** to serve as a ground truth. This set allows `csde` to estimate the error rate of the standard predictions.

**Requirements:**
* **Prediction column:** The same column name as in `adata_pred` (e.g., `"cell_population"`), containing the automated labels.
* **Validation column:** A **boolean** column in `.obs` (e.g., `"is_correct"`) indicating if the automated label matches the validation ground truth (see [How to construct `adata_gt`?](#how-to-construct-adata_gt)).

## Usage

```python
from csde import run_csde

results = run_csde(
# `AnnData` datasets to analyze
adata_pred=adata_pred,
adata_gt=adata_gt,

# Column containing the predicted labels (in BOTH datasets)
pred_cell_pop_key="cell_population",

# The two populations to compare
cell_pop_a="T-cell (infiltrating)", # Reference group
cell_pop_b="T-cell (non-infiltrating)", # Target group

# Boolean column in adata_gt verifying the prediction
gt_key="is_correct",

# Optional: Use a specific layer for counts (default uses .X)
layer_name="counts"
)

# Returns a DataFrame with log_fold_change, p_value, and adjusted p_value
print(results.head())
`--target-proportion` controls the fraction of cells of interest in the subsample. Cells of interest are upweighted accordingly (importance sampling); the unnormalized weight for each sampled cell is stored in `metadata.csv` for downstream use.

`--layer` selects which expression matrix to read: the named `.layers` entry holding the raw counts (e.g. `counts`), or `.X` when omitted. The value is saved to `config.json` and reused throughout the workflow — the same layer feeds the top-gene panels here in Step 1 and the CSDE model in Step 3, so set it once at export time. **It must point at raw counts**, since the noise model (Poisson / negative binomial) assumes integer counts; pointing it at normalised or log-transformed values will produce invalid results.

The script writes:

```
/path/to/annotation_dir/
├── images/
│ ├── cell_<id>.png # one panel per cell
│ └── ...
├── config.json # all export arguments (read by annotate.py)
├── metadata.csv # cell_id, cell_type, image_path, sampling_weight, center_x, center_y
└── annotations.json # {cell_id: true/false} — written by annotate.py
```

### Output Columns
The returned DataFrame is indexed by gene name and contains:
* `log_fold_change`: The estimated log-fold change of expression (Target vs. Reference). Positive values indicate upregulation in `cell_pop_b`.
* `p_value`: The raw p-value from the hypothesis test (two-sided).
* `p_value_adj`: The p-value adjusted for multiple testing (Benjamini-Hochberg FDR).
Each panel contains:
- **Left** — fluorescence image crop + cell boundaries + transcript dots for genes listed in `gene_colors`
- **Right** — top expressed genes (bar chart); genes in `gene_colors` use their assigned colour, others are grey

## How to construct `adata_gt`?
### Gene color file

A simple JSON mapping gene names to colours:

```json
{
"CD68": "#e41a1c",
"MRC1": "#377eb8",
"C1QA": "#4daf4a",
"FCGR3A": "#ff7f00"
}
```

---

## Step 2 — Manual validation (`scripts/annotate.py`)

For each exported image, an annotator decides whether the cell is **correct** — meaning it is both properly **segmented** and properly **labelled**. A cell should be rejected (marked incorrect) when either check fails:

- **Segmentation** — the cell boundary (left panel) is not consistent with the nuclei / membrane staining, e.g. it merges two cells or clips part of one.
- **Cell-type label** — the top expressed genes (right panel) include genes unlikely to be expressed by the cell type of interest, suggesting the automated label is wrong.

The result is a boolean column `is_correct` added to `metadata.csv`, which becomes `adata_gt` in Step 3.

```bash
streamlit run scripts/annotate.py -- --dir /path/to/annotation_dir
```

The `--` is required: it tells Streamlit to pass everything after it to the script rather than interpreting it as Streamlit's own options.

VS Code Remote forwards the Streamlit port automatically. Open the URL printed in the terminal, then use:

- **`1`** — label as correct
- **`2`** — label as incorrect

Progress is saved after every keypress to `annotations.json`. Re-running the command resumes from where you left off. You can also start annotating while `export.py` is still running — the UI picks up newly exported cells automatically.

---

## Step 3 — Differential expression (`scripts/differential_expression.py`)

```bash
python scripts/differential_expression.py --dir /path/to/annotation_dir
```

Constructing `adata_gt` requires validating the cell population labels for a small subset of cells (e.g., random sample). This involves:
1. **Sampling**: Select a small random subset of cells from your dataset.
2. **Data Access**: Extract the relevant data for these cells: their gene expression profile, their spatial coordinates, and importantly, a **high-resolution image crop** of the cell (with segmentation boundaries if available) to assess morphology.
3. **Validation**: Visually inspect these data points to determine the true cell identity.
4. **Annotation**: Create the `is_correct` boolean column based on your assessment.
Reads all export settings from `config.json` and writes gene-level results to `<dir>/results.csv`.

These steps can be performed manually or using dedicated tools.
Our [experimental repository](https://github.com/YosefLab/csde/blob/main/csde_experiments)
provides an example of how these steps were performed for MERFISH data.
| option | default | description |
|---|---|---|
| `--dir` | *(required)* | annotation directory (output of steps 1 & 2) |
| `--out` | `<dir>/results.csv` | output CSV path |
| `--spatial-group-key` | `spatial_group` | obs column encoding the two spatial populations |
| `--n-cells-expressed-threshold` | `10` | min annotated cells expressing a gene for it to be tested |
| `--noise-model` | `poisson` | `poisson` or `nb` (negative binomial) |

To streamline this process, for MERFISH or other spatial transcriptomics data, we recommend using **[SpatialData](https://spatialdata.scverse.org/)** to access the data and perform the manual validation.
### Output columns

| column | description |
|---|---|
| `log_fold_change` | estimated LFC (positive = upregulated in target population) |
| `p_value` | raw two-sided p-value |
| `p_value_adj` | Benjamini-Hochberg adjusted p-value |
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dev = [
"isort",
"flake8",
]
annotate = ["streamlit"]
cuda12 = ["jax[cuda12]"]
cuda13 = ["jax[cuda13]"]

Expand Down
153 changes: 153 additions & 0 deletions scripts/annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Streamlit annotation UI for per-cell panels.

Usage
-----
streamlit run scripts/annotate.py -- --dir /path/to/annotations/R2_macrophages
"""

import argparse
import json
from pathlib import Path

import pandas as pd
import streamlit as st
import streamlit.components.v1 as components


def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--dir", required=True, help="Annotation directory (output of export.py).")
return p.parse_args()


def load_annotations(annotation_dir: Path) -> dict:
ann_path = annotation_dir / "annotations.json"
if ann_path.exists():
with open(ann_path) as f:
return json.load(f)
return {}


def save_annotations(annotations: dict, annotation_dir: Path) -> None:
with open(annotation_dir / "annotations.json", "w") as f:
json.dump(annotations, f, indent=2)


def main():
args = parse_args()
annotation_dir = Path(args.dir)

cell_type_of_interest = "cell of interest"
config_path = annotation_dir / "config.json"
if config_path.exists():
with open(config_path) as f:
config = json.load(f)
cell_type_of_interest = config.get("cell_type_of_interest", cell_type_of_interest)

st.set_page_config(layout="wide", page_title="Cell Annotator")
st.title(f"Cell Annotation — {cell_type_of_interest}")

metadata_path = annotation_dir / "metadata.csv"
if not metadata_path.exists():
st.warning("metadata.csv not found — waiting for export.py to write the first cell.")
st.stop()

metadata = pd.read_csv(metadata_path)
metadata["cell_id"] = metadata["cell_id"].astype(str)
n_total = len(metadata)

annotations = load_annotations(annotation_dir)
n_done = len(annotations)

st.progress(n_done / n_total, text=f"{n_done} / {n_total} annotated")

# Initialize navigation index to first unannotated cell
if "current_idx" not in st.session_state:
unannotated_mask = ~metadata["cell_id"].isin(annotations)
first_unannotated = unannotated_mask.idxmax() if unannotated_mask.any() else 0
st.session_state.current_idx = int(first_unannotated)

idx = st.session_state.current_idx

# Jump-to input (form so Enter submits without looping)
jump_col, nav_col = st.columns([2, 3])
with jump_col:
with st.form("jump_form", clear_on_submit=True):
fc1, fc2 = st.columns([4, 1])
with fc1:
jump_id = st.text_input("Jump to cell ID", placeholder="paste cell_id here", label_visibility="collapsed")
with fc2:
submitted = st.form_submit_button("Go")
if submitted and jump_id.strip():
matches = metadata.index[metadata["cell_id"] == jump_id.strip()].tolist()
if matches:
st.session_state.current_idx = matches[0]
idx = matches[0]
else:
st.warning(f"Cell ID `{jump_id.strip()}` not found.")

with nav_col:
nav1, nav2, nav3 = st.columns([1, 3, 1])
with nav1:
if st.button("← Prev", use_container_width=True, disabled=(idx == 0)):
st.session_state.current_idx = idx - 1
st.rerun()
with nav2:
st.markdown(f"<div style='text-align:center; padding-top:6px'>{idx + 1} / {n_total}</div>", unsafe_allow_html=True)
with nav3:
if st.button("Next →", use_container_width=True, disabled=(idx == n_total - 1)):
st.session_state.current_idx = idx + 1
st.rerun()

row = metadata.iloc[idx]
cell_id = row["cell_id"]
existing = annotations.get(cell_id)
status = "✓ correct" if existing is True else ("✗ incorrect" if existing is False else "not annotated")

st.subheader(f"Cell `{cell_id}` — predicted: **{row['cell_type']}** — {status}")
st.image(str(row["image_path"]), use_container_width=True)

def annotate(is_correct: bool) -> None:
annotations[cell_id] = is_correct
save_annotations(annotations, annotation_dir)
# Advance to next unannotated after annotating
remaining = metadata.index[~metadata["cell_id"].isin(annotations)]
next_idx = next((i for i in remaining if i > idx), None)
if next_idx is not None:
st.session_state.current_idx = int(next_idx)
elif remaining.any():
st.session_state.current_idx = int(remaining[0])

col1, col2, _ = st.columns([1, 1, 4])
with col1:
if st.button("✓ Correct [1]", type="primary", use_container_width=True):
annotate(True)
st.rerun()
with col2:
if st.button("✗ Incorrect [2]", use_container_width=True):
annotate(False)
st.rerun()

# Keyboard shortcuts: 1/2 annotate, ←/→ navigate
components.html("""
<script>
(function() {
const doc = window.parent.document;
doc.addEventListener('keydown', function(e) {
if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') return;
doc.querySelectorAll('button').forEach(function(btn) {
const t = btn.textContent.trim();
if (e.key === '1' && t.startsWith('✓')) btn.click();
if (e.key === '2' && t.startsWith('✗')) btn.click();
if (e.key === 'ArrowLeft' && t.startsWith('←')) btn.click();
if (e.key === 'ArrowRight' && t.startsWith('Next')) btn.click();
});
}, true);
})();
</script>
""", height=0)


if __name__ == "__main__":
main()
Loading
Loading