From 67c229fcf40026e7f85fe8a960af5fe937699831 Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 7 Jun 2026 01:00:11 +0200 Subject: [PATCH 1/2] Maintainability quick wins: dedup, redundant copies, O(K^2) lookup A small low-risk sweep from a maintainability audit of main. No public API or behavioral change; multi-channel image and labels renders verified byte-identical to main (RGBA buffers compared). - Single-source the `ColorLike` type alias. It was defined three times, and the copy in render_params.py had silently dropped `list[float]`. render_params.py is now the canonical definition; basic.py and utils.py import it. - Drop provably-dead copies: - `_map_color_seg`: `map_array(seg.copy(), ...)` -> `map_array(seg, ...)` (x6). skimage `map_array` never mutates its input. - `_render_images`: per-channel `img.sel(c=ch).copy(deep=True)` -> `.squeeze()`. The entry is only read (min/max) then replaced by a fresh array. - `_render_points`: `...flatten().copy()` -> `...flatten()`. `flatten()` already returns a fresh array. - `_extract_colors_from_table_uns`: replace the per-category `list.index`/`in` scan (O(K^2) for K categories) with a single `{category: first_index}` dict. - Hoist the duplicated "blending multiple cmaps" warning into a module constant (the two copies had already drifted by a sentence). - Remove `_get_subplots` (no callers in src; only a test exercised it) and its test. --- src/spatialdata_plot/pl/basic.py | 6 +-- src/spatialdata_plot/pl/render.py | 32 ++++++------ src/spatialdata_plot/pl/render_params.py | 3 +- src/spatialdata_plot/pl/utils.py | 65 +++++------------------- tests/pl/test_utils.py | 18 ------- 5 files changed, 34 insertions(+), 90 deletions(-) diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 236eaf37..f864482c 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -46,6 +46,7 @@ ChannelLegendEntry, CmapParams, ColorbarSpec, + ColorLike, GraphRenderParams, ImageRenderParams, LabelsRenderParams, @@ -80,11 +81,6 @@ save_fig, ) -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | list[float] | str - @register_spatial_data_accessor("pl") class PlotAccessor: diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 430d5635..9ee64033 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -87,6 +87,16 @@ _Normalize = Normalize | abc.Sequence[Normalize] +# Shared body of the "blending multiple cmaps" warning. Emitted both when the user +# supplies several cmaps and when a single cmap is broadcast across channels. +_MULTI_CMAP_BLENDING_WARNING = ( + "You're blending multiple cmaps. " + "If the plot doesn't look like you expect, it might be because your " + "cmaps go from a given color to 'white', and not to 'transparent'. " + "Therefore, the 'white' of higher layers will overlay the lower layers. " + "Consider using 'palette' instead." +) + def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray: if isinstance(element, DataTree): @@ -1136,9 +1146,9 @@ def _render_points( # if the points are colored by values in X (or a different layer), add the values to obs if col_for_color in matched_table.var_names: if table_layer is None: - adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten().copy() + adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten() else: - adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten().copy() + adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten() adata = AnnData( X=points[["x", "y"]].values, obs=adata_obs, @@ -1742,13 +1752,7 @@ def _render_images( user_supplied_multi_cmaps = False if user_supplied_multi_cmaps: - logger.warning( - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." - ) + logger.warning(_MULTI_CMAP_BLENDING_WARNING) # Force nearest-neighbor at display time when the datashader reduction picked # a non-mean aggregation; otherwise imshow's default interpolation would smear it. @@ -1864,7 +1868,9 @@ def _render_images( ) layers = {} for ch_idx, ch in enumerate(channels): - layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() + # No copy needed: this entry is only read (min/max) and then replaced + # by a fresh array (np.full or ch_norm(...)) below; img is never mutated. + layers[ch] = img.sel(c=ch).squeeze() if isinstance(render_params.cmap_params, list): ch_norm = render_params.cmap_params[ch_idx].norm else: @@ -1904,11 +1910,7 @@ def _render_images( stacked = stacked[:, :, :3] logger.warning( "One cmap was given for multiple channels and is now used for each channel. " - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." + + _MULTI_CMAP_BLENDING_WARNING ) _ax_show_and_transform( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index acf03e3a..b021c305 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -15,10 +15,11 @@ _DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] _ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"] +# Canonical definition for the package; imported by basic.py and utils.py. # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str +ColorLike = tuple[float, ...] | list[float] | str # NOTE: defined here instead of utils to avoid circular import diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4c96e051..fd69843f 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -4,7 +4,7 @@ import os import warnings from collections import Counter, OrderedDict -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from copy import copy from functools import partial from pathlib import Path @@ -78,6 +78,7 @@ CmapParams, Color, ColorbarSpec, + ColorLike, FigParams, GraphRenderParams, ImageRenderParams, @@ -93,11 +94,6 @@ to_hex = partial(colors.to_hex, keep_alpha=True) -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | list[float] | str - _GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." _RENDER_CMD_TO_CS_FLAG: dict[str, str] = { @@ -996,44 +992,6 @@ def _set_outline( ) -def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes: - """Set up the axs objects. - - Parameters - ---------- - num_images - Number of images to plot. Must be greater than 1. - ncols - Number of columns in the subplot grid, by default 4 - width - Width of each subplot, by default 4 - - Returns - ------- - Union[plt.Figure, plt.Axes] - Matplotlib figure and axes object. - """ - if num_images < ncols: - nrows = 1 - ncols = num_images - else: - nrows, reminder = divmod(num_images, ncols) - - if nrows == 0: - nrows = 1 - if reminder > 0: - nrows += 1 - - fig, axes = plt.subplots(nrows, ncols, figsize=(width * ncols, height * nrows)) - - if not isinstance(axes, Iterable): - axes = np.array([axes]) - - # get rid of the empty axes - _ = [ax.axis("off") for ax in axes.flatten()[num_images:]] - return fig, axes - - def _get_colors_for_categorical_obs( categories: Sequence[str | int], palette: ListedColormap | str | list[str] | None = None, @@ -1503,7 +1461,7 @@ def _map_color_seg( if isinstance(color_vector.dtype, pd.CategoricalDtype): # Case A: users wants to plot a categorical column - val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1) + val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1) cols = colors.to_rgba_array(color_vector.categories) elif pd.api.types.is_numeric_dtype(color_vector.dtype): # Case B: user wants to plot a continous column @@ -1515,7 +1473,7 @@ def _map_color_seg( normed_color_vector[~np.isnan(normed_color_vector)] ) cols = cmap_params.cmap(normed_color_vector) - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) else: # Case C: User didn't specify any colors if color_source_vector is not None and ( @@ -1524,12 +1482,12 @@ def _map_color_seg( and set(color_vector) == {na_color.get_hex_with_alpha()} and not na_color.color_modified_by_user() ): - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) RNG = default_rng(42) cols = RNG.random((len(color_vector), 3)) else: # Case D: User didn't specify a column to color by, but modified the na_color - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0] if _is_color_like(first_value): # we have color-like values (e.g., hex or named colors) @@ -1550,7 +1508,7 @@ def _map_color_seg( if outline_color_source_vector is not None: cat = pd.Categorical(outline_color_source_vector) cat_codes = cat.codes - outline_val_im: ArrayLike = map_array(seg.copy(), cell_id, cat_codes + 1) + outline_val_im: ArrayLike = map_array(seg, cell_id, cat_codes + 1) color_arr = np.asarray(outline_color_vector, dtype=object) # Pick the first per-cell hex for each category in one vectorized pass # (avoids `K × O(N)` Python loops on large label sets). @@ -1572,7 +1530,7 @@ def _map_color_seg( if finite.any(): normed[finite] = cmap_params.norm(normed[finite]) outline_cols = cmap_params.cmap(normed) - outline_val_im = map_array(seg.copy(), cell_id, cell_id) + outline_val_im = map_array(seg, cell_id, cell_id) if seg_erosionpx is not None: outline_val_im[ outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx))) @@ -1814,8 +1772,13 @@ def _to_hex_no_alpha(color_value: Any) -> str | None: if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat") else categories ) + # Map each category to its first index in O(1) instead of a per-category + # list scan (was O(K^2) via list.index for K categories). + cat_to_idx: dict[Any, int] = {} + for i, c in enumerate(all_cats): + cat_to_idx.setdefault(c, i) for category in categories: - idx = all_cats.index(category) if category in all_cats else None + idx = cat_to_idx.get(category) if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None: hex_color = hex_colors[idx] assert hex_color is not None # type narrowing for mypy diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index ab279d45..e2113b8d 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -16,7 +16,6 @@ from spatialdata_plot.pl.utils import ( _apply_cmap_alpha_to_datashader_result, _datashader_map_aggregate_to_color, - _get_subplots, _set_outline, set_zero_in_cmap_to_transparent, ) @@ -290,23 +289,6 @@ def test_plot_can_handle_rgba_color_specifications(sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", color="blue").pl.show() -@pytest.mark.parametrize( - "input_output", - [ - (1, 4, 1, [True]), - (4, 4, 4, [True, True, True, True]), - (6, 4, 8, [True, True, True, True, True, True, False, False]), # 2 rows with 4 columns - ], -) -def test_utils_get_subplots_produces_correct_axs_layout(input_output): - num_images, ncols, len_axs, axs_visible = input_output - - _, axs = _get_subplots(num_images=num_images, ncols=ncols) - - assert len_axs == len(axs.flatten()) - assert axs_visible == [ax.axison for ax in axs.flatten()] - - class TestMultiscaleToSpatialImage: """Regression tests for #589: multiscale resolution selection.""" From 4a16a34243f6bdcf34781c5cada32f94ff8d5e0c Mon Sep 17 00:00:00 2001 From: anon Date: Sun, 7 Jun 2026 01:12:39 +0200 Subject: [PATCH 2/2] Simplify: dict comprehension for category index; dedup test ColorLike Follow-up cleanup (no behavioral change): - `_extract_colors_from_table_uns`: the category->index map is built from pandas `.categories` (always unique), so a dict comprehension replaces the `setdefault` loop without changing first-occurrence semantics. - tests/pl/test_utils.py: import the canonical `ColorLike` from render_params instead of keeping a local copy that had drifted to the old narrow form (missing `list[float]`). --- src/spatialdata_plot/pl/utils.py | 9 ++++----- tests/pl/test_utils.py | 7 +------ 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index fd69843f..5ddefdb9 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -1772,11 +1772,10 @@ def _to_hex_no_alpha(color_value: Any) -> str | None: if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat") else categories ) - # Map each category to its first index in O(1) instead of a per-category - # list scan (was O(K^2) via list.index for K categories). - cat_to_idx: dict[Any, int] = {} - for i, c in enumerate(all_cats): - cat_to_idx.setdefault(c, i) + # Map category -> index once (O(K)) instead of a per-category list scan + # (was O(K^2) via list.index). all_cats comes from pandas .categories, + # which is unique, so a plain dict comprehension is sufficient. + cat_to_idx: dict[Any, int] = {c: i for i, c in enumerate(all_cats)} for category in categories: idx = cat_to_idx.get(category) if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None: diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index e2113b8d..4a7c8a5b 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -12,7 +12,7 @@ from spatialdata.models import PointsModel, ShapesModel, TableModel import spatialdata_plot -from spatialdata_plot.pl.render_params import Color +from spatialdata_plot.pl.render_params import Color, ColorLike from spatialdata_plot.pl.utils import ( _apply_cmap_alpha_to_datashader_result, _datashader_map_aggregate_to_color, @@ -33,11 +33,6 @@ # the comp. function can be accessed as `self.compare(, tolerance=)` # ".png" is appended to , no need to set it -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str - class TestUtils(PlotTester, metaclass=PlotTesterMeta): @pytest.mark.parametrize(