Source code for spac.visualization

import seaborn as sns
import seaborn
import pandas as pd
import numpy as np
import anndata
import scanpy as sc
import math
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.figure_factory as ff
import plotly.colors as pc
from matplotlib.colors import ListedColormap, BoundaryNorm
from spac.utils import check_table, check_annotation
from spac.utils import check_feature, annotation_category_relations
from spac.utils import check_label
from functools import partial
from spac.utils import color_mapping, spell_out_special_characters
from spac.data_utils import select_values
import logging
import warnings
import re
import copy


# Configure logging
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')


[docs] def visualize_2D_scatter( x, y, labels=None, point_size=None, theme=None, ax=None, annotate_centers=False, x_axis_title='Component 1', y_axis_title='Component 2', plot_title=None, color_representation=None, **kwargs ): """ Visualize 2D data using plt.scatter. Parameters ---------- x, y : array-like Coordinates of the data. labels : array-like, optional Array of labels for the data points. Can be numerical or categorical. point_size : float, optional Size of the points. If None, it will be automatically determined. theme : str, optional Color theme for the plot. Defaults to 'viridis' if theme not recognized. For a list of supported themes, refer to Matplotlib's colormap documentation: https://matplotlib.org/stable/tutorials/colors/colormaps.html ax : matplotlib.axes.Axes, optional (default: None) Matplotlib axis object. If None, a new one is created. annotate_centers : bool, optional (default: False) Annotate the centers of clusters if labels are categorical. x_axis_title : str, optional Title for the x-axis. y_axis_title : str, optional Title for the y-axis. plot_title : str, optional Title for the plot. color_representation : str, optional Description of what the colors represent. **kwargs Additional keyword arguments passed to plt.scatter. Returns ------- fig : matplotlib.figure.Figure The figure of the plot. ax : matplotlib.axes.Axes The axes of the plot. """ # Input validation if not hasattr(x, "__iter__") or not hasattr(y, "__iter__"): raise ValueError("x and y must be array-like.") if len(x) != len(y): raise ValueError("x and y must have the same length.") if labels is not None and len(labels) != len(x): raise ValueError("Labels length should match x and y length.") # Define color themes themes = { 'fire': plt.get_cmap('inferno'), 'viridis': plt.get_cmap('viridis'), 'inferno': plt.get_cmap('inferno'), 'blue': plt.get_cmap('Blues'), 'red': plt.get_cmap('Reds'), 'green': plt.get_cmap('Greens'), 'darkblue': ListedColormap(['#00008B']), 'darkred': ListedColormap(['#8B0000']), 'darkgreen': ListedColormap(['#006400']) } if theme and theme not in themes: error_msg = ( f"Theme '{theme}' not recognized. Please use a valid theme." ) raise ValueError(error_msg) cmap = themes.get(theme, plt.get_cmap('viridis')) # Determine point size num_points = len(x) if point_size is None: point_size = 5000 / num_points # Get figure size and fontsize from kwargs or set defaults fig_width = kwargs.get('fig_width', 10) fig_height = kwargs.get('fig_height', 8) fontsize = kwargs.get('fontsize', 12) if ax is None: fig, ax = plt.subplots(figsize=(fig_width, fig_height)) else: fig = ax.figure # Plotting logic if labels is not None: # Check if labels are categorical if pd.api.types.is_categorical_dtype(labels): # Determine how to access the categories based on # the type of 'labels' if isinstance(labels, pd.Series): unique_clusters = labels.cat.categories elif isinstance(labels, pd.Categorical): unique_clusters = labels.categories else: raise TypeError( "Expected labels to be of type Series[Categorical] or " "Categorical." ) # Combine colors from multiple colormaps cmap1 = plt.get_cmap('tab20') cmap2 = plt.get_cmap('tab20b') cmap3 = plt.get_cmap('tab20c') colors = cmap1.colors + cmap2.colors + cmap3.colors # Use the number of unique clusters to set the colormap length cmap = ListedColormap(colors[:len(unique_clusters)]) for idx, cluster in enumerate(unique_clusters): mask = np.array(labels) == cluster ax.scatter( x[mask], y[mask], color=cmap(idx), label=cluster, s=point_size ) print(f"Cluster: {cluster}, Points: {np.sum(mask)}") if annotate_centers: center_x = np.mean(x[mask]) center_y = np.mean(y[mask]) ax.text( center_x, center_y, cluster, fontsize=fontsize, ha='center', va='center' ) # Create a custom legend with color representation ax.legend( loc='best', bbox_to_anchor=(1.25, 1), # Adjusting position title=f"Color represents: {color_representation}" ) else: # If labels are continuous scatter = ax.scatter( x, y, c=labels, cmap=cmap, s=point_size, **kwargs ) plt.colorbar(scatter, ax=ax) if color_representation is not None: ax.set_title( f"{plot_title}\nColor represents: {color_representation}" ) else: scatter = ax.scatter(x, y, c='gray', s=point_size, **kwargs) # Equal aspect ratio for the axes ax.set_aspect('equal', 'datalim') # Set axis labels ax.set_xlabel(x_axis_title) ax.set_ylabel(y_axis_title) # Set plot title if plot_title is not None: ax.set_title(plot_title) return fig, ax
[docs] def dimensionality_reduction_plot( adata, method=None, annotation=None, feature=None, layer=None, ax=None, associated_table=None, **kwargs): """ Visualize scatter plot in PCA, t-SNE, UMAP, or associated table. Parameters ---------- adata : anndata.AnnData The AnnData object with coordinates precomputed by the 'tsne' or 'UMAP' function and stored in 'adata.obsm["X_tsne"]' or 'adata.obsm["X_umap"]' method : str, optional (default: None) Dimensionality reduction method to visualize. Choose from {'tsne', 'umap', 'pca'}. annotation : str, optional The name of the column in `adata.obs` to use for coloring the scatter plot points based on cell annotations. feature : str, optional The name of the gene or feature in `adata.var_names` to use for coloring the scatter plot points based on feature expression. layer : str, optional The name of the data layer in `adata.layers` to use for visualization. If None, the main data matrix `adata.X` is used. ax : matplotlib.axes.Axes, optional (default: None) A matplotlib axes object to plot on. If not provided, a new figure and axes will be created. associated_table : str, optional (default: None) Name of the key in `obsm` that contains the numpy array. Takes precedence over `method` **kwargs Parameters passed to visualize_2D_scatter function, including point_size. Returns ------- fig : matplotlib.figure.Figure The created figure for the plot. ax : matplotlib.axes.Axes The axes of the plot. """ # Check if both annotation and feature are specified, raise error if so if annotation and feature: raise ValueError( "Please specify either an annotation or a feature for coloring, " "not both.") # Use utility functions for input validation if layer: check_table(adata, tables=layer) if annotation: check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=[feature]) # Validate the method and check if the necessary data exists in adata.obsm if associated_table is None: valid_methods = ['tsne', 'umap', 'pca'] if method not in valid_methods: raise ValueError("Method should be one of {'tsne', 'umap', 'pca'}" f'. Got:"{method}"') key = f'X_{method}' if key not in adata.obsm.keys(): raise ValueError( f"{key} coordinates not found in adata.obsm. " f"Please run {method.upper()} before calling this function." ) else: check_table( adata=adata, tables=associated_table, should_exist=True, associated_table=True ) associated_table_shape = adata.obsm[associated_table].shape if associated_table_shape[1] != 2: raise ValueError( f'The associated table:"{associated_table}" does not have' f' two dimensions. It shape is:"{associated_table_shape}"' ) key = associated_table print(f'Running visualization using the coordinates: "{key}"') # Extract the 2D coordinates x, y = adata.obsm[key].T # Determine coloring scheme if annotation: color_values = adata.obs[annotation].astype('category').values color_representation = annotation elif feature: data_source = adata.layers[layer] if layer else adata.X color_values = data_source[:, adata.var_names == feature].squeeze() color_representation = feature else: color_values = None color_representation = None # Set axis titles based on method and color representation if method == 'tsne': x_axis_title = 't-SNE 1' y_axis_title = 't-SNE 2' plot_title = f'TSNE-{color_representation}' elif method == 'pca': x_axis_title = 'PCA 1' y_axis_title = 'PCA 2' plot_title = f'PCA-{color_representation}' elif method == 'umap': x_axis_title = 'UMAP 1' y_axis_title = 'UMAP 2' plot_title = f'UMAP-{color_representation}' else: x_axis_title = f'{associated_table} 1' y_axis_title = f'{associated_table} 2' plot_title = f'{associated_table}-{color_representation}' # Remove conflicting keys from kwargs kwargs.pop('x_axis_title', None) kwargs.pop('y_axis_title', None) kwargs.pop('plot_title', None) kwargs.pop('color_representation', None) fig, ax = visualize_2D_scatter( x=x, y=y, ax=ax, labels=color_values, x_axis_title=x_axis_title, y_axis_title=y_axis_title, plot_title=plot_title, color_representation=color_representation, **kwargs ) return fig, ax
[docs] def tsne_plot(adata, color_column=None, ax=None, **kwargs): """ Visualize scatter plot in tSNE basis. Parameters ---------- adata : anndata.AnnData The AnnData object with t-SNE coordinates precomputed by the 'tsne' function and stored in 'adata.obsm["X_tsne"]'. color_column : str, optional The name of the column to use for coloring the scatter plot points. ax : matplotlib.axes.Axes, optional (default: None) A matplotlib axes object to plot on. If not provided, a new figure and axes will be created. **kwargs Parameters passed to scanpy.pl.tsne function. Returns ------- fig : matplotlib.figure.Figure The created figure for the plot. ax : matplotlib.axes.Axes The axes of the tsne plot. """ if not isinstance(adata, anndata.AnnData): raise ValueError("adata must be an AnnData object.") if 'X_tsne' not in adata.obsm: err_msg = ("adata.obsm does not contain 'X_tsne', " "perform t-SNE transformation first.") raise ValueError(err_msg) # Create a new figure and axes if not provided if ax is None: fig, ax = plt.subplots() else: fig = ax.get_figure() if color_column and (color_column not in adata.obs.columns and color_column not in adata.var.columns): err_msg = f"'{color_column}' not found in adata.obs or adata.var." raise KeyError(err_msg) # Add color column to the kwargs for the scanpy plot if color_column: kwargs['color'] = color_column # Plot the t-SNE sc.pl.tsne(adata, ax=ax, **kwargs) return fig, ax
[docs] def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, x_log_scale=False, y_log_scale=False, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. Parameters ---------- adata : anndata.AnnData The AnnData object. feature : str, optional Name of continuous feature from adata.X to plot its histogram. annotation : str, optional Name of the annotation from adata.obs to plot its histogram. layer : str, optional Name of the layer in adata.layers to plot its histogram. group_by : str, default None Choose either to group the histogram by another column. together : bool, default False If True, and if group_by != None, create one plot combining all groups. If False, create separate histograms for each group. The appearance of combined histograms can be controlled using the `multiple` and `element` parameters in **kwargs. To control how the histograms are normalized (e.g., to divide the histogram by the number of elements in every group), use the `stat` parameter in **kwargs. For example, set `stat="probability"` to show the relative frequencies of each group. ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. x_log_scale : bool, default False If True, the data will be transformed using np.log1p before plotting, and the x-axis label will be adjusted accordingly. y_log_scale : bool, default False If True, the y-axis will be set to log scale. **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: - `multiple`: Determines how the subsets of data are displayed on the same axes. Options include: * "layer": Draws each subset on top of the other without adjustments. * "dodge": Dodges bars for each subset side by side. * "stack": Stacks bars for each subset on top of each other. * "fill": Adjusts bar heights to fill the axes. - `element`: Determines the visual representation of the bins. Options include: * "bars": Displays the typical bar-style histogram (default). * "step": Creates a step line plot without bars. * "poly": Creates a polygon where the bottom edge represents the x-axis and the top edge the histogram's bins. - `log_scale`: Determines if the data should be plotted on a logarithmic scale. - `stat`: Determines the statistical transformation to use on the data for the histogram. Options include: * "count": Show the counts of observations in each bin. * "frequency": Show the number of observations divided by the bin width. * "density": Normalize such that the total area of the histogram equals 1. * "probability": Normalize such that each bar's height reflects the probability of observing that bin. - `bins`: Specification of hist bins. Can be a number (indicating the number of bins) or a list (indicating bin edges). For example, `bins=10` will create 10 bins, while `bins=[0, 1, 2, 3]` will create bins [0,1), [1,2), [2,3]. If not provided, the binning will be determined automatically. Returns ------- fig : matplotlib.figure.Figure The created figure for the plot. axs : matplotlib.axes.Axes or list of Axes The Axes object(s) of the histogram plot(s). Returns a single Axes if only one plot is created, otherwise returns a list of Axes. """ # If no feature or annotation is specified, apply default behavior if feature is None and annotation is None: # Default to the first feature in adata.var_names feature = adata.var_names[0] warnings.warn( "No feature or annotation specified. " "Defaulting to the first feature: " f"'{feature}'.", UserWarning ) # Use utility functions for input validation if layer: check_table(adata, tables=layer) if annotation: check_annotation(adata, annotations=annotation) if feature: check_feature(adata, features=feature) if group_by: check_annotation(adata, annotations=group_by) # If layer is specified, get the data from that layer if layer: df = pd.DataFrame( adata.layers[layer], index=adata.obs.index, columns=adata.var_names ) else: df = pd.DataFrame( adata.X, index=adata.obs.index, columns=adata.var_names ) df = pd.concat([df, adata.obs], axis=1) if feature and annotation: raise ValueError("Cannot pass both feature and annotation," " choose one.") data_column = feature if feature else annotation # Check for negative values and apply log1p transformation if x_log_scale is True if x_log_scale: if (df[data_column] < 0).any(): print( "There are negative values in the data, disabling x_log_scale." ) x_log_scale = False else: df[data_column] = np.log1p(df[data_column]) if ax is not None: fig = ax.get_figure() else: fig, ax = plt.subplots() axs = [] # Prepare the data for plotting plot_data = df.dropna(subset=[data_column]) # Plotting with or without grouping if group_by: groups = df[group_by].dropna().unique().tolist() n_groups = len(groups) if n_groups == 0: raise ValueError("There must be at least one group to create a" " histogram.") if together: # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") sns.histplot(data=df.dropna(), x=data_column, hue=group_by, ax=ax, **kwargs) axs.append(ax) else: fig, ax_array = plt.subplots( n_groups, 1, figsize=(5, 5 * n_groups) ) # Convert a single Axes object to a list # Ensure ax_array is always iterable if n_groups == 1: ax_array = [ax_array] else: ax_array = ax_array.flatten() for i, ax_i in enumerate(ax_array): group_data = plot_data[plot_data[group_by] == groups[i]] sns.histplot(data=group_data, x=data_column, ax=ax_i, **kwargs) ax_i.set_title(groups[i]) # Set axis scales if y_log_scale is True if y_log_scale: ax_i.set_yscale('log') # Adjust x-axis label if x_log_scale is True if x_log_scale: xlabel = f'log({data_column})' else: xlabel = data_column ax_i.set_xlabel(xlabel) # Adjust y-axis label based on 'stat' parameter stat = kwargs.get('stat', 'count') ylabel_map = { 'count': 'Count', 'frequency': 'Frequency', 'density': 'Density', 'probability': 'Probability' } ylabel = ylabel_map.get(stat, 'Count') if y_log_scale: ylabel = f'log({ylabel})' ax_i.set_ylabel(ylabel) axs.append(ax_i) else: sns.histplot(data=plot_data, x=data_column, ax=ax, **kwargs) axs.append(ax) # Set axis scales if y_log_scale is True if y_log_scale: ax.set_yscale('log') # Adjust x-axis label if x_log_scale is True if x_log_scale: xlabel = f'log({data_column})' else: xlabel = data_column ax.set_xlabel(xlabel) # Adjust y-axis label based on 'stat' parameter stat = kwargs.get('stat', 'count') ylabel_map = { 'count': 'Count', 'frequency': 'Frequency', 'density': 'Density', 'probability': 'Probability' } ylabel = ylabel_map.get(stat, 'Count') if y_log_scale: ylabel = f'log({ylabel})' ax.set_ylabel(ylabel) if len(axs) == 1: return fig, axs[0] else: return fig, axs
[docs] def heatmap(adata, column, layer=None, **kwargs): """ Plot the heatmap of the mean feature of cells that belong to a `column`. Parameters ---------- adata : anndata.AnnData The AnnData object. column : str Name of member of adata.obs to plot the histogram. layer : str, default None The name of the `adata` layer to use to calculate the mean feature. **kwargs: Parameters passed to seaborn heatmap function. Returns ------- pandas.DataFrame A dataframe tha has the labels as indexes the mean feature for every marker. matplotlib.figure.Figure The figure of the heatmap. matplotlib.axes._subplots.AxesSubplot The AsxesSubplot of the heatmap. """ features = adata.to_df(layer=layer) labels = adata.obs[column] grouped = pd.concat([features, labels], axis=1).groupby(column) mean_feature = grouped.mean() n_rows = len(mean_feature) n_cols = len(mean_feature.columns) fig, ax = plt.subplots(figsize=(n_cols * 1.5, n_rows * 1.5)) seaborn.heatmap( mean_feature, annot=True, cmap="Blues", square=True, ax=ax, fmt=".1f", cbar_kws=dict(use_gridspec=False, location="top"), linewidth=.5, annot_kws={"fontsize": 10}, **kwargs) ax.tick_params(axis='both', labelsize=25) ax.set_ylabel(column, size=25) return mean_feature, fig, ax
[docs] def hierarchical_heatmap(adata, annotation, features=None, layer=None, cluster_feature=False, cluster_annotations=False, standard_scale=None, z_score="annotation", swap_axes=False, rotate_label=False, **kwargs): """ Generates a hierarchical clustering heatmap and dendrogram. By default, the dataset is assumed to have features as columns and annotations as rows. Cells are grouped by annotation (e.g., phenotype), and for each group, the average expression intensity of each feature (e.g., protein or marker) is computed. The heatmap is plotted using seaborn's clustermap. Parameters ---------- adata : anndata.AnnData The AnnData object. annotation : str Name of the annotation in adata.obs to group by and calculate mean intensity. features : list or None, optional List of feature names (e.g., markers) to be included in the visualization. If None, all features are used. Default is None. layer : str, optional The name of the `adata` layer to use to calculate the mean intensity. If not provided, uses the main matrix. Default is None. cluster_feature : bool, optional If True, perform hierarchical clustering on the feature axis. Default is False. cluster_annotations : bool, optional If True, perform hierarchical clustering on the annotations axis. Default is False. standard_scale : int or None, optional Whether to standard scale data (0: row-wise or 1: column-wise). Default is None. z_score : str, optional Specifies the axis for z-score normalization. Can be "feature" or "annotation". Default is "annotation". swap_axes : bool, optional If True, switches the axes of the heatmap, effectively transposing the dataset. By default (when False), annotations are on the vertical axis (rows) and features are on the horizontal axis (columns). When set to True, features will be on the vertical axis and annotations on the horizontal axis. Default is False. rotate_label : bool, optional If True, rotate x-axis labels by 45 degrees. Default is False. **kwargs: Additional parameters passed to `sns.clustermap` function or its underlying functions. Some essential parameters include: - `cmap` : colormap Colormap to use for the heatmap. It's an argument for the underlying `sns.heatmap()` used within `sns.clustermap()`. Examples include "viridis", "plasma", "coolwarm", etc. - `{row,col}_colors` : Lists or DataFrames Colors to use for annotating the rows/columns. Useful for visualizing additional categorical information alongside the main heatmap. - `{dendrogram,colors}_ratio` : tuple(float) Control the size proportions of the dendrogram and the color labels relative to the main heatmap. - `cbar_pos` : tuple(float) or None Specify the position and size of the colorbar in the figure. If set to None, no colorbar will be added. - `tree_kws` : dict Customize the appearance of the dendrogram tree. Passes additional keyword arguments to the underlying `matplotlib.collections.LineCollection`. - `method` : str The linkage algorithm to use for the hierarchical clustering. Defaults to 'centroid' in the function, but can be changed. - `metric` : str The distance metric to use for the hierarchy. Defaults to 'euclidean' in the function. Returns ------- mean_intensity : pandas.DataFrame A DataFrame containing the mean intensity of cells for each annotation. clustergrid : seaborn.matrix.ClusterGrid The seaborn ClusterGrid object representing the heatmap and dendrograms. dendrogram_data : dict A dictionary containing hierarchical clustering linkage data for both rows and columns. These linkage matrices can be used to generate dendrograms with tools like scipy's dendrogram function. This offers flexibility in customizing and plotting dendrograms as needed. Examples -------- import matplotlib.pyplot as plt import pandas as pd import anndata from spac.visualization import hierarchical_heatmap X = pd.DataFrame([[1, 2], [3, 4]], columns=['gene1', 'gene2']) annotation = pd.DataFrame(['type1', 'type2'], columns=['cell_type']) all_data = anndata.AnnData(X=X, obs=annotation) mean_intensity, clustergrid, dendrogram_data = hierarchical_heatmap( all_data, "cell_type", layer=None, z_score="annotation", swap_axes=True, cluster_feature=False, cluster_annotations=True ) # To display a standalone dendrogram using the returned linkage matrix: import scipy.cluster.hierarchy as sch import numpy as np import matplotlib.pyplot as plt # Convert the linkage data to type double dendro_col_data = np.array(dendrogram_data['col_linkage'], dtype=np.double) # Ensure the linkage matrix has at least two dimensions and more than one linkage if dendro_col_data.ndim == 2 and dendro_col_data.shape[0] > 1: fig, ax = plt.subplots(figsize=(10, 7)) sch.dendrogram(dendro_col_data, ax=ax) plt.title('Standalone Col Dendrogram') plt.show() else: print("Insufficient data to plot a dendrogram.") """ # Use utility functions to check inputs check_annotation(adata, annotations=annotation) if features: check_feature(adata, features=features) if layer: check_table(adata, tables=layer) # Raise an error if there are any NaN values in the annotation column if adata.obs[annotation].isna().any(): raise ValueError("NaN values found in annotation column.") # Convert the observation column to categorical if it's not already if not pd.api.types.is_categorical_dtype(adata.obs[annotation]): adata.obs[annotation] = adata.obs[annotation].astype('category') # Calculate mean intensity if layer: intensities = pd.DataFrame( adata.layers[layer], index=adata.obs_names, columns=adata.var_names ) else: intensities = adata.to_df() labels = adata.obs[annotation] grouped = pd.concat([intensities, labels], axis=1).groupby(annotation) mean_intensity = grouped.mean() # If swap_axes is True, transpose the mean_intensity if swap_axes: mean_intensity = mean_intensity.T # Map z_score based on user's input and the state of swap_axes if z_score == "annotation": z_score = 0 if not swap_axes else 1 elif z_score == "feature": z_score = 1 if not swap_axes else 0 # Subset the mean_intensity DataFrame based on selected features if features is not None and len(features) > 0: mean_intensity = mean_intensity.loc[features] # Determine clustering behavior based on swap_axes if swap_axes: row_cluster = cluster_feature # Rows are features col_cluster = cluster_annotations # Columns are annotations else: row_cluster = cluster_annotations # Rows are annotations col_cluster = cluster_feature # Columns are features # Use seaborn's clustermap for hierarchical clustering and # heatmap visualization. clustergrid = sns.clustermap( mean_intensity, standard_scale=standard_scale, z_score=z_score, method='centroid', metric='euclidean', row_cluster=row_cluster, col_cluster=col_cluster, **kwargs ) # Rotate x-axis tick labels if rotate_label is True if rotate_label: plt.setp(clustergrid.ax_heatmap.get_xticklabels(), rotation=45) # Extract the dendrogram data for return dendro_row_data = None dendro_col_data = None if clustergrid.dendrogram_row: dendro_row_data = clustergrid.dendrogram_row.linkage if clustergrid.dendrogram_col: dendro_col_data = clustergrid.dendrogram_col.linkage # Define the dendrogram_data dictionary dendrogram_data = { 'row_linkage': dendro_row_data, 'col_linkage': dendro_col_data } return mean_intensity, clustergrid, dendrogram_data
[docs] def threshold_heatmap( adata, feature_cutoffs, annotation, layer=None, swap_axes=False, **kwargs ): """ Creates a heatmap for each feature, categorizing intensities into low, medium, and high based on provided cutoffs. Parameters ---------- adata : anndata.AnnData AnnData object containing the feature intensities in .X attribute or specified layer. feature_cutoffs : dict Dictionary with feature names as keys and tuples with two intensity cutoffs as values. annotation : str Column name in .obs DataFrame that contains the annotation used for grouping. layer : str, optional Layer name in adata.layers to use for intensities. If None, uses .X attribute. swap_axes : bool, optional If True, swaps the axes of the heatmap. **kwargs : keyword arguments Additional keyword arguments to pass to scanpy's heatmap function. Returns ------- Dictionary of :class:`~matplotlib.axes.Axes` A dictionary contains the axes of figures generated in the scanpy heatmap function. Consistent Key: 'heatmap_ax' Potential Keys includes: 'groupby_ax', 'dendrogram_ax', and 'gene_groups_ax'. """ # Use utility functions for input validation check_table(adata, tables=layer) check_annotation(adata, annotations=annotation) if feature_cutoffs: check_feature(adata, features=list(feature_cutoffs.keys())) # Assert annotation is a string if not isinstance(annotation, str): err_type = type(annotation).__name__ err_msg = (f'Annotation should be string. Got {err_type}.') raise TypeError(err_msg) if not isinstance(feature_cutoffs, dict): raise TypeError("feature_cutoffs should be a dictionary.") for key, value in feature_cutoffs.items(): if not (isinstance(value, tuple) and len(value) == 2): raise ValueError( "Each value in feature_cutoffs should be a " "tuple of two elements." ) if math.isnan(value[0]): raise ValueError(f"Low cutoff for {key} should not be NaN.") if math.isnan(value[1]): raise ValueError(f"High cutoff for {key} should not be NaN.") adata.uns['feature_cutoffs'] = feature_cutoffs intensity_df = pd.DataFrame( index=adata.obs_names, columns=feature_cutoffs.keys() ) for feature, cutoffs in feature_cutoffs.items(): low_cutoff, high_cutoff = cutoffs feature_values = ( adata[:, feature].layers[layer] if layer else adata[:, feature].X ).flatten() intensity_df.loc[feature_values <= low_cutoff, feature] = 0 intensity_df.loc[(feature_values > low_cutoff) & (feature_values <= high_cutoff), feature] = 1 intensity_df.loc[feature_values > high_cutoff, feature] = 2 intensity_df = intensity_df.astype(int) adata.layers["intensity"] = intensity_df.to_numpy() adata.obs[annotation] = adata.obs[annotation].astype('category') color_map = {0: (0/255, 0/255, 139/255), 1: 'green', 2: 'yellow'} colors = [color_map[i] for i in range(3)] cmap = ListedColormap(colors) norm = BoundaryNorm([-0.5, 0.5, 1.5, 2.5], cmap.N) heatmap_plot = sc.pl.heatmap( adata, var_names=intensity_df.columns, groupby=annotation, use_raw=False, layer='intensity', cmap=cmap, norm=norm, show=False, # Ensure the plot is not displayed but returned swap_axes=swap_axes, **kwargs ) # Print the keys of the heatmap_plot dictionary print("Keys of heatmap_plot:", heatmap_plot.keys()) # Get the main heatmap axis from the available keys heatmap_ax = heatmap_plot.get('heatmap_ax') # If 'heatmap_ax' key does not exist, access the first axis available if heatmap_ax is None: heatmap_ax = next(iter(heatmap_plot.values())) print("Heatmap Axes:", heatmap_ax) # Find the colorbar associated with the heatmap cbar = None for child in heatmap_ax.get_children(): if hasattr(child, 'colorbar'): cbar = child.colorbar break if cbar is None: print("No colorbar found in the plot.") return print("Colorbar:", cbar) new_ticks = [0, 1, 2] new_labels = ['Low', 'Medium', 'High'] cbar.set_ticks(new_ticks) cbar.set_ticklabels(new_labels) pos_heatmap = heatmap_ax.get_position() cbar.ax.set_position( [pos_heatmap.x1 + 0.02, pos_heatmap.y0, 0.02, pos_heatmap.height] ) return heatmap_plot
[docs] def spatial_plot( adata, spot_size, alpha, vmin=-999, vmax=-999, annotation=None, feature=None, layer=None, ax=None, **kwargs ): """ Generate the spatial plot of selected features Parameters ---------- adata : anndata.AnnData The AnnData object containing target feature and spatial coordinates. spot_size : int The size of spot on the spatial plot. alpha : float The transparency of spots, range from 0 (invisible) to 1 (solid) vmin : float or int The lower limit of the feature value for visualization vmax : float or int The upper limit of the feature value for visualization feature : str The feature to visualize on the spatial plot. Default None. annotation : str The annotation to visualize in the spatial plot. Can't be set with feature, default None. layer : str Name of the AnnData object layer that wants to be plotted. By default adata.raw.X is plotted. ax : matplotlib.axes.Axes The matplotlib Axes containing the analysis plots. The returned ax is the passed ax or new ax created. Only works if plotting a single component. **kwargs Arguments to pass to matplotlib.pyplot.scatter() Returns ------- Single or a list of class:`~matplotlib.axes.Axes`. """ err_msg_layer = "The 'layer' parameter must be a string, " + \ f"got {str(type(layer))}" err_msg_feature = "The 'feature' parameter must be a string, " + \ f"got {str(type(feature))}" err_msg_annotation = "The 'annotation' parameter must be a string, " + \ f"got {str(type(annotation))}" err_msg_feat_annotation_coe = "Both annotation and feature are passed, " +\ "please provide sinle input." err_msg_feat_annotation_non = "Both annotation and feature are None, " + \ "please provide single input." err_msg_spot_size = "The 'spot_size' parameter must be an integer, " + \ f"got {str(type(spot_size))}" err_msg_alpha_type = "The 'alpha' parameter must be a float," + \ f"got {str(type(alpha))}" err_msg_alpha_value = "The 'alpha' parameter must be between " + \ f"0 and 1 (inclusive), got {str(alpha)}" err_msg_vmin = "The 'vmin' parameter must be a float or an int, " + \ f"got {str(type(vmin))}" err_msg_vmax = "The 'vmax' parameter must be a float or an int, " + \ f"got {str(type(vmax))}" err_msg_ax = "The 'ax' parameter must be an instance " + \ f"of matplotlib.axes.Axes, got {str(type(ax))}" if adata is None: raise ValueError("The input dataset must not be None.") if not isinstance(adata, anndata.AnnData): err_msg_adata = "The 'adata' parameter must be an " + \ f"instance of anndata.AnnData, got {str(type(adata))}." raise ValueError(err_msg_adata) if layer is not None and not isinstance(layer, str): raise ValueError(err_msg_layer) if layer is not None and layer not in adata.layers.keys(): err_msg_layer_exist = f"Layer {layer} does not exists, " + \ f"available layers are {str(adata.layers.keys())}" raise ValueError(err_msg_layer_exist) if feature is not None and not isinstance(feature, str): raise ValueError(err_msg_feature) if annotation is not None and not isinstance(annotation, str): raise ValueError(err_msg_annotation) if annotation is not None and feature is not None: raise ValueError(err_msg_feat_annotation_coe) if annotation is None and feature is None: raise ValueError(err_msg_feat_annotation_non) if 'spatial' not in adata.obsm_keys(): err_msg = "Spatial coordinates not found in the 'obsm' attribute." raise ValueError(err_msg) # Extract annotation name annotation_names = adata.obs.columns.tolist() annotation_names_str = ", ".join(annotation_names) if annotation is not None and annotation not in annotation_names: error_text = f'The annotation "{annotation}"' + \ 'not found in the dataset.' + \ f" Existing annotations are: {annotation_names_str}" raise ValueError(error_text) # Extract feature name if layer is None: layer_process = adata.X else: layer_process = adata.layers[layer] feature_names = adata.var_names.tolist() if feature is not None and feature not in feature_names: error_text = f"Feature {feature} not found," + \ " please check the sample metadata." raise ValueError(error_text) if not isinstance(spot_size, int): raise ValueError(err_msg_spot_size) if not isinstance(alpha, float): raise ValueError(err_msg_alpha_type) if not (0 <= alpha <= 1): raise ValueError(err_msg_alpha_value) if vmin != -999 and not ( isinstance(vmin, float) or isinstance(vmin, int) ): raise ValueError(err_msg_vmin) if vmax != -999 and not ( isinstance(vmax, float) or isinstance(vmax, int) ): raise ValueError(err_msg_vmax) if ax is not None and not isinstance(ax, plt.Axes): raise ValueError(err_msg_ax) if feature is not None: feature_index = feature_names.index(feature) feature_annotation = feature + "spatial_plot" if vmin == -999: vmin = np.min(layer_process[:, feature_index]) if vmax == -999: vmax = np.max(layer_process[:, feature_index]) adata.obs[feature_annotation] = layer_process[:, feature_index] color_region = feature_annotation else: color_region = annotation vmin = None vmax = None if ax is None: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax = sc.pl.spatial( adata=adata, layer=layer, color=color_region, spot_size=spot_size, alpha=alpha, vmin=vmin, vmax=vmax, ax=ax, show=False, **kwargs) return ax
[docs] def boxplot(adata, annotation=None, second_annotation=None, layer=None, ax=None, features=None, log_scale=False, **kwargs): """ Create a boxplot visualization of the features in the passed adata object. This function offers flexibility in how the boxplots are displayed, based on the arguments provided. Parameters ---------- adata : anndata.AnnData The AnnData object. annotation : str, optional Annotation to determine if separate plots are needed for every label. second_annotation : str, optional Second annotation to further divide the data. layer : str, optional The name of the matrix layer to use. If not provided, uses the main data matrix adata.X. ax : matplotlib.axes.Axes, optional An existing Axes object to draw the plot onto, optional. features : list, optional List of feature names to be plotted. If not provided, all features will be plotted. log_scale : bool, optional If True, the Y-axis will be in log scale. Default is False. **kwargs Additional arguments to pass to seaborn.boxplot. Key arguments include: - `orient`: Determines the orientation of the plot. * "v": Vertical orientation (default). In this case, categorical data will be plotted on the x-axis, and the boxplots will be vertical. * "h": Horizontal orientation. Categorical data will be plotted on the y-axis, and the boxplots will be horizontal. Returns ------- fig, ax : matplotlib.figure.Figure, matplotlib.axes.Axes The created figure and axes for the plot. Examples -------- - Multiple features boxplot: boxplot(adata, features=['GeneA','GeneB']) - Boxplot grouped by a single annotation: boxplot(adata, features=['GeneA'], annotation='cell_type') - Boxplot for multiple features grouped by a single annotation: boxplot(adata, features=['GeneA', 'GeneB'], annotation='cell_type') - Nested grouping by two annotations: boxplot(adata, features=['GeneA'], annotation='cell_type', second_annotation='treatment') """ # Use utility functions to check inputs print("Calculating Box Plot...") if layer: check_table(adata, tables=layer) if annotation: check_annotation(adata, annotations=annotation) if second_annotation: check_annotation(adata, annotations=second_annotation) if features: check_feature(adata, features=features) if 'orient' not in kwargs: kwargs['orient'] = 'v' if kwargs['orient'] != 'v': v_orient = False else: v_orient = True # Validate ax instance if ax and not isinstance(ax, plt.Axes): raise TypeError("Input 'ax' must be a matplotlib.axes.Axes object.") # Use the specified layer if provided if layer: data_matrix = adata.layers[layer] else: data_matrix = adata.X # Create a DataFrame from the data matrix with features as columns df = pd.DataFrame(data_matrix, columns=adata.var_names) # Add annotations to the DataFrame if provided if annotation: df[annotation] = adata.obs[annotation].values if second_annotation: df[second_annotation] = adata.obs[second_annotation].values # If features is None, set it to all available features if features is None: features = adata.var_names.tolist() df = df[ features + ([annotation] if annotation else []) + ([second_annotation] if second_annotation else []) ] # Check for negative values if log_scale and (df[features] < 0).any().any(): print( "There are negative values in this data, disabling the log scale." ) log_scale = False # Apply log1p transformation if log_scale is True if log_scale: df[features] = np.log1p(df[features]) # Create the plot if ax: fig = ax.get_figure() else: fig, ax = plt.subplots(figsize=(10, 5)) # Plotting logic based on provided annotations if annotation and second_annotation: if v_orient: sns.boxplot(data=df, y=features[0], x=annotation, hue=second_annotation, ax=ax, **kwargs) else: sns.boxplot(data=df, y=annotation, x=features[0], hue=second_annotation, ax=ax, **kwargs) title_str = f"Nested Grouping by {annotation} and {second_annotation}" ax.set_title(title_str) elif annotation: if len(features) > 1: # Reshape the dataframe to long format for visualization melted_data = df.melt(id_vars=annotation) if v_orient: sns.boxplot(data=melted_data, x="variable", y="value", hue=annotation, ax=ax, **kwargs) else: sns.boxplot(data=melted_data, x="value", y="variable", hue=annotation, ax=ax, **kwargs) ax.set_title(f"Multiple Features Grouped by {annotation}") else: if v_orient: sns.boxplot(data=df, y=features[0], x=annotation, ax=ax, **kwargs) else: sns.boxplot(data=df, x=features[0], y=annotation, ax=ax, **kwargs) ax.set_title(f"Grouped by {annotation}") else: if len(features) > 1: if v_orient: sns.boxplot(data=df[features], ax=ax, **kwargs) else: melted_data = df.melf() sns.boxplot(data=melted_data, x="value", y="variable", hue=annotation, ax=ax, **kwargs) ax.set_title("Multiple Features") else: if v_orient: sns.boxplot(y=df[features[0]], ax=ax, **kwargs) ax.set_xticks([0]) # Set a single tick for the single feature ax.set_xticklabels([features[0]]) # Set the label for the tick else: sns.boxplot(x=df[features[0]], ax=ax, **kwargs) ax.set_yticks([0]) # Set a single tick for the single feature ax.set_yticklabels([features[0]]) # Set the label for the tick ax.set_title("Single Boxplot") # Set x and y-axis labels if v_orient: xlabel = annotation if annotation else 'Feature' ylabel = 'log(Intensity)' if log_scale else 'Intensity' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) else: xlabel = 'log(Intensity)' if log_scale else 'Intensity' ylabel = annotation if annotation else 'Feature' ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) plt.xticks(rotation=90) plt.tight_layout() return fig, ax, df
[docs] def interative_spatial_plot( adata, annotations, dot_size=1.5, dot_transparancy=0.75, colorscale='viridis', figure_width=6, figure_height=4, figure_dpi=200, font_size=12, stratify_by=None, defined_color_map=None, **kwargs ): """ Create an interactive scatter plot for spatial data using provided annotations. Parameters ---------- adata : AnnData Annotated data matrix object, must have a .obsm attribute with 'spatial' key. annotations : list of str or str Column(s) in `adata.obs` that contain the annotations to plot. If a single string is provided, it will be converted to a list. The interactive plot will show all the labels in the annotation columns passed. dot_size : float, optional Size of the scatter dots in the plot. Default is 1.5. dot_transparancy : float, optional Transparancy level of the scatter dots. Default is 0.75. colorscale : str, optional Name of the color scale to use for the dots. Default is 'Viridis'. figure_width : int, optional Width of the figure in inches. Default is 12. figure_height : int, optional Height of the figure in inches. Default is 8. figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. font_size : int, optional Font size for text in the plot. Default is 12. stratify_by : str, optional Column in `adata.obs` to stratify the plot. Default is None. defined_color_map : str, optional Predefined color mapping stored in adata.uns for specific labels. Default is None, which will generate the color mapping automatically. **kwargs Additional keyword arguments for customization. Returns ------- list of dict A list of dictionaries, each containing the following keys: - "image_name": str, the name of the generated image. - "image_object": Plotly Figure object. Notes ----- This function is tailored for spatial single-cell data and expects the AnnData object to have spatial coordinates in its `.obsm` attribute under the 'spatial' key. """ if not isinstance(annotations, list): annotations = [annotations] for annotation in annotations: check_annotation( adata, annotations=annotation ) check_table( adata, tables='spatial', associated_table=True ) if defined_color_map is not None: if not isinstance(defined_color_map, str): raise TypeError( 'The "degfined_color_map" should be a string ' + \ f'getting {type(defined_color_map)}.' ) uns_keys = list(adata.uns.keys()) if len(uns_keys) == 0: raise ValueError( "No existing color map found, please" + \ " make sure the Append Pin Color Rules " + \ "template had been ran prior to the "+ \ "current visualization node.") if defined_color_map not in uns_keys: raise ValueError( f'The given color map name: {defined_color_map} ' + \ "is not found in current analysis, " + \ f'available items are: {uns_keys}' ) defined_color_map_dict = adata.uns[defined_color_map] print( f'Selected color mapping "{defined_color_map}":\n' + \ f'{defined_color_map_dict}' ) def main_figure_generation( adata, annotations=annotations, dot_size=dot_size, dot_transparancy=dot_transparancy, colorscale=colorscale, figure_width=figure_width, figure_height=figure_height, figure_dpi=figure_dpi, font_size=font_size, **kwargs ): """ Create the core interactive plot for downstream processing. This function generates the main interactive plot using Plotly that contains the spatial scatter plot with annotations and image configuration. Parameters ---------- adata : AnnData Annotated data matrix object, must have a .obsm attribute with 'spatial' key. annotations : list of str or str Column(s) in `adata.obs` that contain the annotations to plot. If a single string is provided, it will be converted to a list. The interactive plot will show all the labels in the annotation columns passed. dot_size : float, optional Size of the scatter dots in the plot. Default is 1.5. dot_transparancy : float, optional Transparancy level of the scatter dots. Default is 0.75. colorscale : str, optional Name of the color scale to use for the dots. Default is 'Viridis'. figure_width : int, optional Width of the figure in inches. Default is 12. figure_height : int, optional Height of the figure in inches. Default is 8. figure_dpi : int, optional DPI (dots per inch) for the figure. Default is 200. font_size : int, optional Font size for text in the plot. Default is 12. Returns ------- plotly.graph_objs._figure.Figure """ spatial_coords = adata.obsm['spatial'] extract_columns_raw = [] for item in annotations: extract_columns_raw.append(adata.obs[item]) extract_columns = [] # The `extract_columns` list is needed for generating Plotly images # because it stores transformed annotation data. These annotations # are added as columns in the DataFrame (`df`) and are used as inputs # for the `color` and `hover_data` parameters in the Plotly scatter # plot. This enables the plot to visually encode annotations, providing # better insights into the spatial data. Without `extract_columns`, the # plot would lack essential annotation-based differentiation and # interactivity. for i, item in enumerate(extract_columns_raw): extract_columns.append( [annotations[i] + "_" + str(value) for value in item] ) xcoord = [coord[0] for coord in spatial_coords] ycoord = [coord[1] for coord in spatial_coords] data = {'X': xcoord, 'Y': ycoord} # Add the extract_columns data as columns in the dictionary for i, column in enumerate(extract_columns): column_name = annotations[i] data[column_name] = column # Create the DataFrame df = pd.DataFrame(data) max_x_range = max(xcoord) * 1.1 min_x_range = min(xcoord) * 0.9 max_y_range = max(ycoord) * 1.1 min_y_range = min(ycoord) * 0.9 width_px = int(figure_width * figure_dpi) height_px = int(figure_height * figure_dpi) main_fig = px.scatter( df, x='X', y='Y', color=annotations[0], hover_data=[annotations[0]] ) # If annotation is more than 1, we would first call px.scatter # to create plotly object, than append the data to main figure # with add_trace for a centralized view. if len(annotations) > 1: for obs in annotations[1:]: scatter_fig = px.scatter( df, x='X', y='Y', color=obs, hover_data=[obs] ) main_fig.add_traces(scatter_fig.data) # Reset the color attribute of the traces in combined_fig # This is necessary to ensure that the color attribute # does not interfere with subsequent plots for trace in main_fig.data: trace.marker.color = None main_fig.update_traces( mode='markers', marker=dict( size=dot_size, colorscale=colorscale, opacity=dot_transparancy ), hovertemplate="%{customdata[0]}<extra></extra>" ) main_fig.update_layout( width=width_px, height=height_px, plot_bgcolor='white', font=dict(size=font_size), margin=dict(l=10, r=10, t=10, b=10), legend=dict( orientation='v', yanchor='middle', y=0.5, xanchor='right', x=1.15, title='', itemwidth=30, bgcolor="rgba(0, 0, 0, 0)", traceorder='normal', entrywidth=50 ), xaxis=dict( range=[min_x_range, max_x_range], showgrid=False, showticklabels=False, title_standoff=5, constrain="domain" ), yaxis=dict( range=[max_y_range, min_y_range], showgrid=False, scaleanchor="x", scaleratio=1, showticklabels=False, title_standoff=5, constrain="domain" ), shapes=[ go.layout.Shape( type="rect", xref="x", yref="y", x0=min_x_range, y0=min_y_range, x1=max_x_range, y1=max_y_range, line=dict(color="black", width=1), fillcolor="rgba(0,0,0,0)", ) ] ) return main_fig def generate_and_update_image( adata, title, color_mapping=None, **kwargs ): """ This function generates the main figure with annotations and optional stratifications or color mappings, providing flexibility for detailed visualizations. It processes data, groups it by annotations, and enables advanced legend handling and styling. Parameters ---------- adata : AnnData Annotated data matrix containing either the full dataset or a subset of the data. title : str Title for the plot. stratify_by : str, optional Column to stratify the plot. Default is None. color_mapping : dict, optional Color mapping for specific labels. Default is None. Returns ------- dict A dictionary with "image_name" and "image_object" keys. """ main_fig_parent = main_figure_generation( adata, annotations=annotations, dot_size=dot_size, dot_transparancy=dot_transparancy, colorscale=colorscale, figure_width=figure_width, figure_height=figure_height, figure_dpi=figure_dpi, font_size=font_size, **kwargs ) # Create a copy of the figure for non-destructive updates main_fig_copy = copy.copy(main_fig_parent) data = main_fig_copy.data main_fig_parent.data = [] # Prepare to track updates and manage grouped annotations updated_index = [] legend_list = [ f"legend{i+1}" if i > 0 else "legend" for i in range(len(annotations)) ] previous_group = None # Process each trace in the figure for grouping and legends indices = list(range(len(data))) for item in indices: cat_label = data[item]['customdata'][0][0] cat_dataset = pd.DataFrame( {'X': data[item]['x'], 'Y': data[item]['y']} ) # Assign the label to the appropriate legend group for i, legend_group in enumerate(annotations): if cat_label.startswith(legend_group): cat_leg_group = f"<b>{legend_group}</b>" cat_label = cat_label[len(legend_group) + 1:] cat_group = legend_list[i] # Add a new legend entry if this group hasn't been encountered if previous_group is None or cat_group != previous_group: main_fig_parent.add_trace(go.Scattergl( x=[data[item]['x'][0]], y=[data[item]['y'][0]], name=cat_leg_group, mode="markers", showlegend=True, marker=dict( color="white", colorscale=None, size=0, opacity=0 ) )) previous_group = cat_group # Add the category label to the dataset for grouping cat_dataset['label'] = cat_label main_fig_parent.add_trace(go.Scattergl( x=cat_dataset['X'], y=cat_dataset['Y'], name=cat_label, mode="markers", showlegend=True, marker=dict( colorscale=colorscale, size=dot_size, opacity=dot_transparancy ) )) updated_index.append(cat_label) if color_mapping is not None: main_fig_copy = copy.copy(main_fig_parent) data = main_fig_copy.data main_fig_parent.data = [] for trace in data: trace_name = trace["name"] if color_mapping is not None: if trace_name in color_mapping.keys(): trace['marker']['color'] = color_mapping[trace_name] main_fig_parent.add_trace(trace) main_fig_parent.update_layout( title={ 'text': title, 'font': {'size': font_size}, 'xanchor': 'center', 'yanchor': 'top', 'x': 0.5, 'y': 0.99 }, legend={ 'x': 1.05, 'y': 0.5, 'xanchor': 'left', 'yanchor': 'middle' }, margin=dict(l=5, r=5, t=font_size*2, b=5) ) return { "image_name": f"{spell_out_special_characters(title)}.html", "image_object": main_fig_parent } ##################### ## Main Code Block ## ##################### results = [] if defined_color_map: color_dict = adata.uns[defined_color_map] else: unique_ann_labels = np.unique(adata.obs[annotations].values) color_dict = color_mapping( unique_ann_labels, color_map=colorscale, rgba_mode=False, return_dict=True ) if stratify_by is not None: unique_stratification_values = adata.obs[stratify_by].unique() for strat_value in unique_stratification_values: condition = adata.obs[stratify_by] == strat_value title = f"Highlighting {stratify_by}: {strat_value}" indices = np.where(condition)[0] selected_spatial = adata.obsm['spatial'][indices] print(f"number of cells in the region: {len(selected_spatial)}") adata_subset = select_values( data=adata, annotation=stratify_by, values=strat_value ) result = generate_and_update_image( adata=adata_subset, title=title, stratify_by=stratify_by, color_mapping=color_dict, **kwargs ) results.append(result) else: title = "Interactive Spatial Plot" result = generate_and_update_image( adata=adata, title=title, stratify_by=None, color_mapping=color_dict, **kwargs ) results.append(result) return results
[docs] def sankey_plot( adata: anndata.AnnData, source_annotation: str, target_annotation: str, source_color_map: str = "tab20", target_color_map: str = "tab20c", sankey_font: float = 12.0, prefix: bool = True ): """ Generates a Sankey plot from the given AnnData object. The color map refers to matplotlib color maps, default is tab20 for source annotation, and tab20c for target annotation. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html Parameters ---------- adata : anndata.AnnData The annotated data matrix. source_annotation : str The source annotation to use for the Sankey plot. target_annotation : str The target annotation to use for the Sankey plot. source_color_map : str The color map to use for the source nodes. Default is tab20. target_color_map : str The color map to use for the target nodes. Default is tab20c. sankey_font : float, optional The font size to use for the Sankey plot. Defaults to 12.0. prefix : bool, optional Whether to prefix the target labels with the source labels. Defaults to True. Returns ------- plotly.graph_objs._figure.Figure The generated Sankey plot. """ label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, target_annotation=target_annotation, prefix=prefix ) # Extract and prepare source and target labels source_labels = label_relations["source"].unique().tolist() target_labels = label_relations["target"].unique().tolist() all_labels = source_labels + target_labels source_label_colors = color_mapping(source_labels, source_color_map) target_label_colors = color_mapping(target_labels, target_color_map) label_colors = source_label_colors + target_label_colors # Create a dictionary to map labels to indices label_to_index = { label: index for index, label in enumerate(all_labels)} color_to_map = { label: color for label, color in zip(source_labels, source_label_colors) } # Initialize lists to store the source indices, target indices, and values source_indices = [] target_indices = [] values = [] link_colors = [] # For each row in label_relations, add the source index, target index, # and count to the respective lists for _, row in label_relations.iterrows(): source_indices.append(label_to_index[row['source']]) target_indices.append(label_to_index[row['target']]) values.append(row['count']) link_colors.append(color_to_map[row['source']]) # Generate Sankey diagram # Calculate the x-coordinate for each label fig = go.Figure(go.Sankey( node=dict( pad=sankey_font * 1.05, thickness=sankey_font * 1.05, line=dict(color=None, width=0.1), label=all_labels, color=label_colors ), link=dict( arrowlen=15, source=source_indices, target=target_indices, value=values, color=link_colors ), arrangement="snap", textfont=dict( color='black', size=sankey_font ) )) fig.data[0].link.customdata = label_relations[ ['percentage_source', 'percentage_target'] ] hovertemplate = ( '%{source.label} to %{target.label}<br>' '%{customdata[0]}% to %{customdata[1]}%<br>' 'Count: %{value}<extra></extra>' ) fig.data[0].link.hovertemplate = hovertemplate # Customize the Sankey diagram layout fig.update_layout( title_text=( f'"{source_annotation}" to "{target_annotation}"<br>Sankey Diagram' ), title_x=0.5, title_font=dict( family='Arial, bold', size=sankey_font, # Set the title font size color="black" # Set the title font color ) ) fig.update_layout(margin=dict( l=10, r=10, t=sankey_font * 3, b=sankey_font)) return fig
[docs] def relational_heatmap( adata: anndata.AnnData, source_annotation: str, target_annotation: str, color_map: str = "mint", **kwargs ): """ Generates a relational heatmap from the given AnnData object. The color map refers to matplotlib color maps, default is mint. For more information on colormaps, see: https://matplotlib.org/stable/users/explain/colors/colormaps.html Parameters ---------- adata : anndata.AnnData The annotated data matrix. source_annotation : str The source annotation to use for the relational heatmap. target_annotation : str The target annotation to use for the relational heatmap. color_map : str The color map to use for the relational heatmap. Default is mint. **kwargs : dict, optional Additional keyword arguments. For example, you can pass font_size=12.0. Returns ------- dict A dictionary containing: - "figure" (plotly.graph_objs._figure.Figure): The generated relational heatmap as a Plotly figure. - "file_name" (str): The name of the file where the relational matrix can be saved. - "data" (pandas.DataFrame): A relational matrix DataFrame with percentage values. Rows represent source annotations, columns represent target annotations, and an additional "total" column sums the percentages for each source. """ # Default font size font_size = kwargs.get('font_size', 12.0) prefix = kwargs.get('prefix', True) # Get the relationship between source and target annotations label_relations = annotation_category_relations( adata=adata, source_annotation=source_annotation, target_annotation=target_annotation, prefix=prefix ) # Pivot the data to create a matrix for the heatmap heatmap_matrix = label_relations.pivot( index='source', columns='target', values='percentage_source' ) heatmap_matrix = heatmap_matrix.fillna(0) x = list(heatmap_matrix.columns) y = list(heatmap_matrix.index) # Create text labels for the heatmap label_relations['text_label'] = [ '{}%'.format(val) for val in label_relations["percentage_source"] ] heatmap_matrix2 = label_relations.pivot( index='source', columns='target', values='percentage_source' ) heatmap_matrix2 = heatmap_matrix2.fillna(0) hover_template = 'Source: %{z}%<br>Target: %{customdata}%<extra></extra>' # Ensure alignment of the text data with the heatmap matrix z = list() iter_list = list() for y_item in y: iter_list.clear() for x_item in x: z_data_point = label_relations[ ( label_relations['target'] == x_item ) & ( label_relations['source'] == y_item ) ]['percentage_source'] iter_list.append( 0 if len(z_data_point) == 0 else z_data_point.iloc[0] ) z.append([_ for _ in iter_list]) # Create heatmap fig = ff.create_annotated_heatmap( z=z, colorscale=color_map, customdata=heatmap_matrix2.values, hovertemplate=hover_template ) fig.update_layout( overwrite=True, xaxis=dict( title=source_annotation, ticks="", dtick=1, side="top", gridcolor="rgb(0, 0, 0)", tickvals=list(range(len(x))), ticktext=x ), yaxis=dict( title=target_annotation, ticks="", dtick=1, ticksuffix=" ", tickvals=list(range(len(y))), ticktext=y ), margin=dict( l=5, r=5, t=font_size * 2, b=font_size * 2 ) ) for i in range(len(fig.layout.annotations)): fig.layout.annotations[i].font.size = font_size fig.update_xaxes( side="bottom", tickangle=90 ) # Data output section data = fig.data[0] layout = fig.layout # Create a DataFrame matrix = pd.DataFrame(data['customdata']) matrix.index=layout['yaxis']['ticktext'] matrix.columns=layout['xaxis']['ticktext'] matrix["total"] = matrix.sum(axis=1) matrix = matrix.fillna(0) # Display the DataFrame file_name = f"{source_annotation}_to_{target_annotation}" + \ "_relation_matrix.csv" return {"figure": fig, "file_name": file_name, "data": matrix}
[docs] def plot_ripley_l( adata, phenotypes, annotation=None, regions=None, sims=False, return_df=False, **kwargs): """ Plot Ripley's L statistic for multiple bins and different regions for a given pair of phenotypes. Parameters ---------- adata : AnnData AnnData object containing Ripley's L results in `adata.uns['ripley_l']`. phenotypes : tuple of str A tuple of two phenotypes: (center_phenotype, neighbor_phenotype). regions : list of str, optional A list of region labels to plot. If None, plot all available regions. Default is None. sims : bool, optional Whether to plot the simulation results. Default is False. return_df : bool, optional Whether to return the DataFrame containing the Ripley's L results. kwargs : dict, optional Additional keyword arguments to pass to `seaborn.lineplot`. Raises ------ ValueError If the Ripley L results are not found in `adata.uns['ripley_l']`. Returns ------- ax : matplotlib.axes.Axes The Axes object containing the plot, which can be further modified. df : pandas.DataFrame, optional The DataFrame containing the Ripley's L results, if `return_df` is True. Example ------- >>> ax = plot_ripley_l( ... adata, ... phenotypes=('Phenotype1', 'Phenotype2'), ... regions=['region1', 'region2']) >>> plt.show() This returns the `Axes` object for further customization and displays the plot. """ # Retrieve the results from adata.uns['ripley_l'] ripley_results = adata.uns.get('ripley_l') if ripley_results is None: raise ValueError( "Ripley L results not found in the analsyis." ) # Filter the results for the specific pair of phenotypes filtered_results = ripley_results[ (ripley_results['center_phenotype'] == phenotypes[0]) & (ripley_results['neighbor_phenotype'] == phenotypes[1]) ] if filtered_results.empty: # Generate all unique combinations of phenotype pairs unique_pairs = ripley_results[ ['center_phenotype', 'neighbor_phenotype']].drop_duplicates() raise ValueError( "No Ripley L results found for the specified pair of phenotypes." f'\nCenter Phenotype: "{phenotypes[0]}"' f'\nNeighbor Phenotype: "{phenotypes[1]}"' f"\nExisiting unique pairs: {unique_pairs}" ) # If specific regions are provided, filter them, otherwise plot all regions if regions is not None: filtered_results = filtered_results[ filtered_results['region'].isin(regions)] # Check if the results are emply after subsetting the regions if filtered_results.empty: available_regions = ripley_results['region'].unique() raise ValueError( f"No data available for the specified regions: {regions}. " f"Available regions: {available_regions}." ) # Create a figure and axes fig, ax = plt.subplots(figsize=(10, 10)) plot_data = [] # Plot Ripley's L for each region for _, row in filtered_results.iterrows(): region = row['region'] # Region label if row['ripley_l'] is None: message = ( f"Ripley L results not found for region: {region}" f"\n Message: {row['message']}" ) logging.warning( message ) print(message) continue n_center = row['ripley_l']['n_center'] n_neighbors = row['ripley_l']['n_neighbor'] n_cells = f"({n_center}, {n_neighbors})" area = row['ripley_l']['area'] # Plot the Ripley L statistic for the region sns.lineplot( data=row['ripley_l']['L_stat'], x='bins', y='stats', label=f'{region}: {n_cells}, {int(area)}', ax=ax, **kwargs) # Prepare plotted data to return if return_df is True l_stat_data = row['ripley_l']['L_stat'] for _, stat_row in l_stat_data.iterrows(): plot_data.append({ 'region': region, 'radius': stat_row['bins'], 'ripley(radius)': stat_row['stats'], 'region_area': area, 'n_center': n_center, 'n_neighbor': n_neighbors, }) if sims: confidence_level = 95 errorbar = ("pi", confidence_level) n_sims = row["n_simulations"] sns.lineplot( x="bins", y="stats", data=row["ripley_l"]["sims_stat"], errorbar=errorbar, label=f"Simulations({region}):{n_sims} runs", **kwargs ) # Set labels, title, and grid ax.set_title( "Ripley's L Statistic for phenotypes:" f"({phenotypes[0]}, {phenotypes[1]})\n" ) ax.legend(title='Regions:(center, neighbor), area', loc='upper left') ax.grid(True) # Set the horizontal axis lable ax.set_xlabel("Radii (pixels)") ax.set_ylabel("Ripley's L Statistic") if return_df: df = pd.DataFrame(plot_data) return fig, df return fig
[docs] def _prepare_spatial_distance_data( adata, annotation, stratify_by=None, spatial_distance='spatial_distance', distance_from=None, distance_to=None, log=False ): """ Prepares a tidy DataFrame for nearest-neighbor (spatial distance) plotting. This function: 1) Validates required parameters (annotation, distance_from). 2) Retrieves the spatial distance matrix from `adata.obsm[spatial_distance]`. 3) Merges annotation (and optional stratify column). 4) Filters rows to the reference phenotype (`distance_from`). 5) Subsets columns if `distance_to` is given; otherwise keeps all distances. 6) Reshapes (melts) into long-form data: columns -> [cellid, group, distance]. 7) Applies optional log1p transform. The resulting DataFrame is suitable for plotting with tool like Seaborn. Parameters ---------- adata : anndata.AnnData Annotated data matrix, containing distances in `adata.obsm[spatial_distance]`. annotation : str Column in `adata.obs` indicating cell phenotype or annotation. stratify_by : str, optional Column in `adata.obs` used to group/stratify data (e.g., image or sample). spatial_distance : str, optional Key in `adata.obsm` storing the distance DataFrame. Default 'spatial_distance'. distance_from : str Reference phenotype from which distances are measured. Required. distance_to : str or list of str, optional Target phenotype(s). If None, use all available phenotype distances. log : bool, optional If True, applies np.log1p transform to the 'distance' column, which is renamed to 'log_distance'. Returns ------- pd.DataFrame Tidy DataFrame with columns: - 'cellid': index of the cell from 'adata.obs'. - 'group': the target phenotype (column names of 'distance_map'. - 'distance': the numeric distance value. - 'phenotype': the reference phenotype ('distance_from'). - 'stratify_by': optional grouping column, if provided. Raises ------ ValueError If required parameters are missing, if phenotypes are not found in `adata.obs`, or if the spatial distance matrix is not available in `adata.obsm`. Examples -------- >>> df_long = _prepare_spatial_distance_data( ... adata=my_adata, ... annotation='cell_type', ... stratify_by='sample_id', ... spatial_distance='spatial_distance', ... distance_from='Tumor', ... distance_to=['Stroma', 'Immune'], ... log=True ... ) >>> df_long.head() """ # Validate required parameters if distance_from is None: raise ValueError( "Please specify the 'distance_from' phenotype. This indicates " "the reference group from which distances are measured." ) check_annotation(adata, annotations=annotation) # Convert distance_to to list if needed if distance_to is not None and isinstance(distance_to, str): distance_to = [distance_to] phenotypes_to_check = [distance_from] + ( distance_to if distance_to else [] ) # Ensure distance_from and distance_to exist in adata.obs[annotation] check_label( adata, annotation=annotation, labels=phenotypes_to_check, should_exist=True ) # Retrieve the spatial distance matrix from adata.obsm if spatial_distance not in adata.obsm: raise ValueError( f"'{spatial_distance}' does not exist in the provided dataset. " "Please run 'calculate_nearest_neighbor' first to compute and " "store spatial distance. " f"Available keys: {list(adata.obsm.keys())}" ) distance_map = adata.obsm[spatial_distance].copy() # Verify requested phenotypes exist in the distance_map columns missing_cols = [ p for p in phenotypes_to_check if p not in distance_map.columns ] if missing_cols: raise ValueError( f"Phenotypes {missing_cols} not found in columns of " f"'{spatial_distance}'. Columns present: " f"{list(distance_map.columns)}" ) # Validate 'stratify_by' column if provided if stratify_by is not None: check_annotation(adata, annotations=stratify_by) # Build a meta DataFrame with phenotype & optional stratify column meta_data = pd.DataFrame({'phenotype': adata.obs[annotation]}, index=adata.obs.index) if stratify_by: meta_data[stratify_by] = adata.obs[stratify_by] # Merge metadata with distance_map and filter for 'distance_from' df_merged = meta_data.join(distance_map, how='left') df_merged = df_merged[df_merged['phenotype'] == distance_from] if df_merged.empty: raise ValueError( f"No cells found with phenotype == '{distance_from}'." ) # Reset index to ensure cell names are in a column called 'cellid' df_merged = df_merged.reset_index().rename(columns={'index': 'cellid'}) # Prepare the list of metadata columns meta_cols = ['phenotype'] if stratify_by: meta_cols.append(stratify_by) # Determine distance columns if distance_to: keep_cols = ['cellid'] + meta_cols + distance_to else: non_distance_cols = ['cellid', 'phenotype'] if stratify_by: non_distance_cols.append(stratify_by) distance_columns = [ c for c in df_merged.columns if c not in non_distance_cols ] keep_cols = ['cellid'] + meta_cols + distance_columns df_merged = df_merged[keep_cols] # Melt the DataFrame from wide to long format df_long = df_merged.melt( id_vars=['cellid'] + meta_cols, var_name='group', value_name='distance' ) # Convert columns to categorical for consistency for col in ['group', 'phenotype', stratify_by]: if col and col in df_long.columns: df_long[col] = df_long[col].astype(str).astype('category') # Reorder categories for 'group' if 'distance_to' is provided if distance_to: df_long['group'] = df_long['group'].cat.reorder_categories(distance_to) df_long.sort_values('group', inplace=True) # Ensure 'distance' is numeric and apply log transform if requested df_long['distance'] = pd.to_numeric(df_long['distance'], errors='coerce') if log: df_long['distance'] = np.log1p(df_long['distance']) df_long.rename(columns={'distance': 'log_distance'}, inplace=True) # Reorder columns dynamically based on the presence of 'log' distance_col = 'log_distance' if log else 'distance' final_cols = ['cellid', 'group', distance_col, 'phenotype'] if stratify_by is not None: final_cols.append(stratify_by) df_long = df_long[final_cols] return df_long
[docs] def _plot_spatial_distance_dispatch( df_long, method, plot_type, stratify_by=None, facet_plot=False, **kwargs ): """ Decides the figure layout based on 'stratify_by' and 'facet_plot' and dispatches actual plotting calls. Logic: 1) If stratify_by and facet_plot => single figure with subplots (faceted) 2) If stratify_by and not facet_plot => multiple figures, one per group 3) If stratify_by is None => single figure (no subplots) This function calls seaborn figure-level functions (catplot or displot). Parameters ---------- df_long : pd.DataFrame Tidy DataFrame with columns ['cellid', 'group', 'distance', 'phenotype', 'stratify_by']. method : {'numeric', 'distribution'} Determines which seaborn function is used (catplot or displot). plot_type : str For method='numeric': 'box', 'violin', 'boxen', etc. For method='distribution': 'hist', 'kde', 'ecdf', etc. stratify_by : str or None Column name for grouping. If None, no grouping is done. facet_plot : bool If True, subplots in a single figure (faceted). If False, separate figures (one per group) or a single figure. **kwargs Additional seaborn plotting arguments (e.g., col_wrap=2). Returns ------- dict Dictionary with two keys: - "data": the DataFrame (df_long) - "fig": a Matplotlib Figure or a list of Figures Raises ------ ValueError If 'method' is invalid (not 'numeric' or 'distribution'). Examples -------- Called internally by 'visualize_nearest_neighbor'. Typically not used directly by end users. """ distance_col = kwargs.pop('distance_col', 'distance') hue_axis = kwargs.pop('hue_axis', None) if method not in ['numeric', 'distribution']: raise ValueError("`method` must be 'numeric' or 'distribution'.") # Set up the plotting function using partial if method == 'numeric': plot_func = partial( sns.catplot, data=None, x=distance_col, y='group', kind=plot_type ) else: # distribution plot_func = partial( sns.displot, data=None, x=distance_col, hue=hue_axis if hue_axis else None, kind=plot_type ) # Helper to plot a single figure or faceted figure def _make_figure(data, **kws): g = plot_func(data=data, **kws) if distance_col == 'log_distance': x_label = "Log(Nearest Neighbor Distance)" else: x_label = "Nearest Neighbor Distance" # Set axis label based on whether log transform was applied if hasattr(g, 'set_axis_labels'): g.set_axis_labels(x_label, None) else: # Fallback if 'set_axis_labels' is unavailable plt.xlabel(x_label) return g.fig figures = [] # Branching logic for figure creation if stratify_by and facet_plot: # Single figure with faceted subplots (col=stratify_by) fig = _make_figure(df_long, col=stratify_by, **kwargs) figures.append(fig) elif stratify_by and not facet_plot: # Multiple separate figures, one per unique value in stratify_by categories = df_long[stratify_by].unique() for cat in categories: subset = df_long[df_long[stratify_by] == cat] fig = _make_figure(subset, **kwargs) figures.append(fig) else: # Single figure (no subplots) fig = _make_figure(df_long, **kwargs) figures.append(fig) # Return dictionary: { 'data': DataFrame, 'fig': Figure(s) } result = {"data": df_long} if len(figures) == 1: result["fig"] = figures[0] else: result["fig"] = figures return result
[docs] def visualize_nearest_neighbor( adata, annotation, stratify_by=None, spatial_distance='spatial_distance', distance_from=None, distance_to=None, facet_plot=False, plot_type=None, log=False, method=None, **kwargs ): """ Visualize nearest-neighbor (spatial distance) data between groups of cells as numeric or distribution plots. This user-facing function assembles the data by calling `_prepare_spatial_distance_data` and then creates plots through `_plot_spatial_distance_dispatch`. Plot arrangement logic: 1) If stratify_by is not None and facet_plot=True => single figure with subplots (faceted). 2) If stratify_by is not None and facet_plot=False => multiple separate figures, one per group. 3) If stratify_by is None => a single figure with one plot. Parameters ---------- adata : anndata.AnnData Annotated data matrix with distances in `adata.obsm[spatial_distance]`. annotation : str Column in `adata.obs` containing cell phenotypes or annotations. stratify_by : str, optional Column in `adata.obs` used to group or stratify data (e.g. imageid). spatial_distance : str, optional Key in `adata.obsm` storing the distance DataFrame. Default is 'spatial_distance'. distance_from : str Reference phenotype from which distances are measured. Required. distance_to : str or list of str, optional Target phenotype(s) to measure distance to. If None, uses all available phenotypes. facet_plot : bool, optional If True (and stratify_by is not None), subplots in a single figure. Else, multiple or single figure(s). plot_type : str, optional For method='numeric': 'box', 'violin', 'boxen', etc. For method='distribution': 'hist', 'kde', 'ecdf', etc. log : bool, optional If True, applies np.log1p transform to the distance values. method : {'numeric', 'distribution'} Determines the plotting style (catplot vs displot). **kwargs : dict Additional arguments for seaborn figure-level functions. Returns ------- dict { "data": pd.DataFrame, # Tidy DataFrame used for plotting "fig": Figure or list[Figure] # Single or multiple figures } Raises ------ ValueError If required parameters are missing or invalid. Examples -------- >>> # Numeric box plot comparing Tumor distances to multiple targets >>> res = visualize_nearest_neighbor( ... adata=my_adata, ... annotation='cell_type', ... stratify_by='sample_id', ... spatial_distance='spatial_distance', ... distance_from='Tumor', ... distance_to=['Stroma', 'Immune'], ... facet_plot=True, ... plot_type='box', ... method='numeric' ... ) >>> df_long, fig = res["data"], res["fig"] >>> # Distribution plot (kde) for a single target, single figure >>> res2 = visualize_nearest_neighbor( ... adata=my_adata, ... annotation='cell_type', ... distance_from='Tumor', ... distance_to='Stroma', ... method='distribution', ... plot_type='kde' ... ) >>> df_dist, fig2 = res2["data"], res2["fig"] """ if distance_from is None: raise ValueError( "Please specify the 'distance_from' phenotype. It indicates " "the reference group from which distances are measured." ) if method not in ['numeric', 'distribution']: raise ValueError( "Invalid 'method'. Please choose 'numeric' or 'distribution'." ) df_long = _prepare_spatial_distance_data( adata=adata, annotation=annotation, stratify_by=stratify_by, spatial_distance=spatial_distance, distance_from=distance_from, distance_to=distance_to, log=log ) # Determine plot_type if not provided if plot_type is None: plot_type = 'boxen' if method == 'numeric' else 'kde' # If log=True, the column name is 'log_distance', else 'distance' distance_col = 'log_distance' if log else 'distance' # Dispatch to the plot logic result_dict = _plot_spatial_distance_dispatch( df_long=df_long, method=method, plot_type=plot_type, stratify_by=stratify_by, facet_plot=facet_plot, distance_col=distance_col, **kwargs ) return result_dict