diff --git a/src/spatialdata_plot/pl/_color.py b/src/spatialdata_plot/pl/_color.py index 4db897ac..adefb606 100644 --- a/src/spatialdata_plot/pl/_color.py +++ b/src/spatialdata_plot/pl/_color.py @@ -458,6 +458,92 @@ def _extract_color_column( return values.reindex(element.index) +def _resolve_color_origins( + value_to_plot: str | None, + sdata: sd.SpatialData, + element_name: list[str] | str | None, + table_name: str | None, +) -> tuple[list[Any], bool]: + """Locate the color column and resolve df-vs-table shadowing; raise if it lives in >1 place.""" + origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name) + # An explicit `table_name=` disambiguates a column present in both the element df and the table. + explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins) + if explicit_table_shadows_df: + origins = [o for o in origins if o.origin != "df"] + if len(origins) > 1: + raise ValueError( + f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. " + "Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity." + ) + return origins, explicit_table_shadows_df + + +def _fetch_color_source_vector( + sdata: sd.SpatialData, + element: SpatialElement | None, + element_name: list[str] | str | None, + value_to_plot: str | None, + table_name: str | None, + table_layer: str | None, + origins: list[Any], + explicit_table_shadows_df: bool, + preloaded_color_data: pd.Series | None, +) -> ArrayLike | pd.Series: + """Read the raw color column, preferring a direct aligned read over a whole-table join.""" + if preloaded_color_data is not None: + return preloaded_color_data + if ( + isinstance(element, GeoDataFrame) + and isinstance(element_name, str) + and table_name is not None + and table_name in sdata.tables + and origins[0].origin in ("obs", "var") + ): + # Fast path: read the single aligned column directly instead of joining/copying the + # whole annotating table (the join's out-of-order sparse row-gather dominates large renders). + return _extract_color_column( + sdata[table_name], + value_to_plot, + origin=origins[0].origin, + element=element, + element_name=element_name, + table_layer=table_layer, + ) + if explicit_table_shadows_df: + # Pass the table as `element` so upstream `get_values` skips the + # element-column lookup and avoids the multi-origin error. + return get_values( + value_key=value_to_plot, + element=sdata[table_name], + element_name=element_name, + table_layer=table_layer, + )[value_to_plot] + return get_values( + value_key=value_to_plot, + sdata=sdata, + element_name=element_name, + table_name=table_name, + table_layer=table_layer, + )[value_to_plot] + + +def _resolve_color_table(value_from_element: bool, table_name: str | None, sdata: sd.SpatialData) -> str | None: + """Pick which table supplies .uns colors: none if the value is element-local, else the named or sole table.""" + if value_from_element: + return None + if table_name is not None: + if table_name in sdata.tables: + return table_name + logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.") + return None + table_keys = list(sdata.tables.keys()) + if not table_keys: + return None + if len(table_keys) > 1: + logger.warning(f"No table name provided, using '{table_keys[0]}' as fallback for color mapping.") + return table_keys[0] + + def _set_color_source_vec( sdata: sd.SpatialData, element: SpatialElement | None, @@ -478,26 +564,7 @@ def _set_color_source_vec( color = np.full(len(element), na_color.get_hex_with_alpha()) return color, color, False - # Figure out where to get the color from - origins = _locate_value( - value_key=value_to_plot, - sdata=sdata, - element_name=element_name, - table_name=table_name, - ) - - # When both the element's own dataframe and the chosen table contain a - # column with this name, an explicit `table_name=` resolves the ambiguity — - # keep only the table origin and skip the multi-origin error below. - explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins) - if explicit_table_shadows_df: - origins = [o for o in origins if o.origin != "df"] - - if len(origins) > 1: - raise ValueError( - f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. " - "Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity." - ) + origins, explicit_table_shadows_df = _resolve_color_origins(value_to_plot, sdata, element_name, table_name) if len(origins) == 1 and value_to_plot is not None: if table_name is not None: @@ -507,42 +574,17 @@ def _set_color_source_vec( element_name=element_name, table_name=table_name, ) - if preloaded_color_data is not None: - color_source_vector = preloaded_color_data - elif ( - isinstance(element, GeoDataFrame) - and isinstance(element_name, str) - and table_name is not None - and table_name in sdata.tables - and origins[0].origin in ("obs", "var") - ): - # Fast path: read the single aligned column directly instead of joining/copying the - # whole annotating table (the join's out-of-order sparse row-gather dominates large renders). - color_source_vector = _extract_color_column( - sdata[table_name], - value_to_plot, - origin=origins[0].origin, - element=element, - element_name=element_name, - table_layer=table_layer, - ) - elif explicit_table_shadows_df: - # Pass the table as `element` so upstream `get_values` skips the - # element-column lookup and avoids the multi-origin error. - color_source_vector = get_values( - value_key=value_to_plot, - element=sdata[table_name], - element_name=element_name, - table_layer=table_layer, - )[value_to_plot] - else: - color_source_vector = get_values( - value_key=value_to_plot, - sdata=sdata, - element_name=element_name, - table_name=table_name, - table_layer=table_layer, - )[value_to_plot] + color_source_vector = _fetch_color_source_vector( + sdata=sdata, + element=element, + element_name=element_name, + value_to_plot=value_to_plot, + table_name=table_name, + table_layer=table_layer, + origins=origins, + explicit_table_shadows_df=explicit_table_shadows_df, + preloaded_color_data=preloaded_color_data, + ) color_series = ( color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector) @@ -588,29 +630,8 @@ def _set_color_source_vec( processed = processed.reorder_categories(sorted(processed.categories)) color_source_vector = processed # convert, e.g., `pd.Series` - # When the value lives on the element's own DataFrame (origin="df"), - # there is no reason to look up a table for .uns colors. value_from_element = origins[0].origin == "df" - - # Use the provided table_name parameter, fall back to only one present - table_to_use: str | None - if value_from_element: - table_to_use = None - elif table_name is not None and table_name in sdata.tables: - table_to_use = table_name - elif table_name is not None and table_name not in sdata.tables: - logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.") - table_to_use = None - else: - table_keys = list(sdata.tables.keys()) - if len(table_keys) == 1: - table_to_use = table_keys[0] - elif len(table_keys) > 1: - table_to_use = table_keys[0] - logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.") - else: - table_to_use = None - + table_to_use = _resolve_color_table(value_from_element, table_name, sdata) adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None # Check if custom colors exist in the resolved table's .uns slot diff --git a/src/spatialdata_plot/pl/_geometry.py b/src/spatialdata_plot/pl/_geometry.py index 756d31a7..eedb9dda 100644 --- a/src/spatialdata_plot/pl/_geometry.py +++ b/src/spatialdata_plot/pl/_geometry.py @@ -268,6 +268,40 @@ def _validate_polygons(shapes: GeoDataFrame) -> GeoDataFrame: return shapes +def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(30, 390, 60) + ] + return shapely.Polygon(verts), None + + +def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(45, 360, 90) + ] + return shapely.Polygon(verts), None + + +def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: + return center, radius + + +def _enclosing_circle(coords: np.ndarray) -> tuple[shapely.Point, float]: + """Enclosing circle from a point cloud: centroid of the convex hull and the max vertex distance.""" + hull_pts = coords[ConvexHull(coords).vertices] + center = np.mean(hull_pts, axis=0) + radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + return shapely.Point(center), radius + + def _convert_shapes( shapes: GeoDataFrame, target_shape: str, @@ -282,67 +316,32 @@ def _convert_shapes( # work on a copy with a clean positional index shapes = shapes.reset_index(drop=True).copy() - def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - verts = [ - ( - center.x + radius * math.cos(math.radians(a)), - center.y + radius * math.sin(math.radians(a)), - ) - for a in range(30, 390, 60) - ] - return shapely.Polygon(verts), None - - def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - verts = [ - ( - center.x + radius * math.cos(math.radians(a)), - center.y + radius * math.sin(math.radians(a)), - ) - for a in range(45, 360, 90) - ] - return shapely.Polygon(verts), None - - def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: - return center, radius - def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: - coords = np.array(polygon.exterior.coords) - hull_pts = coords[ConvexHull(coords).vertices] - center = np.mean(hull_pts, axis=0) - radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + center, radius = _enclosing_circle(np.array(polygon.exterior.coords)) nonlocal warn_shape_size if 2 * radius > max_extent * warn_above_extent_fraction: warn_shape_size = True - return shapely.Point(center), radius + return center, radius def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - c, r = _polygon_to_circle(polygon) - return _circle_to_hexagon(c, r) + return _circle_to_hexagon(*_polygon_to_circle(polygon)) def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - c, r = _polygon_to_circle(polygon) - return _circle_to_square(c, r) + return _circle_to_square(*_polygon_to_circle(polygon)) def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: - pts = [] - for poly in multipolygon.geoms: - pts.extend(poly.exterior.coords) - pts_array = np.array(pts) - hull_pts = pts_array[ConvexHull(pts_array).vertices] - center = np.mean(hull_pts, axis=0) - radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + coords = np.array([pt for poly in multipolygon.geoms for pt in poly.exterior.coords]) + center, radius = _enclosing_circle(coords) nonlocal warn_shape_size if 2 * radius > max_extent * warn_above_extent_fraction: warn_shape_size = True - return shapely.Point(center), radius + return center, radius def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - c, r = _multipolygon_to_circle(multipolygon) - return _circle_to_hexagon(c, r) + return _circle_to_hexagon(*_multipolygon_to_circle(multipolygon)) def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - c, r = _multipolygon_to_circle(multipolygon) - return _circle_to_square(c, r) + return _circle_to_square(*_multipolygon_to_circle(multipolygon)) # choose conversion methods conversion_methods: dict[str, Any] diff --git a/src/spatialdata_plot/pl/_validate.py b/src/spatialdata_plot/pl/_validate.py index ad28a613..61e8f639 100644 --- a/src/spatialdata_plot/pl/_validate.py +++ b/src/spatialdata_plot/pl/_validate.py @@ -368,7 +368,7 @@ def _validate_col_for_column_table( return col_for_color, table_name -def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: +def _check_colorbar(param_dict: dict[str, Any]) -> None: colorbar = param_dict.get("colorbar", "auto") if colorbar not in {True, False, None, "auto"}: raise TypeError("Parameter 'colorbar' must be one of True, False or 'auto'.") @@ -377,6 +377,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if colorbar_params is not None and not isinstance(colorbar_params, dict): raise TypeError("Parameter 'colorbar_params' must be a dictionary or None.") + +def _check_element(param_dict: dict[str, Any], element_type: str) -> None: element = param_dict.get("element") if element is not None and not isinstance(element, str): raise ValueError( @@ -392,6 +394,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif element_type == "shapes": param_dict["element"] = [element] if element is not None else list(param_dict["sdata"].shapes.keys()) + +def _check_channel(param_dict: dict[str, Any]) -> None: channel = param_dict.get("channel") if channel is not None and not isinstance(channel, list | str | int): raise TypeError("Parameter 'channel' must be a string, an integer, or a list of strings or integers.") @@ -404,10 +408,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif "channel" in param_dict: param_dict["channel"] = [channel] if channel is not None else None + +def _check_contour_px_type(param_dict: dict[str, Any]) -> None: contour_px = param_dict.get("contour_px") if contour_px and not isinstance(contour_px, int): raise TypeError("Parameter 'contour_px' must be an integer.") + +def _check_color(param_dict: dict[str, Any], element_type: str) -> None: color = param_dict.get("color") if color and element_type in { "shapes", @@ -440,6 +448,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st elif "color" in param_dict and element_type != "images": param_dict["col_for_color"] = None + +def _check_outline(param_dict: dict[str, Any], element_type: str) -> None: outline_width = param_dict.get("outline_width") if outline_width: # outline_width only exists for shapes at the moment @@ -520,12 +530,17 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st else: param_dict["outline_color"] = Color(outline_color) + +def _check_contour_px_range(param_dict: dict[str, Any]) -> None: + contour_px = param_dict.get("contour_px") if contour_px is not None and contour_px < 2: raise ValueError( "Parameter 'contour_px' must be >= 2; values below 2 produce no visible outline " "(a 1x1 erosion is the identity transformation)." ) + +def _check_alpha(param_dict: dict[str, Any], element_type: str) -> None: alpha = param_dict.get("alpha") if alpha is not None: if not isinstance(alpha, float | int): @@ -536,6 +551,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st # set default alpha for points if not given by user explicitly or implicitly (as part of color) param_dict["alpha"] = 1.0 + +def _check_fill_alpha(param_dict: dict[str, Any], element_type: str) -> None: fill_alpha = param_dict.get("fill_alpha") if fill_alpha is not None: if not isinstance(fill_alpha, float | int): @@ -549,6 +566,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st # set default fill_alpha for labels if not given by user explicitly or implicitly (as part of color) param_dict["fill_alpha"] = 0.4 + +def _check_cmap_palette_groups(param_dict: dict[str, Any], element_type: str) -> None: cmap = param_dict.get("cmap") palette = param_dict.get("palette") if cmap is not None and palette is not None: @@ -597,10 +616,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st else: raise TypeError("Parameter 'cmap' must be a string, a Colormap, or a list of these types.") + +def _check_na_color(param_dict: dict[str, Any]) -> None: # validation happens within Color constructor (images don't use na_color) if "na_color" in param_dict: param_dict["na_color"] = Color(param_dict.get("na_color")) + +def _check_norm(param_dict: dict[str, Any], element_type: str) -> None: norm = param_dict.get("norm") if norm is not None: if element_type == "images": @@ -618,6 +641,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if element_type == "graph" and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be a Normalize instance.") + +def _check_scale(param_dict: dict[str, Any], element_type: str) -> None: scale = param_dict.get("scale") if scale is not None: if element_type in {"images", "labels"} and not isinstance(scale, str): @@ -628,6 +653,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if scale < 0: raise ValueError("Parameter 'scale' must be a positive number.") + +def _check_size(param_dict: dict[str, Any]) -> None: size = param_dict.get("size") if size: if not isinstance(size, float | int): @@ -635,6 +662,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if size < 0: raise ValueError("Parameter 'size' must be a positive number.") + +def _check_shape(param_dict: dict[str, Any], element_type: str) -> None: shape = param_dict.get("shape") if element_type == "shapes" and shape is not None: valid_shapes = {"circle", "hex", "visium_hex", "square"} @@ -643,6 +672,8 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if shape not in valid_shapes: raise ValueError(f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}.") + +def _check_table(param_dict: dict[str, Any]) -> None: table_name = param_dict.get("table_name") table_layer = param_dict.get("table_layer") if table_name and not isinstance(param_dict["table_name"], str): @@ -690,10 +721,14 @@ def _ensure_table_and_layer_exist_in_sdata( _ensure_table_and_layer_exist_in_sdata(param_dict.get("sdata"), table_name, table_layer) + +def _check_method(param_dict: dict[str, Any]) -> None: method = param_dict.get("method") if method not in ["matplotlib", "datashader", None]: raise ValueError("If specified, parameter 'method' must be either 'matplotlib' or 'datashader'.") + +def _check_ds_reduction(param_dict: dict[str, Any]) -> None: valid_ds_reduction_methods = [ "sum", "mean", @@ -710,6 +745,8 @@ def _ensure_table_and_layer_exist_in_sdata( if ds_reduction and (ds_reduction not in valid_ds_reduction_methods): raise ValueError(f"Parameter 'ds_reduction' must be one of the following: {valid_ds_reduction_methods}.") + +def _check_graph_params(param_dict: dict[str, Any], element_type: str) -> None: if element_type == "graph": for key in ("connectivity_key",): val = param_dict.get(key) @@ -739,6 +776,28 @@ def _ensure_table_and_layer_exist_in_sdata( if val is not None and not isinstance(val, bool): raise TypeError(f"Parameter '{key}' must be a boolean.") + +def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[str, Any]: + # Call order is the contract: the first-raised error and in-place mutations must match the pre-split form. + _check_colorbar(param_dict) + _check_element(param_dict, element_type) + _check_channel(param_dict) + _check_contour_px_type(param_dict) + _check_color(param_dict, element_type) + _check_outline(param_dict, element_type) + _check_contour_px_range(param_dict) # must stay after outline (preserves first-error order) + _check_alpha(param_dict, element_type) + _check_fill_alpha(param_dict, element_type) + _check_cmap_palette_groups(param_dict, element_type) + _check_na_color(param_dict) + _check_norm(param_dict, element_type) + _check_scale(param_dict, element_type) + _check_size(param_dict) + _check_shape(param_dict, element_type) + _check_table(param_dict) + _check_method(param_dict) + _check_ds_reduction(param_dict) + _check_graph_params(param_dict, element_type) return param_dict diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index def6d94c..6efc5452 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -273,6 +273,21 @@ def test_extract_scalar_value(): assert _extract_scalar_value([], default=1.0) == 1.0 +def test_type_check_params_preserves_validation_order(): + """#716: decomposition must preserve call order, so the first error for a multiply-invalid input is + unchanged -- `color` is validated before the `contour_px < 2` range check. + """ + from spatialdata_plot.pl._validate import _type_check_params + + # color=5 raises in the color block, before the contour_px range check -> color error wins. + with pytest.raises(TypeError, match="Parameter 'color' must be a string or a tuple/list of floats."): + _type_check_params({"element": "x", "color": 5, "contour_px": 1, "sdata": None}, "labels") + + # a valid RGB tuple (no sdata-dependent collision check) lets the same input reach contour_px. + with pytest.raises(ValueError, match="Parameter 'contour_px' must be >= 2"): + _type_check_params({"element": "x", "color": (1.0, 0.0, 0.0), "contour_px": 1, "sdata": None}, "labels") + + def test_plot_can_handle_rgba_color_specifications(sdata_blobs: SpatialData): """Test handling of RGBA color specifications.""" # Test with RGBA tuple