import squidpy as sq
import matplotlib.pyplot as plt
import pandas as pd
import anndata
from spac.utils import (
check_annotation,
check_table,
check_distances,
check_label,
)
import numpy as np
from scipy.spatial import KDTree
from scipy.spatial import distance_matrix
from sklearn.preprocessing import LabelEncoder
from functools import partial
import logging
[docs]
def spatial_interaction(
adata,
annotation,
analysis_method,
stratify_by=None,
ax=None,
return_matrix=False,
seed=None,
coord_type=None,
n_rings=1,
n_neighs=6,
radius=None,
cmap="seismic",
**kwargs):
"""
Perform spatial analysis on the selected annotation in the dataset.
Current analysis methods are provided in squidpy:
Neighborhood Enrichment,
Cluster Interaction Matrix
Parameters
----------
adata : anndata.AnnData
The AnnData object.
annotation : str
The column name of the annotation (e.g., phenotypes) to analyze in the
provided dataset.
analysis_method : str
The analysis method to use, currently available:
"Neighborhood Enrichment" and "Cluster Interaction Matrix".
stratify_by : str or list of strs
The annotation[s] to stratify the dataset when generating
interaction plots. If single annotation is passed, the dataset
will be stratified by the unique labels in the annotation column.
If n (n>=2) annotations are passed, the function will be stratified
based on existing combination of labels in the passed annotations.
ax: matplotlib.axes.Axes, default None
The matplotlib Axes to display the image. This option is only
available when stratify is None.
return_matrix: boolean, default False
If true, the fucntion will return a list of two dictionaries,
the first contains axes and the second containing computed matrix.
Note that for Neighborhood Encrichment, the matrix will be a tuple
with the z-score and the enrichment count.
For Cluster Interaction Matrix, it will returns the
interaction matrix.
If False, the function will return only the axes dictionary.
seed: int, default None
Random seed for reproducibility, used in Neighborhood Enrichment
Analysis.
coord_type : str, optional
Type of coordinate system used in sq.gr.spatial_neighbors.
Should be either 'grid' (Visium Data) or 'generic' (Others).
Default is None, decided by the squidy pacakge. If spatial_key
is in anndata.uns the coord_type would be 'grid', otherwise
general.
n_rings : int, default 1
Number of rings of neighbors for grid data.
Only used when coord_type = 'grid' (Visium)
n_neights : int, optional
Default is 6.
Depending on the ``coord_type``:
- 'grid' (Visium) - number of neighboring tiles.
- 'generic' - number of neighborhoods for non-grid data.
radius : float, optional
Default is None.
Only available when coord_type = 'generic'. Depending on the type:
- :class:`float` - compute the graph based on neighborhood radius.
- :class:`tuple` - prune the final graph to only contain
edges in interval `[min(radius), max(radius)]`.
cmap : str, default 'seismic'
The colormap to use for the plot. The 'seismic' color map
consist of three color regions: red for positive, blue for negative,
and the white at the center. This color map effectively represents
the nature of the spatial interaction analysis results,
where positive values indicate clustering and
negative values indicate seperation. For more color maps, please visit
https://matplotlib.org/stable/tutorials/colors/colormaps.html
**kwargs
Keyword arguments for matplotlib.pyplot.text()
Returns
-------
dict
A dictionary containing the results of the spatial interaction
analysis. The keys of the dictionary depend on the parameters
passed to the function:
Ax : dict or matplotlib.axes.Axes
If `stratify_by` is not used, returns a single
matplotlib.axes.Axes object. If `stratify_by` is used,
returns a dictionary of Axes objects, with keys
representing the stratification groups.
Matrix : dict, optional
Contains processed DataFrames of computed matrices with row and
column labels applied. If `stratify_by` is used, the keys represent
the stratification groups. For example:
- `results['Matrix']['GroupA']` for a specific stratification group.
- If `stratify_by` is not used, the table is accessible via
`results['Matrix']['annotation']`.
"""
# List all available methods
available_methods = [
"Neighborhood Enrichment",
"Cluster Interaction Matrix"
]
available_methods_str = ",".join(available_methods)
# pacakge each methods into a function to allow
# centralized control and improve flexibility
def _Neighborhood_Enrichment_Analysis(
adata,
categorical_annotation,
ax,
return_matrix=False,
title=None,
seed=None,
**kwargs
):
"""
Perform Neighborhood Enrichment analysis.
Parameters
----------
adata : anndata.AnnData
The AnnData object.
categorical_annotation : str
Name of the annotation column to analyze.
ax : matplotlib.axes.Axes
Axes to plot the enrichment results.
return_matrix : bool
If True, returns the enrichment matrix.
title : str, optional
Title of the plot.
seed : int, optional
Random seed for reproducibility.
**kwargs : dict
Additional keyword arguments for the plot.
Returns
-------
ax or [ax, matrix]
The plot axes or axes and enrichment matrix.
"""
# Calculate Neighborhood_Enrichment
if return_matrix:
matrix = sq.gr.nhood_enrichment(
adata,
copy=True,
seed=seed,
cluster_key=categorical_annotation
)
sq.gr.nhood_enrichment(
adata,
seed=seed,
cluster_key=categorical_annotation
)
else:
sq.gr.nhood_enrichment(
adata,
seed=seed,
cluster_key=categorical_annotation
)
# Plot Neighborhood_Enrichment
sq.pl.nhood_enrichment(
adata,
cluster_key=categorical_annotation,
title=title,
ax=ax,
**kwargs
)
if return_matrix:
return [ax, matrix]
else:
return ax
def _Cluster_Interaction_Matrix_Analysis(
adata,
categorical_annotation,
ax,
return_matrix=False,
title=None,
**kwargs
):
"""
Perform Cluster Interaction Matrix analysis.
Parameters
----------
adata : anndata.AnnData
The AnnData object.
categorical_annotation : str
Name of the annotation column to analyze.
ax : matplotlib.axes.Axes
Axes to plot the interaction matrix.
return_matrix : bool
If True, returns the interaction matrix.
title : str, optional
Title of the plot.
**kwargs : dict
Additional keyword arguments for the plot.
Returns
-------
ax or [ax, matrix]
The plot axes or axes and interaction matrix.
"""
# Calculate Cluster_Interaction_Matrix
if return_matrix:
matrix = sq.gr.interaction_matrix(
adata,
cluster_key=categorical_annotation,
copy=True
)
sq.gr.interaction_matrix(
adata,
cluster_key=categorical_annotation
)
else:
sq.gr.interaction_matrix(
adata,
cluster_key=categorical_annotation
)
sq.pl.interaction_matrix(
adata,
title=title,
cluster_key=categorical_annotation,
ax=ax,
**kwargs
)
if return_matrix:
return [ax, matrix]
else:
return ax
# Perfrom the actual analysis, first call sq.gr.spatial_neighbors
# to calculate neighboring graph, then do different analysis.
def _perform_analysis(
adata,
analysis_method,
categorical_annotation,
ax,
coord_type,
n_rings,
n_neighs,
radius,
return_matrix=False,
title=None,
seed=None,
**kwargs
):
"""
Perform the specified spatial analysis method.
Parameters
----------
Same as parent function.
Returns
-------
ax or [ax, matrix]
The plot axes or axes and matrix results.
"""
sq.gr.spatial_neighbors(
adata,
coord_type=coord_type,
n_rings=n_rings,
n_neighs=n_neighs,
radius=radius
)
if analysis_method == "Neighborhood Enrichment":
ax = _Neighborhood_Enrichment_Analysis(
adata,
categorical_annotation,
ax,
return_matrix,
title,
seed,
**kwargs)
elif analysis_method == "Cluster Interaction Matrix":
ax = _Cluster_Interaction_Matrix_Analysis(
adata,
categorical_annotation,
ax,
return_matrix,
title,
**kwargs)
return ax
def _get_labels(
fig,
unique_annotations,
verbose=False
):
"""
Extract row and column labels from plot axes.
Parameters
----------
fig : matplotlib.figure.Figure
The figure containing the plots.
unique_annotations : list
List of unique annotation labels.
verbose : bool, default False
If True, print debugging information.
Returns
-------
list
List of row labels.
"""
row_labels = []
# Iterate over all axes to check if any contain the row labels
for i, ax in enumerate(fig.axes):
if verbose:
print(f"Inspecting axis {i}...")
# Try to extract labels from the y-axis of each axis
yticklabels = [tick.get_text() for tick in ax.get_yticklabels()]
xticklabels = [tick.get_text() for tick in ax.get_xticklabels()]
if yticklabels and xticklabels:
raise ValueError(
"Both x- and y-axis labels found on axis. "
"Unable to determine row labels."
)
elif yticklabels and not xticklabels:
if set(yticklabels) <= set(unique_annotations):
if verbose:
print(f"Row labels found on axis {i}: {yticklabels}")
row_labels = yticklabels[::-1]
# Try extracting other possible labels (x-axis, title, etc.)
elif xticklabels and not yticklabels:
if set(xticklabels) <= set(unique_annotations):
if verbose:
print(
f"Column labels found on axis {i}: {xticklabels}"
)
row_labels = xticklabels[::-1]
else:
if verbose:
print(
"No labels found on axis. Unable to determine row labels."
)
return row_labels
# Use to process the output from different
# spatial analysis method in squidpy
def _process_matrixs(
matrixs,
row_labels,
plot_label=None
):
"""
Process the output matrices for saving and visualization.
This function organizes matrices produced during the spatial analysis,
adding appropriate labels and creating a dictionary for easy access.
The processed matrices are returned in a labeled format suitable
for saving or further analysis.
Parameters
----------
matrixs : dict
Dictionary of raw matrices generated from spatial analyses.
row_labels : list
List of row and column labels for the matrices, used to annotate
the resulting DataFrames.
plot_label : str, optional
Additional label to append to the matrix file names, useful for
distinguishing between different stratification groups or analysis
scenarios.
Returns
-------
dict
A dictionary of labeled matrices, where the keys are the file names
(including the annotation and analysis method) and the values are
pandas DataFrames representing the matrices.
Notes
-----
For "Cluster Interaction Matrix", a single matrix is processed.
For "Neighborhood Enrichment", multiple matrices (e.g., z-score and
enrichment counts) are processed separately.
"""
return_dict = {}
if analysis_method == "Cluster Interaction Matrix":
if not isinstance(matrixs, pd.DataFrame):
matrix = pd.DataFrame(matrixs)
if len(row_labels) > 0:
matrix.index = row_labels
matrix.columns = row_labels
if plot_label is None:
file_name = f"{annotation}_{output_file_cat_list[0]}" + \
"_interaction_matrix.csv"
else:
file_name = f"{annotation}_{output_file_cat_list[0]}" + \
f"_{plot_label}_interaction_matrix.csv"
return_dict[file_name] = matrix
elif analysis_method == "Neighborhood Enrichment":
for i, matrix in enumerate(matrixs):
# Convert each 2D array to a DataFrame
if not isinstance(matrix, pd.DataFrame):
matrix = pd.DataFrame(matrix)
if len(row_labels) > 0:
matrix.index = row_labels
matrix.columns = row_labels
if plot_label is None:
file_name = f"{annotation}_{output_file_cat_list[i]}" + \
"_interaction_matrix.csv"
else:
file_name = f"{annotation}_{output_file_cat_list[i]}" + \
f"_{plot_label}" + \
"_interaction_matrix.csv"
return_dict[file_name] = matrix
return return_dict
# -----------------------------------------------
# Error Check Section
# -----------------------------------------------
if not isinstance(adata, anndata.AnnData):
error_text = "Input data is not an AnnData object. " + \
f"Got {str(type(adata))}"
raise ValueError(error_text)
check_annotation(
adata,
annotations=annotation,
parameter_name="annotation",
should_exist=True)
# Check if stratify_by is list or list of str
check_annotation(
adata,
annotations=stratify_by,
parameter_name="stratify_by",
should_exist=True)
if not isinstance(analysis_method, str):
error_text = "The analysis methods must be a string."
raise ValueError(error_text)
else:
if analysis_method not in available_methods:
error_text = f"Method {analysis_method}" + \
" is not supported currently. " + \
f"Available methods are: {available_methods_str}"
raise ValueError(error_text)
if ax is not None:
if not isinstance(ax, plt.Axes):
error_text = "Invalid 'ax' argument. " + \
"Expected an instance of matplotlib.axes.Axes. " + \
f"Got {str(type(ax))}"
raise ValueError(error_text)
else:
fig, ax = plt.subplots()
# Operational Section
# -----------------------------------------------
# Create a categorical column data for plotting\
# This is to avoid modifying the original annotation to comply with
# the squidpy function requirements
categorical_annotation = annotation + "_plot"
adata.obs[categorical_annotation] = pd.Categorical(
adata.obs[annotation])
if stratify_by:
if isinstance(stratify_by, list):
adata.obs['_spac_utils_concat_obs'] = \
adata.obs[stratify_by].astype(str).agg('_'.join, axis=1)
else:
adata.obs['_spac_utils_concat_obs'] = \
adata.obs[stratify_by]
# Partial function for the _perform_analysis function
# to allow for uniform parameter passing for both stratified
# and non-stratified analysis
perform_analysis_prefilled = partial(
_perform_analysis,
analysis_method=analysis_method,
categorical_annotation=categorical_annotation,
coord_type=coord_type,
n_rings=n_rings,
n_neighs=n_neighs,
radius=radius,
return_matrix=return_matrix,
seed=seed,
cmap=cmap
)
# Compute a connectivity matrix from spatial coordinates
if stratify_by:
ax_dictionary = {}
matrix_dictionary = {}
unique_values = adata.obs['_spac_utils_concat_obs'].unique()
for subset_key in unique_values:
# Subset the original AnnData object based on the unique value
subset_adata = adata[
adata.obs['_spac_utils_concat_obs'] == subset_key
].copy()
fig, ax = plt.subplots()
image_title = f"Group: {subset_key}"
ax = perform_analysis_prefilled(
adata=subset_adata,
ax=ax,
title=image_title,
**kwargs
)
if return_matrix:
ax_dictionary[subset_key] = ax[0]
matrix_dictionary[subset_key] = ax[1]
else:
ax_dictionary[subset_key] = ax
del subset_adata
if return_matrix:
results = {
"Ax": ax_dictionary,
"Matrix": matrix_dictionary
}
else:
results = {"Ax": ax_dictionary}
else:
ax = perform_analysis_prefilled(
adata=adata,
ax=ax,
**kwargs
)
if return_matrix:
results = {
"Ax": ax[0],
"Matrix": ax[1]
}
else:
results = {"Ax": ax}
# Adding post processing methods for updating images and retrieve
# Acquire the annotation labels and acquire
# the column names from axes and matrixs
if return_matrix:
output_file_cat_grosarry = {
"Neighborhood Enrichment": ["z_score", "enrichment_counts"],
"Cluster Interaction Matrix": ["interaction_counts"]
}
output_file_cat_list = output_file_cat_grosarry[analysis_method]
unique_annotations = list(adata.obs[annotation].unique())
_matrixs = results['Matrix']
_axs = results['Ax']
def _processes_function_return(
matrix,
ax,
plot_label=None
):
fig = ax.get_figure()
row_labels = _get_labels(
fig,
unique_annotations
)
result_dict = _process_matrixs(
matrix,
row_labels,
plot_label
)
return result_dict
table_results = {}
if stratify_by:
for key in _axs:
_ax = _axs[key]
_matrix = _matrixs[key]
result_dict = _processes_function_return(
_matrix,
_ax,
key
)
table_results[key] = result_dict
else:
table_results['annotation'] = _processes_function_return(
_matrixs,
_axs
)
results['Matrix'] = table_results
# Clean up the temporary columns
adata.obs.drop(categorical_annotation, axis=1, inplace=True)
if stratify_by:
adata.obs.drop('_spac_utils_concat_obs', axis=1, inplace=True)
return results
[docs]
def ripley_l(
adata,
annotation,
phenotypes,
distances,
regions=None,
spatial_key="spatial",
n_simulations=1,
area=None,
seed=42
):
"""
Calculate Ripley's L statistic for spatial data in `adata`.
Ripley's L statistic is a spatial point pattern analysis metric that
quantifies clustering or regularity in point patterns across various
distances. This function calculates the statistic for each region in
`adata` (if provided) or for all cells if regions are not specified.
Parameters
----------
adata : AnnData
The annotated data matrix containing the spatial coordinates and
cell phenotypes.
annotation : str
The key in `adata.obs` representing the annotation for cell phenotypes.
phenotypes : list of str
A list containing two phenotypes for which the Ripley L statistic
will be calculated. If the two phenotypes are the same, the
calculation is done for the same type; if different, it considers
interactions between the two.
distances : array-like
An array of distances at which to calculate Ripley's L statistic.
The values must be positive and incremental.
regions : str or None, optional
The key in `adata.obs` representing regions for stratifying the
data. If `None`, all cells will be treated as one region.
spatial_key : str, optional
The key in `adata.obsm` representing the spatial coordinates.
Default is `"spatial"`.
n_simulations : int, optional
Number of simulations to perform for significance testing. Default is 100.
area : float or None, optional
The area of the spatial region of interest. If `None`, the area
will be inferred from the data. Default is `None`.
seed : int, optional
Random seed for simulation reproducibility. Default is 42.
Returns
-------
pd.DataFrame
A DataFrame containing the Ripley's L results for each region or
the entire dataset if `regions` is `None`. The DataFrame includes
the following columns:
- `region`: The region label or 'all' if no regions are specified.
- `center_phenotype`: The first phenotype in `phenotypes`.
- `neighbor_phenotype`: The second phenotype in `phenotypes`.
- `ripley_l`: The Ripley's L statistic calculated for the region.
- `config`: A dictionary with configuration settings used for the calculation.
Notes
-----
Ripley's L is an adjusted version of Ripley's K that corrects for the
inherent increase in point-to-point distances as the distance grows.
This statistic is used to evaluate spatial clustering or dispersion
of points (cells) in biological datasets.
The function uses pre-defined distances and performs simulations to
assess the significance of observed patterns. The results are stored
in the `.uns` attribute of `adata` under the key `'ripley_l'`, or in
a new DataFrame if no prior results exist.
Examples
--------
Calculate Ripley's L for two phenotypes in a single region dataset:
>>> result = ripley_l(adata, annotation='cell_type', phenotypes=['A', 'B'], distances=np.linspace(0, 500, 100))
Calculate Ripley's L for multiple regions in `adata`:
>>> result = ripley_l(adata, annotation='cell_type', phenotypes=['A', 'B'], distances=np.linspace(0, 500, 100), regions='region_key')
"""
# Check that distances is array-like with incremental positive values
check_distances(distances)
# Check that annotation and phenotypes exist in adata.obs
check_annotation(adata, annotations=[annotation], should_exist=True)
check_label(adata, annotation, phenotypes)
# Convert annotations to categorical
adata.obs[annotation] = pd.Categorical(adata.obs[annotation])
if regions is not None:
check_annotation(adata, annotations=[regions], should_exist=True)
# Import ripley function from the spac library
from spac._ripley import ripley
from functools import partial
# Partial function for Ripley calculation
ripley_func = partial(
ripley,
cluster_key=annotation,
mode='L',
spatial_key=spatial_key,
phenotypes=phenotypes,
support=distances,
n_simulations=n_simulations,
seed=seed,
area=area,
copy=True
)
# Check if adata already has ripley_l results,
# else initialize a result DataFrame
if 'ripley_l' in adata.uns.keys():
results = adata.uns['ripley_l']
else:
results = None
# Function to process Ripley L calculation for a region
def process_region(adata_region, region_label):
# Calculate number of cells in the region for the phenotypes
# n_cells = get_ncells(adata_region, annotation, phenotypes)
region_cells = adata_region.n_obs
# Log the region and cell info
print(
f'Processing region:"{region_label}".\n'
f'Cells in region:"{region_cells}"'
)
cell_counts = adata_region.obs[annotation].value_counts()
if region_cells < 3:
message = (
f'WARNING, not enough cells in region "{region_label}". '
f'Number of cells "{region_cells}". '
'Skipping Ripley L.'
)
print(message)
ripley_result = None
elif not phenotypes[0] in cell_counts.index:
message = (
f'WARNING, phenotype "{phenotypes[0]}" '
f'not found in region "{region_label}", skipping Ripley L.'
)
print(message)
ripley_result = None
elif not phenotypes[1] in cell_counts.index:
message = (
f'WARNING, phenotype "{phenotypes[1]}" '
f'not found in region "{region_label}", skipping Ripley L.'
)
print(message)
ripley_result = None
else:
# Calculate Ripley's L statistic using the partial function
ripley_result = ripley_func(adata=adata_region)
message = "Ripley's L successfully calculated."
# Create a result entry for the region
new_result = {
'region': region_label,
'center_phenotype': phenotypes[0],
'neighbor_phenotype': phenotypes[1],
'ripley_l': ripley_result,
'region_cells': region_cells,
'message': message,
'n_simulations': n_simulations,
'seed': seed
}
return new_result
def append_results(results, new_result):
# Convert the new result to a DataFrame
new_df = pd.DataFrame([new_result])
# Check if the dataframe exists and concatenate
if results is not None:
results = pd.concat([results, new_df], ignore_index=True)
else:
results = new_df
return results
# If regions are provided, process each region,
# else process all cells as a single region
if regions is not None:
for region in adata.obs[regions].unique():
adata_region = adata[adata.obs[regions] == region]
new_result = process_region(adata_region, region)
results = append_results(results, new_result)
else:
new_result = process_region(adata, 'all')
results = append_results(results, new_result)
# Save the results in the AnnData object
adata.uns['ripley_l'] = results
return results
[docs]
def neighborhood_profile(
adata,
phenotypes,
distances,
regions=None,
spatial_key="spatial",
normalize=None,
associated_table_name="neighborhood_profile"
):
"""
Calculate the neighborhood profile for every cell in all slides in an analysis
and update the input AnnData object in place.
Parameters
----------
adata : AnnData
The AnnData object containing the spatial coordinates and phenotypes.
phenotypes : str
The name of the column in adata.obs that contains the phenotypes.
distances : list
The list of increasing distances for the neighborhood profile.
spatial_key : str, optional
The key in adata.obs that contains the spatial coordinates. Default is
'spatial'.
normalize : str or None, optional
If 'total_cells', normalize the neighborhood profile based on the
total number of cells in each bin.
If 'bin_area', normalize the neighborhood profile based on the area
of every bin. Default is None.
associated_table_name : str, optional
The name of the column in adata.obsm that will contain the
neighborhood profile. Default is 'neighborhood_profile'.
regions : str or None, optional
The name of the column in adata.obs that contains the regions.
If None, all cells in adata will be used. Default is None.
Returns
-------
None
The function modifies the input AnnData object in place, adding a new
column containing the neighborhood profile to adata.obsm.
Notes
-----
The input AnnData object 'adata' is modified in place. The function adds a
new column containing the neighborhood profile to adata.obsm, named by the
parameter 'associated_table_name'. The associated_table_name is a 3D array of
shape (n_cells, n_phenotypes, n_bins) where n_cells is the number of cells
in the all slides, n_phenotypes is the number of unique phenotypes, and
n_bins is the number of bins in the distances list.
A dictionary is added to adata.uns[associated_table_name] with the two keys
"bins" and "labels". "labels" will store all the values in the phenotype
annotation.
"""
# Check that distances is array like with incremental positive values
check_distances(distances)
# Check that phenotypes is adata.obs
check_annotation(
adata,
annotations=[phenotypes],
should_exist=True)
# Check that phenotypes is adata.obs
if regions is not None:
check_annotation(
adata,
annotations=[regions],
should_exist=True)
check_table(
adata=adata,
tables=spatial_key,
should_exist=True,
associated_table=True
)
# Check the values of normalize
if normalize is not None and normalize not in ['total_cells', 'bin_area']:
raise ValueError((f'normalize must be "total_cells", "bin_area"'
f' or None. Got "{normalize}"'))
# Check that the associated_table_name does not exist.
# Raise a warning othewise
check_table(
adata=adata,
tables=associated_table_name,
should_exist=False,
associated_table=True,
warning=True
)
logger = logging.getLogger()
# Convert the phenotypes to integers using label encoder
labels = adata.obs[phenotypes].values
le = LabelEncoder().fit(labels)
n_phenotypes = len(le.classes_)
# Create a place holder for the neighborhood profile
all_cells_profiles = np.zeros(
(adata.n_obs, n_phenotypes, len(distances)-1))
# If regions is None, use all cells in adata
if regions is not None:
# Calculate the neighborhood profile for every slide
for i, region in enumerate(adata.obs[regions].unique()):
adata_region = adata[adata.obs[regions] == region]
logger.info(f"Processing region:{region} \
n_cells:{len(adata_region)}")
positions = adata_region.obsm[spatial_key]
labels_id = le.transform(adata_region.obs[phenotypes].values)
region_profiles = _neighborhood_profile_core(
positions,
labels_id,
n_phenotypes,
distances,
normalize
)
# Updated profiles of the cells of the current slide
all_cells_profiles[adata.obs[regions] == region] = (
region_profiles
)
else:
logger.info(("Processing all cells as a single region."
f" n_cells:{len(adata)}"))
positions = adata.obsm[spatial_key]
labels_id = le.transform(labels)
all_cells_profiles = _neighborhood_profile_core(
positions,
labels_id,
n_phenotypes,
distances,
normalize
)
# Add the neighborhood profile to the AnnData object
adata.obsm[associated_table_name] = all_cells_profiles
# Store the bins and the lables in uns
summary = {"bins": distances, "labels": le.classes_}
if associated_table_name in adata.uns:
logger.warning(f"The analysis already contains the \
unstructured value:{associated_table_name}. \
It will be overwriten")
adata.uns[associated_table_name] = summary
[docs]
def _neighborhood_profile_core(
coord,
phenotypes,
n_phenotypes,
distances_bins,
normalize=None
):
"""
Calculate the neighborhood profile for every cell in a region.
Parameters
----------
coord : numpy.ndarray
The coordinates of the cells in the region. Should be a 2D array of
shape (n_cells, 2) representing x, y coordinates.
phenotypes : numpy.ndarray
The phenotypes of the cells in the region.
n_phenotypes : int
The number of unique phenotypes in the region.
distances_bins : list
The bins defining the distance ranges for the neighborhood profile.
normalize : str or None, optional
If 'total_cells', normalize the neighborhood profile based on the
total number of cells in each bin.
If 'bin_area', normalize the neighborhood profile based on the area
of every bin.
Returns
-------
numpy.ndarray
A 3D array containing the neighborhood profile for every cell in the
region. The dimensions are (n_cells, n_phenotypes, n_intervals).
Notes
-----
- The function calculates the neighborhood profile for each cell, which
represents the distribution of neighboring cells' phenotypes within
different distance intervals.
- The 'distances_bins' parameter should be a list defining the bins for
the distance intervals. It is assumed that the bins are incremental,
starting from 0.
"""
# TODO Check that distances bins is incremental
max_distance = distances_bins[-1]
kdtree = KDTree(coord)
# indexes is a list of neighbors coordinate for every
# cell
indexes = kdtree.query_ball_tree(kdtree, r=max_distance)
# Create phenotype bins to include the integer equivalent of
# every phenotype to use the histogram2d function instead of
# a for loop over every phenotype
phenotype_bins = np.arange(-0.5, n_phenotypes + 0.5, 1)
n_intervals = len(distances_bins) - 1
neighborhood_profile = []
for i, neighbors in enumerate(indexes):
# Query_ball_tree will include the point itself
neighbors.remove(i)
# To potentially save on calculating the histogram
if len(neighbors) == 0:
neighborhood_profile.append(np.zeros((n_phenotypes, n_intervals)))
else:
neighbor_coords = coord[neighbors]
dist_matrix = distance_matrix(coord[i:i+1], neighbor_coords)[0]
neighbors_phenotypes = phenotypes[neighbors]
# Returns a 2D histogram of size n_phenotypes * n_intervals
histograms_array, _ , _ = np.histogram2d(neighbors_phenotypes,
dist_matrix,
bins=[
phenotype_bins,
distances_bins
])
neighborhood_profile.append(histograms_array)
neighborhood_array = np.stack(neighborhood_profile)
if normalize == "total_cells":
bins_sum = neighborhood_array.sum(axis=1)
bins_sum[bins_sum == 0] = 1
neighborhood_array = neighborhood_array / bins_sum[:, np.newaxis, :]
elif normalize == "bin_area":
circles_areas = np.pi * np.array(distances_bins)**2
bins_areas = np.diff(circles_areas)
neighborhood_array = neighborhood_array / bins_areas[np.newaxis, np.newaxis, :]
return neighborhood_array
[docs]
def calculate_nearest_neighbor(
adata,
annotation,
spatial_associated_table='spatial',
imageid=None,
label='spatial_distance',
verbose=True
):
"""
Computes the shortest distance from each cell to the nearest cell of
each phenotype (via scimap.tl.spatial_distance) and stores the
resulting DataFrame in `adata.obsm[label]`.
Parameters
----------
adata : anndata.AnnData
Annotated data matrix with spatial information.
annotation : str
Column name in `adata.obs` containing cell annotationsi (i.e.
phenotypes).
spatial_associated_table : str, optional
Key in `adata.obsm` where spatial coordinates are stored. Default is
'spatial'.
imageid : str, optional
The column in `adata.obs` specifying image IDs. If None,
a dummy image column is created temporarily. Spatial distances are
computed across the entire dataseti as if it's one image.
label : str, optional
The key under which results are stored in `adata.obsm`. Default is
'spatial_distance'.
verbose : bool, optional
If True, prints progress messages. Default is True.
Returns
-------
None
Modifies `adata` in place by storing a DataFrame of
spatial distances in `adata.obsm[label]`.
Example
-------
For a dataset with two cells (CellA, CellB) both of the same phenotype
"type1", the output might look like:
>>> adata.obsm['spatial_distance']
type1
CellA 0.0
CellB 0.0
For a dataset with two phenotypes "type1" and "type2", the output might
look like:
>>> adata.obsm['spatial_distance']
type1 type2
CellA 0.00 1.414214
CellB 1.414214 0.00
Input:
adata.obs:
cell_type imageid
type1 image1
type1 image1
type2 image1
adata.obsm['spatial']:
[[0.0, 0.0], [1.0, 1.0], [2.0, 2.0]]
Output stored in adata.obsm['spatial_distance']:
type1 type2
0 0.0 1.414
1 1.414 0.0
2 2.236 1.0
Raises
------
ValueError
If `spatial_associated_table` is not found in `adata.obsm`.
If spatial coordinates are missing or invalid.
"""
import scimap as sm
# Input validation for annotation
check_annotation(adata, annotations=annotation)
# Validate and extract spatial coordinates
check_table(
adata,
tables=spatial_associated_table,
associated_table=True,
should_exist=True
)
coords = adata.obsm[spatial_associated_table]
if coords.shape[1] < 2:
raise ValueError(
"The input data must include coordinates with at least "
"two dimensions, such as X and Y positions."
)
# Check for missing coordinates
if np.isnan(coords).any():
missing_cells = np.where(np.isnan(coords).any(axis=1))[0]
raise ValueError(
f"Missing values found in spatial coordinates for cells "
f"at indices: {missing_cells}."
)
if verbose:
print("Preparing data for spatial distance calculation...")
# Add coordinates to adata.obs temporarily
adata.obs['_x_coord'] = coords[:, 0]
adata.obs['_y_coord'] = coords[:, 1]
use_z = False
if coords.shape[1] > 2:
adata.obs['_z_coord'] = coords[:, 2]
use_z = True
# Handle imageid logic
dummy_column_created = False
original_imageid = imageid
if imageid is None:
dummy_column_created = True
imageid = '_dummy_imageid'
adata.obs[imageid] = 'dummy_image' # Treat all cells as one 'image'
sm.tl.spatial_distance(
adata=adata,
x_coordinate='_x_coord',
y_coordinate='_y_coord',
z_coordinate=('_z_coord' if use_z else None),
phenotype=annotation,
imageid=imageid,
verbose=verbose,
label=label
)
# The scimap function stores the result in adata.uns[label].
# Need to align it to adata.obs_names before placing into .obsm.
result_df = adata.uns.pop(label) # remove from uns and capture
# Reindex the results to match adata.obs_names.
result_df = result_df.reindex(adata.obs_names)
adata.obsm[label] = result_df
# Remove temporary coordinates from adata.obs
drop_cols = ['_x_coord', '_y_coord']
if use_z:
drop_cols.append('_z_coord')
adata.obs.drop(columns=drop_cols, inplace=True, errors='ignore')
# Remove dummy column if it was created
if dummy_column_created:
adata.obs.drop(columns=[imageid], inplace=True, errors='ignore')
imageid = original_imageid # restore the original state
if verbose:
print(f"Spatial distances stored in adata.obsm['{label}']")
print("Preview of the distance DataFrame:\n", adata.obsm[label].head())