Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 99 additions & 78 deletions src/spatialdata_plot/pl/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,92 @@ def _extract_color_column(
return values.reindex(element.index)


def _resolve_color_origins(
value_to_plot: str | None,
sdata: sd.SpatialData,
element_name: list[str] | str | None,
table_name: str | None,
) -> tuple[list[Any], bool]:
"""Locate the color column and resolve df-vs-table shadowing; raise if it lives in >1 place."""
origins = _locate_value(value_key=value_to_plot, sdata=sdata, element_name=element_name, table_name=table_name)
# An explicit `table_name=` disambiguates a column present in both the element df and the table.
explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins)
if explicit_table_shadows_df:
origins = [o for o in origins if o.origin != "df"]
if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
"Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity."
)
return origins, explicit_table_shadows_df


def _fetch_color_source_vector(
sdata: sd.SpatialData,
element: SpatialElement | None,
element_name: list[str] | str | None,
value_to_plot: str | None,
table_name: str | None,
table_layer: str | None,
origins: list[Any],
explicit_table_shadows_df: bool,
preloaded_color_data: pd.Series | None,
) -> ArrayLike | pd.Series:
"""Read the raw color column, preferring a direct aligned read over a whole-table join."""
if preloaded_color_data is not None:
return preloaded_color_data
if (
isinstance(element, GeoDataFrame)
and isinstance(element_name, str)
and table_name is not None
and table_name in sdata.tables
and origins[0].origin in ("obs", "var")
):
# Fast path: read the single aligned column directly instead of joining/copying the
# whole annotating table (the join's out-of-order sparse row-gather dominates large renders).
return _extract_color_column(
sdata[table_name],
value_to_plot,
origin=origins[0].origin,
element=element,
element_name=element_name,
table_layer=table_layer,
)
if explicit_table_shadows_df:
# Pass the table as `element` so upstream `get_values` skips the
# element-column lookup and avoids the multi-origin error.
return get_values(
value_key=value_to_plot,
element=sdata[table_name],
element_name=element_name,
table_layer=table_layer,
)[value_to_plot]
return get_values(
value_key=value_to_plot,
sdata=sdata,
element_name=element_name,
table_name=table_name,
table_layer=table_layer,
)[value_to_plot]


def _resolve_color_table(value_from_element: bool, table_name: str | None, sdata: sd.SpatialData) -> str | None:
"""Pick which table supplies .uns colors: none if the value is element-local, else the named or sole table."""
if value_from_element:
return None
if table_name is not None:
if table_name in sdata.tables:
return table_name
logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.")
return None
table_keys = list(sdata.tables.keys())
if not table_keys:
return None
if len(table_keys) > 1:
logger.warning(f"No table name provided, using '{table_keys[0]}' as fallback for color mapping.")
return table_keys[0]


