Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ChannelLegendEntry,
CmapParams,
ColorbarSpec,
ColorLike,
GraphRenderParams,
ImageRenderParams,
LabelsRenderParams,
Expand Down Expand Up @@ -80,11 +81,6 @@
save_fig,
)

# replace with
# from spatialdata._types import ColorLike
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
ColorLike = tuple[float, ...] | list[float] | str


@register_spatial_data_accessor("pl")
class PlotAccessor:
Expand Down
32 changes: 17 additions & 15 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@

_Normalize = Normalize | abc.Sequence[Normalize]

# Shared body of the "blending multiple cmaps" warning. Emitted both when the user
# supplies several cmaps and when a single cmap is broadcast across channels.
_MULTI_CMAP_BLENDING_WARNING = (
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
)


def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray:
if isinstance(element, DataTree):
Expand Down Expand Up @@ -1136,9 +1146,9 @@ def _render_points(
# if the points are colored by values in X (or a different layer), add the values to obs
if col_for_color in matched_table.var_names:
if table_layer is None:
adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten().copy()
adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten()
else:
adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten().copy()
adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten()
adata = AnnData(
X=points[["x", "y"]].values,
obs=adata_obs,
Expand Down Expand Up @@ -1742,13 +1752,7 @@ def _render_images(
user_supplied_multi_cmaps = False

if user_supplied_multi_cmaps:
logger.warning(
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
)
logger.warning(_MULTI_CMAP_BLENDING_WARNING)

# Force nearest-neighbor at display time when the datashader reduction picked
# a non-mean aggregation; otherwise imshow's default interpolation would smear it.
Expand Down Expand Up @@ -1864,7 +1868,9 @@ def _render_images(
)
layers = {}
for ch_idx, ch in enumerate(channels):
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
# No copy needed: this entry is only read (min/max) and then replaced
# by a fresh array (np.full or ch_norm(...)) below; img is never mutated.
layers[ch] = img.sel(c=ch).squeeze()
if isinstance(render_params.cmap_params, list):
ch_norm = render_params.cmap_params[ch_idx].norm
else:
Expand Down Expand Up @@ -1904,11 +1910,7 @@ def _render_images(
stacked = stacked[:, :, :3]
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
+ _MULTI_CMAP_BLENDING_WARNING
)

_ax_show_and_transform(
Expand Down
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"]
_ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"]

# Canonical definition for the package; imported by basic.py and utils.py.
# replace with
# from spatialdata._types import ColorLike
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
ColorLike = tuple[float, ...] | str
ColorLike = tuple[float, ...] | list[float] | str


# NOTE: defined here instead of utils to avoid circular import
Expand Down
64 changes: 13 additions & 51 deletions src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import warnings
from collections import Counter, OrderedDict
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
from copy import copy
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -78,6 +78,7 @@
CmapParams,
Color,
ColorbarSpec,
ColorLike,
FigParams,
GraphRenderParams,
ImageRenderParams,
Expand All @@ -93,11 +94,6 @@

to_hex = partial(colors.to_hex, keep_alpha=True)

# replace with
# from spatialdata._types import ColorLike
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
ColorLike = tuple[float, ...] | list[float] | str

_GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name."

_RENDER_CMD_TO_CS_FLAG: dict[str, str] = {
Expand Down Expand Up @@ -996,44 +992,6 @@ def _set_outline(
)


def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes:
"""Set up the axs objects.

Parameters
----------
num_images
Number of images to plot. Must be greater than 1.
ncols
Number of columns in the subplot grid, by default 4
width
Width of each subplot, by default 4

Returns
-------
Union[plt.Figure, plt.Axes]
Matplotlib figure and axes object.
"""
if num_images < ncols:
nrows = 1
ncols = num_images
else:
nrows, reminder = divmod(num_images, ncols)

if nrows == 0:
nrows = 1
if reminder > 0:
nrows += 1

fig, axes = plt.subplots(nrows, ncols, figsize=(width * ncols, height * nrows))

if not isinstance(axes, Iterable):
axes = np.array([axes])

# get rid of the empty axes
_ = [ax.axis("off") for ax in axes.flatten()[num_images:]]
return fig, axes


def _get_colors_for_categorical_obs(
categories: Sequence[str | int],
palette: ListedColormap | str | list[str] | None = None,
Expand Down Expand Up @@ -1503,7 +1461,7 @@ def _map_color_seg(

if isinstance(color_vector.dtype, pd.CategoricalDtype):
# Case A: users wants to plot a categorical column
val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1)
val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1)
cols = colors.to_rgba_array(color_vector.categories)
elif pd.api.types.is_numeric_dtype(color_vector.dtype):
# Case B: user wants to plot a continous column
Expand All @@ -1515,7 +1473,7 @@ def _map_color_seg(
normed_color_vector[~np.isnan(normed_color_vector)]
)
cols = cmap_params.cmap(normed_color_vector)
val_im = map_array(seg.copy(), cell_id, cell_id)
val_im = map_array(seg, cell_id, cell_id)
else:
# Case C: User didn't specify any colors
if color_source_vector is not None and (
Expand All @@ -1524,12 +1482,12 @@ def _map_color_seg(
and set(color_vector) == {na_color.get_hex_with_alpha()}
and not na_color.color_modified_by_user()
):
val_im = map_array(seg.copy(), cell_id, cell_id)
val_im = map_array(seg, cell_id, cell_id)
RNG = default_rng(42)
cols = RNG.random((len(color_vector), 3))
else:
# Case D: User didn't specify a column to color by, but modified the na_color
val_im = map_array(seg.copy(), cell_id, cell_id)
val_im = map_array(seg, cell_id, cell_id)
first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0]
if _is_color_like(first_value):
# we have color-like values (e.g., hex or named colors)
Expand All @@ -1550,7 +1508,7 @@ def _map_color_seg(
if outline_color_source_vector is not None:
cat = pd.Categorical(outline_color_source_vector)
cat_codes = cat.codes
outline_val_im: ArrayLike = map_array(seg.copy(), cell_id, cat_codes + 1)
outline_val_im: ArrayLike = map_array(seg, cell_id, cat_codes + 1)
color_arr = np.asarray(outline_color_vector, dtype=object)
# Pick the first per-cell hex for each category in one vectorized pass
# (avoids `K × O(N)` Python loops on large label sets).
Expand All @@ -1572,7 +1530,7 @@ def _map_color_seg(
if finite.any():
normed[finite] = cmap_params.norm(normed[finite])
outline_cols = cmap_params.cmap(normed)
outline_val_im = map_array(seg.copy(), cell_id, cell_id)
outline_val_im = map_array(seg, cell_id, cell_id)
if seg_erosionpx is not None:
outline_val_im[
outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx)))
Expand Down Expand Up @@ -1814,8 +1772,12 @@ def _to_hex_no_alpha(color_value: Any) -> str | None:
if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat")
else categories
)
# Map category -> index once (O(K)) instead of a per-category list scan
# (was O(K^2) via list.index). all_cats comes from pandas .categories,
# which is unique, so a plain dict comprehension is sufficient.
cat_to_idx: dict[Any, int] = {c: i for i, c in enumerate(all_cats)}
for category in categories:
idx = all_cats.index(category) if category in all_cats else None
idx = cat_to_idx.get(category)
if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None:
hex_color = hex_colors[idx]
assert hex_color is not None # type narrowing for mypy
Expand Down
25 changes: 1 addition & 24 deletions tests/pl/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from spatialdata.models import PointsModel, ShapesModel, TableModel

import spatialdata_plot
from spatialdata_plot.pl.render_params import Color
from spatialdata_plot.pl.render_params import Color, ColorLike
from spatialdata_plot.pl.utils import (
_apply_cmap_alpha_to_datashader_result,
_datashader_map_aggregate_to_color,
_get_subplots,
_set_outline,
set_zero_in_cmap_to_transparent,
)
Expand All @@ -34,11 +33,6 @@
# the comp. function can be accessed as `self.compare(<your_filename>, tolerance=<your_tolerance>)`
# ".png" is appended to <your_filename>, no need to set it

# replace with
# from spatialdata._types import ColorLike
# once https://github.com/scverse/spatialdata/pull/689/ is in a release
ColorLike = tuple[float, ...] | str


class TestUtils(PlotTester, metaclass=PlotTesterMeta):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -290,23 +284,6 @@ def test_plot_can_handle_rgba_color_specifications(sdata_blobs: SpatialData):
sdata_blobs.pl.render_shapes(element="blobs_circles", color="blue").pl.show()


@pytest.mark.parametrize(
"input_output",
[
(1, 4, 1, [True]),
(4, 4, 4, [True, True, True, True]),
(6, 4, 8, [True, True, True, True, True, True, False, False]), # 2 rows with 4 columns
],
)
def test_utils_get_subplots_produces_correct_axs_layout(input_output):
num_images, ncols, len_axs, axs_visible = input_output

_, axs = _get_subplots(num_images=num_images, ncols=ncols)

assert len_axs == len(axs.flatten())
assert axs_visible == [ax.axison for ax in axs.flatten()]


class TestMultiscaleToSpatialImage:
"""Regression tests for #589: multiscale resolution selection."""

Expand Down
Loading