diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 236eaf37..f864482c 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -46,6 +46,7 @@ ChannelLegendEntry, CmapParams, ColorbarSpec, + ColorLike, GraphRenderParams, ImageRenderParams, LabelsRenderParams, @@ -80,11 +81,6 @@ save_fig, ) -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | list[float] | str - @register_spatial_data_accessor("pl") class PlotAccessor: diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 430d5635..9ee64033 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -87,6 +87,16 @@ _Normalize = Normalize | abc.Sequence[Normalize] +# Shared body of the "blending multiple cmaps" warning. Emitted both when the user +# supplies several cmaps and when a single cmap is broadcast across channels. +_MULTI_CMAP_BLENDING_WARNING = ( + "You're blending multiple cmaps. " + "If the plot doesn't look like you expect, it might be because your " + "cmaps go from a given color to 'white', and not to 'transparent'. " + "Therefore, the 'white' of higher layers will overlay the lower layers. " + "Consider using 'palette' instead." +) + def _get_top_data_array(element: xr.DataArray | DataTree) -> xr.DataArray: if isinstance(element, DataTree): @@ -1136,9 +1146,9 @@ def _render_points( # if the points are colored by values in X (or a different layer), add the values to obs if col_for_color in matched_table.var_names: if table_layer is None: - adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten().copy() + adata_obs[col_for_color] = matched_table[:, col_for_color].X.flatten() else: - adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten().copy() + adata_obs[col_for_color] = matched_table[:, col_for_color].layers[table_layer].flatten() adata = AnnData( X=points[["x", "y"]].values, obs=adata_obs, @@ -1742,13 +1752,7 @@ def _render_images( user_supplied_multi_cmaps = False if user_supplied_multi_cmaps: - logger.warning( - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." - ) + logger.warning(_MULTI_CMAP_BLENDING_WARNING) # Force nearest-neighbor at display time when the datashader reduction picked # a non-mean aggregation; otherwise imshow's default interpolation would smear it. @@ -1864,7 +1868,9 @@ def _render_images( ) layers = {} for ch_idx, ch in enumerate(channels): - layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() + # No copy needed: this entry is only read (min/max) and then replaced + # by a fresh array (np.full or ch_norm(...)) below; img is never mutated. + layers[ch] = img.sel(c=ch).squeeze() if isinstance(render_params.cmap_params, list): ch_norm = render_params.cmap_params[ch_idx].norm else: @@ -1904,11 +1910,7 @@ def _render_images( stacked = stacked[:, :, :3] logger.warning( "One cmap was given for multiple channels and is now used for each channel. " - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." + + _MULTI_CMAP_BLENDING_WARNING ) _ax_show_and_transform( diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index acf03e3a..b021c305 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -15,10 +15,11 @@ _DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] _ImageDsReduction = Literal["max", "min", "mean", "mode", "first", "last", "var", "std"] +# Canonical definition for the package; imported by basic.py and utils.py. # replace with # from spatialdata._types import ColorLike # once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str +ColorLike = tuple[float, ...] | list[float] | str # NOTE: defined here instead of utils to avoid circular import diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 4c96e051..5ddefdb9 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -4,7 +4,7 @@ import os import warnings from collections import Counter, OrderedDict -from collections.abc import Callable, Iterable, Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from copy import copy from functools import partial from pathlib import Path @@ -78,6 +78,7 @@ CmapParams, Color, ColorbarSpec, + ColorLike, FigParams, GraphRenderParams, ImageRenderParams, @@ -93,11 +94,6 @@ to_hex = partial(colors.to_hex, keep_alpha=True) -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | list[float] | str - _GROUPS_IGNORED_WARNING = "Parameter 'groups' is ignored when 'color' is a literal color, not a column name." _RENDER_CMD_TO_CS_FLAG: dict[str, str] = { @@ -996,44 +992,6 @@ def _set_outline( ) -def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = 3) -> plt.Figure | plt.Axes: - """Set up the axs objects. - - Parameters - ---------- - num_images - Number of images to plot. Must be greater than 1. - ncols - Number of columns in the subplot grid, by default 4 - width - Width of each subplot, by default 4 - - Returns - ------- - Union[plt.Figure, plt.Axes] - Matplotlib figure and axes object. - """ - if num_images < ncols: - nrows = 1 - ncols = num_images - else: - nrows, reminder = divmod(num_images, ncols) - - if nrows == 0: - nrows = 1 - if reminder > 0: - nrows += 1 - - fig, axes = plt.subplots(nrows, ncols, figsize=(width * ncols, height * nrows)) - - if not isinstance(axes, Iterable): - axes = np.array([axes]) - - # get rid of the empty axes - _ = [ax.axis("off") for ax in axes.flatten()[num_images:]] - return fig, axes - - def _get_colors_for_categorical_obs( categories: Sequence[str | int], palette: ListedColormap | str | list[str] | None = None, @@ -1503,7 +1461,7 @@ def _map_color_seg( if isinstance(color_vector.dtype, pd.CategoricalDtype): # Case A: users wants to plot a categorical column - val_im: ArrayLike = map_array(seg.copy(), cell_id, color_vector.codes + 1) + val_im: ArrayLike = map_array(seg, cell_id, color_vector.codes + 1) cols = colors.to_rgba_array(color_vector.categories) elif pd.api.types.is_numeric_dtype(color_vector.dtype): # Case B: user wants to plot a continous column @@ -1515,7 +1473,7 @@ def _map_color_seg( normed_color_vector[~np.isnan(normed_color_vector)] ) cols = cmap_params.cmap(normed_color_vector) - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) else: # Case C: User didn't specify any colors if color_source_vector is not None and ( @@ -1524,12 +1482,12 @@ def _map_color_seg( and set(color_vector) == {na_color.get_hex_with_alpha()} and not na_color.color_modified_by_user() ): - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) RNG = default_rng(42) cols = RNG.random((len(color_vector), 3)) else: # Case D: User didn't specify a column to color by, but modified the na_color - val_im = map_array(seg.copy(), cell_id, cell_id) + val_im = map_array(seg, cell_id, cell_id) first_value = color_vector.iloc[0] if isinstance(color_vector, pd.Series) else color_vector[0] if _is_color_like(first_value): # we have color-like values (e.g., hex or named colors) @@ -1550,7 +1508,7 @@ def _map_color_seg( if outline_color_source_vector is not None: cat = pd.Categorical(outline_color_source_vector) cat_codes = cat.codes - outline_val_im: ArrayLike = map_array(seg.copy(), cell_id, cat_codes + 1) + outline_val_im: ArrayLike = map_array(seg, cell_id, cat_codes + 1) color_arr = np.asarray(outline_color_vector, dtype=object) # Pick the first per-cell hex for each category in one vectorized pass # (avoids `K × O(N)` Python loops on large label sets). @@ -1572,7 +1530,7 @@ def _map_color_seg( if finite.any(): normed[finite] = cmap_params.norm(normed[finite]) outline_cols = cmap_params.cmap(normed) - outline_val_im = map_array(seg.copy(), cell_id, cell_id) + outline_val_im = map_array(seg, cell_id, cell_id) if seg_erosionpx is not None: outline_val_im[ outline_val_im == erosion(outline_val_im, footprint_rectangle((seg_erosionpx, seg_erosionpx))) @@ -1814,8 +1772,12 @@ def _to_hex_no_alpha(color_value: Any) -> str | None: if col_to_colorby in adata.obs and hasattr(adata.obs[col_to_colorby], "cat") else categories ) + # Map category -> index once (O(K)) instead of a per-category list scan + # (was O(K^2) via list.index). all_cats comes from pandas .categories, + # which is unique, so a plain dict comprehension is sufficient. + cat_to_idx: dict[Any, int] = {c: i for i, c in enumerate(all_cats)} for category in categories: - idx = all_cats.index(category) if category in all_cats else None + idx = cat_to_idx.get(category) if idx is not None and idx < len(hex_colors) and hex_colors[idx] is not None: hex_color = hex_colors[idx] assert hex_color is not None # type narrowing for mypy diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index ab279d45..4a7c8a5b 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -12,11 +12,10 @@ from spatialdata.models import PointsModel, ShapesModel, TableModel import spatialdata_plot -from spatialdata_plot.pl.render_params import Color +from spatialdata_plot.pl.render_params import Color, ColorLike from spatialdata_plot.pl.utils import ( _apply_cmap_alpha_to_datashader_result, _datashader_map_aggregate_to_color, - _get_subplots, _set_outline, set_zero_in_cmap_to_transparent, ) @@ -34,11 +33,6 @@ # the comp. function can be accessed as `self.compare(, tolerance=)` # ".png" is appended to , no need to set it -# replace with -# from spatialdata._types import ColorLike -# once https://github.com/scverse/spatialdata/pull/689/ is in a release -ColorLike = tuple[float, ...] | str - class TestUtils(PlotTester, metaclass=PlotTesterMeta): @pytest.mark.parametrize( @@ -290,23 +284,6 @@ def test_plot_can_handle_rgba_color_specifications(sdata_blobs: SpatialData): sdata_blobs.pl.render_shapes(element="blobs_circles", color="blue").pl.show() -@pytest.mark.parametrize( - "input_output", - [ - (1, 4, 1, [True]), - (4, 4, 4, [True, True, True, True]), - (6, 4, 8, [True, True, True, True, True, True, False, False]), # 2 rows with 4 columns - ], -) -def test_utils_get_subplots_produces_correct_axs_layout(input_output): - num_images, ncols, len_axs, axs_visible = input_output - - _, axs = _get_subplots(num_images=num_images, ncols=ncols) - - assert len_axs == len(axs.flatten()) - assert axs_visible == [ax.axison for ax in axs.flatten()] - - class TestMultiscaleToSpatialImage: """Regression tests for #589: multiscale resolution selection."""