Source code for singlet.dataset.plot

# vim: fdm=indent
# author:     Fabio Zanini
# date:       16/08/17
# content:    Dataset functions to plot gene expression and phenotypes
# Modules
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import cm
from ..config import config


try:
    import seaborn as sns
except (ImportError, RuntimeError):
    if 'seaborn_import' not in config['_once_warnings']:
        warnings.warn('Unable to import seaborn: plotting will not work')
        config['_once_warnings'].append('seaborn_import')
    sns = None

try:
    import matplotlib.pyplot as plt
except (ImportError, RuntimeError):
    if 'pyplot_import' not in config['_once_warnings']:
        warnings.warn('Unable to import matplotlib.pyplot: plotting will not work')
        config['_once_warnings'].append('pyplot_import')
    plt = None


# Classes / functions
[docs]class Plot(): '''Plot gene expression and phenotype in single cells''' def __init__(self, dataset): '''Plot gene expression and phenotype in single cells Args: dataset (Dataset): the dataset to analyze. ''' self.dataset = dataset @staticmethod def _update_properties(kwargs, defaults): Plot._sanitize_plot_properties(kwargs) for key, val in defaults.items(): if key not in kwargs: kwargs[key] = val @staticmethod def _sanitize_plot_properties(kwargs): aliases = { 'linewidth': 'lw', 'antialiased': 'aa', 'color': 'c', 'linestyle': 'ls', 'markeredgecolor': 'mec', 'markeredgewidth': 'mew', 'markerfacecolor': 'mfc', 'markerfacecoloralt': 'mfcalt', 'markersize': 'ms', } for key, alias in aliases.items(): if alias in kwargs: kwargs[key] = kwargs.pop(alias)
[docs] def plot_coverage( self, features='total', kind='cumulative', ax=None, tight_layout=True, legend=False, **kwargs): '''Plot number of reads for each sample Args: features (list or string): Features to sum over. The string \ 'total' means all features including spikeins and other, \ 'mapped' means all features excluding spikeins and other, \ 'spikeins' means only spikeins, and 'other' means only \ 'other' features. kind (string): Kind of plot (default: cumulative distribution). ax (matplotlib.axes.Axes): The axes to plot into. If None \ (default), a new figure with one axes is created. ax must \ not strictly be a matplotlib class, but it must have \ common methods such as 'plot' and 'set'. tight_layout (bool or dict): Whether to call \ matplotlib.pyplot.tight_layout at the end of the \ plotting. If it is a dict, pass it unpacked to that \ function. legend (bool or dict): If True, call ax.legend(). If a dict, \ pass as **kwargs to ax.legend. **kwargs: named arguments passed to the plot function. Returns: matplotlib.axes.Axes with the axes contaiing the plot. ''' if ax is None: new_axes = True fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(13, 8)) else: new_axes = False defaults = { 'linewidth': 2, 'color': 'darkgrey', } Plot._update_properties(kwargs, defaults) counts = self.dataset.counts if features == 'total': pass elif features == 'mapped': counts = counts.exclude_features(spikeins=True, other=True) elif features == 'spikeins': counts = counts.get_spikeins() elif features == 'other': counts = counts.get_other_features() else: counts = counts.loc[features] if kind == 'cumulative': x = counts.values.sum(axis=0) x.sort() y = 1.0 - np.linspace(0, 1, len(x)) ax.plot(x, y, **kwargs) ax_props = { 'ylim': (-0.05, 1.05), 'ylabel': 'Cumulative distribution'} else: raise ValueError('Plot kind not understood') if not counts._normalized: ax_props['xlabel'] = 'Number of reads' elif counts._normalized != 'custom': ax_props['xlabel'] = counts._normalized.capitalize().replace('_', ' ') if new_axes: xmin = 0.5 xmax = 1.05 * x.max() ax_props['xlim'] = (xmin, xmax) ax_props['xscale'] = 'log' ax.grid(True) ax.set(**ax_props) if legend: if np.isscalar(legend): ax.legend() else: ax.legend(**legend) if tight_layout: if isinstance(tight_layout, dict): plt.tight_layout(**tight_layout) else: plt.tight_layout() return ax
[docs] def scatter_statistics( self, features='mapped', x='mean', y='cv', ax=None, tight_layout=True, legend=False, grid=None, **kwargs): '''Scatter plot statistics of features. Args: features (list or string): List of features to plot. The string \ 'mapped' means everything excluding spikeins and other, \ 'all' means everything including spikeins and other. x (string): Statistics to plot on the x axis. y (string): Statistics to plot on the y axis. ax (matplotlib.axes.Axes): The axes to plot into. If None \ (default), a new figure with one axes is created. ax must \ not strictly be a matplotlib class, but it must have \ common methods such as 'plot' and 'set'. tight_layout (bool or dict): Whether to call \ matplotlib.pyplot.tight_layout at the end of the \ plotting. If it is a dict, pass it unpacked to that \ function. legend (bool or dict): If True, call ax.legend(). If a dict, \ pass as **kwargs to ax.legend. grid (bool or None): Whether to add a grid to the plot. None \ defaults to your existing settings. **kwargs: named arguments passed to the plot function. Returns: matplotlib.axes.Axes with the axes contaiing the plot. ''' if ax is None: new_axes = True fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(13, 8)) else: new_axes = False defaults = { 's': 10, 'color': 'darkgrey', } Plot._update_properties(kwargs, defaults) counts = self.dataset.counts if features == 'total': if not counts._otherfeatures.isin(counts.index).all(): raise ValueError('Other features not found in counts') if not counts._spikeins.isin(counts.index).all(): raise ValueError('Spike-ins not found in counts') pass elif features == 'mapped': counts = counts.exclude_features( spikeins=True, other=True, errors='ignore') else: counts = counts.loc[features] stats = counts.get_statistics(metrics=(x, y)) ax_props = {'xlabel': x, 'ylabel': y} x = stats.loc[:, x] y = stats.loc[:, y] ax.scatter(x, y, **kwargs) if ax_props['xlabel'] == 'mean': xmin = 0.5 xmax = 1.05 * x.max() ax_props['xlim'] = (xmin, xmax) ax_props['xscale'] = 'log' elif ax_props['ylabel'] == 'mean': ymin = 0.5 ymax = 1.05 * y.max() ax_props['ylim'] = (ymin, ymax) ax_props['yscale'] = 'log' if ax_props['xlabel'] == 'cv': xmin = 0 xmax = 1.05 * x.max() ax_props['xlim'] = (xmin, xmax) elif ax_props['ylabel'] == 'cv': ymin = 0 ymax = 1.05 * y.max() ax_props['ylim'] = (ymin, ymax) if grid is not None: ax.grid(grid) ax.set(**ax_props) if legend: if np.isscalar(legend): ax.legend() else: ax.legend(**legend) if tight_layout: if isinstance(tight_layout, dict): plt.tight_layout(**tight_layout) else: plt.tight_layout()
[docs] def gate_features_from_statistics( self, features='mapped', x='mean', y='cv', **kwargs): '''Select features for downstream analysis with a gate. Usage: Click with the left mouse button to set the vertices of a \ polygon. Double left-click closes the shape. Right click \ resets the plot. Args: features (list or string): List of features to plot. The string \ 'mapped' means everything excluding spikeins and other, \ 'all' means everything including spikeins and other. x (string): Statistics to plot on the x axis. y (string): Statistics to plot on the y axis. **kwargs: named arguments passed to the plot function. Returns: pd.Index of features within the gate. ''' is_interactive = mpl.is_interactive() plt.ioff() fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(13, 8)) defaults = { 's': 10, 'color': 'darkgrey', } Plot._update_properties(kwargs, defaults) counts = self.dataset.counts if features == 'total': if not counts._otherfeatures.isin(counts.index).all(): raise ValueError('Other features not found in counts') if not counts._spikeins.isin(counts.index).all(): raise ValueError('Spike-ins not found in counts') pass elif features == 'mapped': counts = counts.exclude_features( spikeins=True, other=True, errors='ignore') else: counts = counts.loc[features] stats = counts.get_statistics(metrics=(x, y)) ax_props = {'xlabel': x, 'ylabel': y} x = stats.loc[:, x] y = stats.loc[:, y] ax.scatter(x, y, **kwargs) if ax_props['xlabel'] == 'mean': xmin = 0.5 xmax = 1.05 * x.max() ax_props['xlim'] = (xmin, xmax) ax_props['xscale'] = 'log' elif ax_props['ylabel'] == 'mean': ymin = 0.5 ymax = 1.05 * y.max() ax_props['ylim'] = (ymin, ymax) ax_props['yscale'] = 'log' if ax_props['xlabel'] == 'cv': xmin = 0 xmax = 1.05 * x.max() ax_props['xlim'] = (xmin, xmax) elif ax_props['ylabel'] == 'cv': ymin = 0 ymax = 1.05 * y.max() ax_props['ylim'] = (ymin, ymax) ax.grid(True) ax.set(**ax_props) # event handling cids = {'press': None, 'release': None} polygon = [] selected = [] annotations = [] def onpress(event): if event.button == 1: return onpress_left(event) elif event.button in (2, 3): return onpress_right(event) def onpress_left(event): xp = event.xdata yp = event.ydata if len(polygon) == 0: h = ax.scatter([xp], [yp], s=50, color='red') polygon.append({ 'x': xp, 'y': yp, 'handle': h}) else: if len(polygon) == 1: polygon[0]['handle'].remove() polygon[0]['handle'] = None xp0 = polygon[-1]['x'] yp0 = polygon[-1]['y'] h = ax.plot([xp0, xp], [yp0, yp], lw=2, color='red')[0] polygon.append({ 'x': xp, 'y': yp, 'handle': h}) fig.canvas.draw() if event.dblclick: return ondblclick_left(event) def ondblclick_left(event): from matplotlib import path # Close the polygon xp = polygon[0]['x'] yp = polygon[0]['y'] xp0 = polygon[-1]['x'] yp0 = polygon[-1]['y'] h = ax.plot([xp0, xp], [yp0, yp], lw=2, color='red')[0] polygon[0]['handle'] = h fig.canvas.draw() xv = x.values.copy() yv = y.values.copy() iv = x.index.values # A polygon in linear and log is not the same xscale = ax.get_xscale() yscale = ax.get_yscale() if xscale == 'log': xv = np.log(xv) if yscale == 'log': yv = np.log(yv) pa = [] for p in polygon: xp = p['x'] yp = p['y'] if xscale == 'log': xp = np.log(xp) if yscale == 'log': yp = np.log(yp) pa.append([xp, yp]) pa = path.Path(pa) points = list(zip(xv, yv)) ind = pa.contains_points(points).nonzero()[0] for ix in ind: selected.append(iv[ix]) # Annotate plot for ix in ind: h = ax.text( x.iloc[ix], y.iloc[ix], ' '+x.index[ix], ha='left', va='bottom') annotations.append(h) fig.canvas.draw() # Let go of the code flow if is_interactive: plt.ion() def onpress_right(event): for elem in polygon: h = elem['handle'] if h is not None: elem['handle'].remove() for i in range(len(polygon)): del polygon[-1] for h in annotations: h.remove() for i in range(len(annotations)): del annotations[-1] for i in range(len(selected)): del selected[-1] fig.canvas.draw() def onrelease(event): pass def axes_enter(event): cids['press'] = fig.canvas.mpl_connect('button_press_event', onpress) cids['release'] = fig.canvas.mpl_connect('button_release_event', onrelease) def axes_leave(event): fig.canvas.mpl_disconnect(cids['press']) fig.canvas.mpl_disconnect(cids['release']) cids['press'] = None cids['release'] = None fig.canvas.draw() fig.canvas.mpl_connect('axes_enter_event', axes_enter) fig.canvas.mpl_connect('axes_leave_event', axes_leave) plt.tight_layout() plt.show() return selected
[docs] def plot_distributions( self, features, kind='violin', ax=None, tight_layout=True, legend=False, orientation='vertical', sort=False, bottom=0, grid=None, **kwargs): '''Plot distribution of spike-in controls Args: features (list or string): List of features to plot. If it is the \ string 'spikeins', plot all spikeins, if the string \ 'other', plot other features. kind (string): Kind of plot, one of 'violin' (default), 'box', \ 'swarm'. ax (matplotlib.axes.Axes): Axes to plot into. If None (default), \ create a new figure and axes. tight_layout (bool or dict): Whether to call \ matplotlib.pyplot.tight_layout at the end of the \ plotting. If it is a dict, pass it unpacked to that \ function. legend (bool or dict): If True, call ax.legend(). If a dict, \ pass as **kwargs to ax.legend. orientation (string): 'horizontal' or 'vertical'. sort (bool or string): True or 'ascending' sorts the features by \ median, 'descending' uses the reverse order. bottom (float or string): The value of zero-count features. If \ you are using a log axis, you may want to set this to \ 0.1 or any other small positive number. If a string, it \ must be 'pseudocount', then the CountsTable.pseudocount \ will be used. grid (bool or None): Whether to add a grid to the plot. None \ defaults to your existing settings. **kwargs: named arguments passed to the plot function. Return: matplotlib.axes.Axes: The axes with the plot. ''' if ax is None: fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(18, 8)) counts = self.dataset.counts if features == 'spikeins': counts = counts.get_spikeins() elif features == 'other': counts = counts.get_other_features() else: counts = counts.loc[features] if sort: asc = sort != 'descending' ind = counts.median(axis=1).sort_values(ascending=asc).index counts = counts.loc[ind] if bottom == 'pseudocount': bottom = counts.pseudocount counts = np.maximum(counts, bottom) ax_props = {} if kind == 'violin': defaults = { 'scale': 'width', 'inner': 'stick', } Plot._update_properties(kwargs, defaults) sns.violinplot( data=counts.T, orient=orientation, ax=ax, **kwargs) elif kind == 'box': defaults = {} Plot._update_properties(kwargs, defaults) sns.boxplot( data=counts.T, orient=orientation, ax=ax, **kwargs) elif kind == 'swarm': defaults = {} Plot._update_properties(kwargs, defaults) sns.swarmplot( data=counts.T, orient=orientation, ax=ax, **kwargs) else: raise ValueError('Plot kind not understood') if orientation == 'vertical': ax_props['ylim'] = (0.9 * bottom, 1.1 * counts.values.max()) if not counts._normalized: ax_props['ylabel'] = 'Number of reads' elif counts._normalized != 'custom': ax_props['ylabel'] = counts._normalized.capitalize().replace('_', ' ') for label in ax.get_xmajorticklabels(): label.set_rotation(90) label.set_horizontalalignment("center") ax.grid(True, 'y') elif orientation == 'horizontal': ax_props['xlim'] = (0.9 * bottom, 1.1 * counts.values.max()) if not counts._normalized: ax_props['xlabel'] = 'Number of reads' elif counts._normalized != 'custom': ax_props['xlabel'] = counts._normalized.capitalize().replace('_', ' ') ax.grid(True, axis='x') ax.set(**ax_props) if grid is not None: ax.grid(grid) if legend: if np.isscalar(legend): ax.legend() else: ax.legend(**legend) if tight_layout: if isinstance(tight_layout, dict): plt.tight_layout(**tight_layout) else: plt.tight_layout() return ax
[docs] def scatter_reduced_samples( self, vectors_reduced, color_by=None, color_log=None, cmap='viridis', ax=None, tight_layout=True, **kwargs): '''Scatter samples after dimensionality reduction. Args: vectors_reduced (pandas.Dataframe): matrix of coordinates of the \ samples after dimensionality reduction. Rows are samples, \ columns (typically 2 or 3) are the component in the \ low-dimensional embedding. color_by (string or None): color sample dots by phenotype or \ expression of a certain feature. color_log (bool or None): use log of phenotype/expression in the \ colormap. Default None only logs expression, but not \ phenotypes. cmap (string or matplotlib colormap): color map to use for the \ sample dots. ax (matplotlib.axes.Axes): The axes to plot into. If None \ (default), a new figure with one axes is created. ax must \ not strictly be a matplotlib class, but it must have \ common methods such as 'plot' and 'set'. tight_layout (bool or dict): Whether to call \ matplotlib.pyplot.tight_layout at the end of the \ plotting. If it is a dict, pass it unpacked to that \ function. **kwargs: named arguments passed to the plot function. Returns: matplotlib.axes.Axes with the axes containing the plot. ''' if ax is None: new_axes = True fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(13, 8)) else: new_axes = False defaults = { 's': 90, } Plot._update_properties(kwargs, defaults) if color_by is None: kwargs['color'] = 'darkgrey' else: if isinstance(cmap, str): cmap = cm.get_cmap(cmap) if color_by in self.dataset.samplesheet.columns: color_data = self.dataset.samplesheet.loc[:, color_by] if hasattr(color_data, 'cat'): is_numeric = False else: is_numeric = np.issubdtype(color_data.dtype, np.number) color_by_phenotype = True elif color_by in self.dataset.counts.index: color_data = self.dataset.counts.loc[color_by] is_numeric = True color_by_phenotype = False else: raise ValueError( 'The label '+color_by+' is neither a phenotype nor a feature') # Categorical columns get just a list of colors if (hasattr(color_data, 'cat')) or (not is_numeric): cd_unique = list(np.unique(color_data.values)) c_unique = cmap(np.linspace(0, 1, len(cd_unique))) c = c_unique[[cd_unique.index(x) for x in color_data.values]] # Non-categorical numeric types are more tricky: check for NaNs else: if np.isnan(color_data.values).any(): unmask = ~np.isnan(color_data.values) else: unmask = np.ones(len(color_data), bool) cd_min = color_data.values[unmask].min() cd_max = color_data.values[unmask].max() if color_log: if color_by_phenotype: pc = 0.1 * cd_min else: pc = self.dataset.counts.pseudocount color_data = np.log10(color_data + pc) cd_min = np.log10(cd_min + pc) cd_max = np.log10(cd_max + pc) cd_norm = (color_data.values - cd_min) / (cd_max - cd_min) c = np.zeros((len(color_data), 4), float) c[unmask] = cmap(cd_norm[unmask]) # Grey-ish semitransparency for NaNs c[~unmask] = [0.75] * 3 + [0.3] kwargs['c'] = c vectors_reduced.plot( x=vectors_reduced.columns[0], y=vectors_reduced.columns[1], kind='scatter', ax=ax, **kwargs) ax.grid(True) if tight_layout: if isinstance(tight_layout, dict): plt.tight_layout(**tight_layout) else: plt.tight_layout() return ax
[docs] def clustermap( self, cluster_samples=False, cluster_features=False, phenotypes_cluster_samples=(), phenotypes_cluster_features=(), annotate_samples=False, annotate_features=False, orientation='horizontal', colorbars=False, **kwargs): '''Samples versus features / phenotypes. Args: cluster_samples (bool or linkage): Whether to cluster samples and \ show the dendrogram. Can be either, False, True, or a \ linkage from scipy.cluster.hierarchy.linkage. cluster_features (bool or linkage): Whether to cluster features \ and show the dendrogram. Can be either, False, True, or a \ linkage from scipy.cluster.hierarchy.linkage. phenotypes_cluster_samples (iterable of strings): Phenotypes to \ add to the features for joint clustering of the samples. \ If the clustering has been \ precomputed including phenotypes and the linkage matrix \ is explicitely set as cluster_samples, the *same* \ phenotypes must be specified here, in the same order. phenotypes_cluster_features (iterable of strings): Phenotypes to \ add to the features for joint clustering of the features \ and phenotypes. If the clustering has been \ precomputed including phenotypes and the linkage matrix \ is explicitely set as cluster_features, the *same* \ phenotypes must be specified here, in the same order. annotate_samples (dict, or False): Whether and how to \ annotate the samples with separate colorbars. The \ dictionary must have phenotypes or features as keys. For \ qualitative phenotypes, the values can be palette names \ or palettes (with at least as many colors as there are \ categories). For quantitative phenotypes and features, \ they can be colormap names or colormaps. annotate_features (dict, or False): Whether and how to \ annotate the featues with separate colorbars. The \ dictionary must have features metadata as keys. For \ qualitative annotations, the values can be palette names \ or palettes (with at least as many colors as there are \ categories). For quantitative annotatoins, the values \ can be colormap names or colormaps. Keys must be columns \ of the Dataset.featuresheet, except for the key 'mean \ expression' which is interpreted to mean the average of \ the counts for that feature. orientation (string): Whether the samples are on the abscissa \ ('horizontal') or on the ordinate ('vertical'). tight_layout (bool or dict): Whether to call \ matplotlib.pyplot.tight_layout at the end of the \ plotting. If it is a dict, pass it unpacked to that \ function. colorbars (bool): Whether to add colorbars. One colorbar refers \ to the heatmap. Moreover, if annotations for samples or \ features are shown, a colorbar for each of them will be \ shown as well. **kwargs: named arguments passed to seaborn.clustermap. Returns: A seaborn ClusterGrid instance. ''' data = self.dataset.counts.copy() for pheno in phenotypes_cluster_features: data.loc[pheno] = self.dataset.samplesheet.loc[:, pheno] # FIXME: what to do with NaN? if cluster_samples is True: cluster_samples = self.dataset.cluster.hierarchical( axis='samples', phenotypes=phenotypes_cluster_samples, ) linkage_samples = cluster_samples['linkage'] elif cluster_samples is False: linkage_samples = None else: linkage_samples = cluster_samples if cluster_features is True: cluster_features = self.dataset.cluster.hierarchical( axis='features', phenotypes=phenotypes_cluster_features, ) linkage_features = cluster_features['linkage'] elif cluster_features is False: linkage_features = None else: linkage_features = cluster_features if annotate_samples: cbars_samples = [] col_samples = [] for key, val in annotate_samples.items(): if key in self.dataset.samplesheet.columns: color_data = self.dataset.samplesheet.loc[:, key] is_numeric = np.issubdtype(color_data.dtype, np.number) if (color_data.dtype.name == 'category') or (not is_numeric): cmap_type = 'qualitative' else: cmap_type = 'sequential' else: color_data = self.dataset.counts.loc[key] cmap_type = 'sequential' if isinstance(val, str): if cmap_type == 'qualitative': cd_unique = list(np.unique(color_data.values)) n_colors = len(cd_unique) palette = sns.color_palette(val, n_colors=n_colors) c = [palette[cd_unique.index(x)] for x in color_data.values] cbi = {'name': key, 'palette': palette, 'ticklabels': cd_unique, 'type': 'qualitative', 'n_colors': n_colors} else: cmap = cm.get_cmap(val) vmax = np.nanmax(color_data.values) vmin = np.nanmin(color_data.values) cval = (color_data.values - vmin) / (vmax - vmin) c = cmap(cval) cbi = {'name': key, 'cmap': cmap, 'vmin': vmin, 'vmax': vmax, 'type': 'sequential'} else: if cmap_type == 'qualitative': cd_unique = list(np.unique(color_data.values)) n_colors = len(cd_unique) if len(palette) < n_colors: raise ValueError( 'Palettes must have as many colors as there are categories') palette = val c = [palette[cd_unique.index(x)] for x in color_data.values] cbi = {'name': key, 'palette': palette[:n_colors], 'ticks': cd_unique, 'type': 'qualitative', 'n_colors': n_colors} else: cmap = val vmax = np.nanmax(color_data.values) vmin = np.nanmin(color_data.values) cval = (color_data.values - vmin) / (vmax - vmin) c = cmap(cval) cbi = {'name': key, 'cmap': cmap, 'vmin': vmin, 'vmax': vmax, 'type': 'sequential'} col_samples.append(c) cbars_samples.append(cbi) col_samples = pd.DataFrame( data=[list(a) for a in col_samples], columns=color_data.index, index=annotate_samples.keys()).T else: col_samples = None if annotate_features: cbars_features = [] col_features = [] for key, val in annotate_features.items(): if key == 'mean expression': color_data = self.dataset.counts.mean(axis=1) else: color_data = self.dataset.featuresheet.loc[:, key] is_numeric = np.issubdtype(color_data.dtype, np.number) if (color_data.dtype.name == 'category') or (not is_numeric): cmap_type = 'qualitative' else: cmap_type = 'sequential' if isinstance(val, str): if cmap_type == 'qualitative': cd_unique = list(np.unique(color_data.values)) n_colors = len(cd_unique) palette = sns.color_palette(val, n_colors=n_colors) c = [palette[cd_unique.index(x)] for x in color_data.values] cbi = {'name': key, 'palette': palette, 'ticklabels': cd_unique, 'type': 'qualitative', 'n_colors': n_colors} else: cmap = cm.get_cmap(val) vmax = np.nanmax(color_data.values) vmin = np.nanmin(color_data.values) cval = (color_data.values - vmin) / (vmax - vmin) c = cmap(cval) cbi = {'name': key, 'cmap': cmap, 'vmin': vmin, 'vmax': vmax, 'type': 'sequential'} else: if cmap_type == 'qualitative': cd_unique = list(np.unique(color_data.values)) n_colors = len(cd_unique) if len(palette) < n_colors: raise ValueError( 'Palettes must have as many colors as there are categories') palette = val c = [palette[cd_unique.index(x)] for x in color_data.values] cbi = {'name': key, 'palette': palette[:n_colors], 'ticks': cd_unique, 'type': 'qualitative', 'n_colors': n_colors} else: cmap = val vmax = np.nanmax(color_data.values) vmin = np.nanmin(color_data.values) cval = (color_data.values - vmin) / (vmax - vmin) c = cmap(cval) cbi = {'name': key, 'cmap': cmap, 'vmin': vmin, 'vmax': vmax, 'type': 'sequential'} col_features.append(c) cbars_features.append(cbi) col_features = pd.DataFrame( data=[list(a) for a in col_features], columns=color_data.index, index=annotate_features.keys()).T else: col_features = None if orientation == 'horizontal': row_linkage = linkage_features col_linkage = linkage_samples row_colors = col_features col_colors = col_samples elif orientation == 'vertical': data = data.T row_linkage = linkage_samples col_linkage = linkage_features row_colors = col_samples col_colors = col_features else: raise ValueError('Orientation must be "horizontal" or "vertical".') defaults = { 'yticklabels': True, 'xticklabels': True, 'row_colors': row_colors, 'col_colors': col_colors} if row_linkage is not None: defaults.update({ 'row_cluster': True, 'row_linkage': row_linkage}) else: defaults.update({'row_cluster': False}) if col_linkage is not None: defaults.update({ 'col_cluster': True, 'col_linkage': col_linkage}) else: defaults.update({'col_cluster': False}) Plot._update_properties(kwargs, defaults) g = sns.clustermap( data=data, **kwargs) ax = g.ax_heatmap for label in ax.get_xmajorticklabels(): label.set_rotation(90) label.set_horizontalalignment("center") for label in ax.get_ymajorticklabels(): label.set_rotation(0) label.set_verticalalignment("center") if colorbars: # The colorbar for the heatmap is shown anyway if col_samples is not None: n_cbars = len(cbars_samples) caxs = [] if orientation == 'horizontal': wcb = min(0.3, 0.4 / n_cbars) xcb = 0.98 - wcb * n_cbars - 0.05 * (n_cbars - 1) else: hcb = min(0.3, 0.4 / n_cbars) ycb = 0.98 - hcb for i, cbi in enumerate(cbars_samples): if orientation == 'horizontal': cax = g.fig.add_axes((xcb, 0.955, wcb, 0.025)) else: cax = g.fig.add_axes((0.01, ycb, 0.02, hcb)) caxs.append(cax) kw = {} if cbi['type'] == 'sequential': kw['norm'] = mpl.colors.Normalize( vmin=cbi['vmin'], vmax=cbi['vmax']) cb = mpl.colorbar.ColorbarBase( cax, cmap=cbi['cmap'], orientation=orientation, **kw) else: n_colors = cbi['n_colors'] bounds = [1.0 * i / n_colors for i in range(n_colors + 1)] ticks = [(2.0 * i + 1) / (n_colors * 2) for i in range(n_colors)] kw['norm'] = mpl.colors.Normalize(vmin=0, vmax=1) cmap = mpl.colors.ListedColormap(cbi['palette']) cb = mpl.colorbar.ColorbarBase( cax, cmap=cmap, boundaries=bounds, ticks=ticks, orientation=orientation, **kw) if orientation == 'horizontal': cb.ax.set_xticklabels([str(x) for x in cbi['ticklabels']]) else: cb.ax.set_yticklabels([str(x) for x in cbi['ticklabels']]) cb.set_label(cbi['name']) if orientation == 'horizontal': xcb += wcb + 0.05 else: ycb -= hcb + 0.05 if orientation == 'horizontal': g.ax_cbars_columns = caxs else: g.ax_cbars_rows = caxs if col_features is not None: n_cbars = len(cbars_features) caxs = [] if orientation == 'horizontal': orientation_cb = 'vertical' else: orientation_cb = 'horizontal' if orientation_cb == 'horizontal': wcb = min(0.3, 0.4 / n_cbars) xcb = 0.98 - wcb * n_cbars - 0.05 * (n_cbars - 1) else: hcb = min(0.3, 0.4 / n_cbars) ycb = 0.98 - hcb for i, cbi in enumerate(cbars_features): if orientation_cb == 'horizontal': cax = g.fig.add_axes((xcb, 0.955, wcb, 0.025)) else: cax = g.fig.add_axes((0.01, ycb, 0.02, hcb)) caxs.append(cax) kw = {} if cbi['type'] == 'sequential': kw['norm'] = mpl.colors.Normalize( vmin=cbi['vmin'], vmax=cbi['vmax']) cb = mpl.colorbar.ColorbarBase( cax, cmap=cbi['cmap'], orientation=orientation_cb, **kw) else: n_colors = cbi['n_colors'] bounds = [1.0 * i / n_colors for i in range(n_colors + 1)] ticks = [(2.0 * i + 1) / (n_colors * 2) for i in range(n_colors)] kw['norm'] = mpl.colors.Normalize(vmin=0, vmax=1) cmap = mpl.colors.ListedColormap(cbi['palette']) cb = mpl.colorbar.ColorbarBase( cax, cmap=cmap, boundaries=bounds, ticks=ticks, orientation=orientation_cb, **kw) if orientation_cb == 'horizontal': cb.ax.set_xticklabels([str(x) for x in cbi['ticklabels']]) else: cb.ax.set_yticklabels([str(x) for x in cbi['ticklabels']]) cb.set_label(cbi['name']) if orientation_cb == 'horizontal': xcb += wcb + 0.05 else: ycb -= hcb + 0.05 if orientation_cb == 'horizontal': g.ax_cbars_columns = caxs else: g.ax_cbars_rows = caxs else: # Remove colorbar g.fig.get_axes()[-1].remove() # TODO: reimplement some heuristic tight_layout return g