def _set_color_source_vec(
sdata: sd.SpatialData,
element: SpatialElement | None,
Expand All @@ -478,26 +564,7 @@ def _set_color_source_vec(
color = np.full(len(element), na_color.get_hex_with_alpha())
return color, color, False

# Figure out where to get the color from
origins = _locate_value(
value_key=value_to_plot,
sdata=sdata,
element_name=element_name,
table_name=table_name,
)

# When both the element's own dataframe and the chosen table contain a
# column with this name, an explicit `table_name=` resolves the ambiguity —
# keep only the table origin and skip the multi-origin error below.
explicit_table_shadows_df = table_name is not None and any(o.origin == "df" for o in origins)
if explicit_table_shadows_df:
origins = [o for o in origins if o.origin != "df"]

if len(origins) > 1:
raise ValueError(
f"Color key '{value_to_plot}' for element '{element_name}' was found in multiple locations: {origins}. "
"Please keep it in exactly one place (preferably on the points parquet for speed) to avoid ambiguity."
)
origins, explicit_table_shadows_df = _resolve_color_origins(value_to_plot, sdata, element_name, table_name)

if len(origins) == 1 and value_to_plot is not None:
if table_name is not None:
Expand All @@ -507,42 +574,17 @@ def _set_color_source_vec(
element_name=element_name,
table_name=table_name,
)
if preloaded_color_data is not None:
color_source_vector = preloaded_color_data
elif (
isinstance(element, GeoDataFrame)
and isinstance(element_name, str)
and table_name is not None
and table_name in sdata.tables
and origins[0].origin in ("obs", "var")
):
# Fast path: read the single aligned column directly instead of joining/copying the
# whole annotating table (the join's out-of-order sparse row-gather dominates large renders).
color_source_vector = _extract_color_column(
sdata[table_name],
value_to_plot,
origin=origins[0].origin,
element=element,
element_name=element_name,
table_layer=table_layer,
)
elif explicit_table_shadows_df:
# Pass the table as `element` so upstream `get_values` skips the
# element-column lookup and avoids the multi-origin error.
color_source_vector = get_values(
value_key=value_to_plot,
element=sdata[table_name],
element_name=element_name,
table_layer=table_layer,
)[value_to_plot]
else:
color_source_vector = get_values(
value_key=value_to_plot,
sdata=sdata,
element_name=element_name,
table_name=table_name,
table_layer=table_layer,
)[value_to_plot]
color_source_vector = _fetch_color_source_vector(
sdata=sdata,
element=element,
element_name=element_name,
value_to_plot=value_to_plot,
table_name=table_name,
table_layer=table_layer,
origins=origins,
explicit_table_shadows_df=explicit_table_shadows_df,
preloaded_color_data=preloaded_color_data,
)

color_series = (
color_source_vector if isinstance(color_source_vector, pd.Series) else pd.Series(color_source_vector)
Expand Down Expand Up @@ -588,29 +630,8 @@ def _set_color_source_vec(
processed = processed.reorder_categories(sorted(processed.categories))
color_source_vector = processed # convert, e.g., `pd.Series`

# When the value lives on the element's own DataFrame (origin="df"),
# there is no reason to look up a table for .uns colors.
value_from_element = origins[0].origin == "df"

# Use the provided table_name parameter, fall back to only one present
table_to_use: str | None
if value_from_element:
table_to_use = None
elif table_name is not None and table_name in sdata.tables:
table_to_use = table_name
elif table_name is not None and table_name not in sdata.tables:
logger.warning(f"Table '{table_name}' not found in `sdata.tables`. Falling back to default behavior.")
table_to_use = None
else:
table_keys = list(sdata.tables.keys())
if len(table_keys) == 1:
table_to_use = table_keys[0]
elif len(table_keys) > 1:
table_to_use = table_keys[0]
logger.warning(f"No table name provided, using '{table_to_use}' as fallback for color mapping.")
else:
table_to_use = None

table_to_use = _resolve_color_table(value_from_element, table_name, sdata)
adata_for_mapping = sdata[table_to_use] if table_to_use is not None else None

# Check if custom colors exist in the resolved table's .uns slot
Expand Down
87 changes: 43 additions & 44 deletions src/spatialdata_plot/pl/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,40 @@ def _validate_polygons(shapes: GeoDataFrame) -> GeoDataFrame:
return shapes


def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
verts = [
(
center.x + radius * math.cos(math.radians(a)),
center.y + radius * math.sin(math.radians(a)),
)
for a in range(30, 390, 60)
]
return shapely.Polygon(verts), None


def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
verts = [
(
center.x + radius * math.cos(math.radians(a)),
center.y + radius * math.sin(math.radians(a)),
)
for a in range(45, 360, 90)
]
return shapely.Polygon(verts), None


def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
return center, radius


def _enclosing_circle(coords: np.ndarray) -> tuple[shapely.Point, float]:
"""Enclosing circle from a point cloud: centroid of the convex hull and the max vertex distance."""
hull_pts = coords[ConvexHull(coords).vertices]
center = np.mean(hull_pts, axis=0)
radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1)))
return shapely.Point(center), radius


def _convert_shapes(
shapes: GeoDataFrame,
target_shape: str,
Expand All @@ -282,67 +316,32 @@ def _convert_shapes(
# work on a copy with a clean positional index
shapes = shapes.reset_index(drop=True).copy()

def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
verts = [
(
center.x + radius * math.cos(math.radians(a)),
center.y + radius * math.sin(math.radians(a)),
)
for a in range(30, 390, 60)
]
return shapely.Polygon(verts), None

def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]:
verts = [
(
center.x + radius * math.cos(math.radians(a)),
center.y + radius * math.sin(math.radians(a)),
)
for a in range(45, 360, 90)
]
return shapely.Polygon(verts), None

def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]:
return center, radius

def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]:
coords = np.array(polygon.exterior.coords)
hull_pts = coords[ConvexHull(coords).vertices]
center = np.mean(hull_pts, axis=0)
radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1)))
center, radius = _enclosing_circle(np.array(polygon.exterior.coords))
nonlocal warn_shape_size
if 2 * radius > max_extent * warn_above_extent_fraction:
warn_shape_size = True
return shapely.Point(center), radius
return center, radius

def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
c, r = _polygon_to_circle(polygon)
return _circle_to_hexagon(c, r)
return _circle_to_hexagon(*_polygon_to_circle(polygon))

def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]:
c, r = _polygon_to_circle(polygon)
return _circle_to_square(c, r)
return _circle_to_square(*_polygon_to_circle(polygon))

def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]:
pts = []
for poly in multipolygon.geoms:
pts.extend(poly.exterior.coords)
pts_array = np.array(pts)
hull_pts = pts_array[ConvexHull(pts_array).vertices]
center = np.mean(hull_pts, axis=0)
radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1)))
coords = np.array([pt for poly in multipolygon.geoms for pt in poly.exterior.coords])
center, radius = _enclosing_circle(coords)
nonlocal warn_shape_size
if 2 * radius > max_extent * warn_above_extent_fraction:
warn_shape_size = True
return shapely.Point(center), radius
return center, radius

def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
c, r = _multipolygon_to_circle(multipolygon)
return _circle_to_hexagon(c, r)
return _circle_to_hexagon(*_multipolygon_to_circle(multipolygon))

def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]:
c, r = _multipolygon_to_circle(multipolygon)
return _circle_to_square(c, r)
return _circle_to_square(*_multipolygon_to_circle(multipolygon))

# choose conversion methods
conversion_methods: dict[str, Any]
Expand Down
Loading
Loading