from typing import Dict, Optional, Union
import numpy as np
import torch
from scipy.sparse import issparse
import scanpy as sc
from scanpy.get import _get_obs_rep, _set_obs_rep
from anndata import AnnData
from scgpt import logger
[文档]
class Preprocessor:
"""
Prepare data into training, valid and test split. Normalize raw expression
values, binning or using other transform into the preset model input format.
"""
def __init__(
self,
use_key: Optional[str] = None,
filter_gene_by_counts: Union[int, bool] = False,
filter_cell_by_counts: Union[int, bool] = False,
normalize_total: Union[float, bool] = 1e4,
result_normed_key: Optional[str] = "X_normed",
log1p: bool = False,
result_log1p_key: str = "X_log1p",
subset_hvg: Union[int, bool] = False,
hvg_use_key: Optional[str] = None,
hvg_flavor: str = "seurat_v3",
binning: Optional[int] = None,
result_binned_key: str = "X_binned",
):
r"""
Set up the preprocessor, use the args to config the workflow steps.
Args:
use_key (:class:`str`, optional):
The key of :class:`~anndata.AnnData` to use for preprocessing.
filter_gene_by_counts (:class:`int` or :class:`bool`, default: ``False``):
Whther to filter genes by counts, if :class:`int`, filter genes with counts
filter_cell_by_counts (:class:`int` or :class:`bool`, default: ``False``):
Whther to filter cells by counts, if :class:`int`, filter cells with counts
normalize_total (:class:`float` or :class:`bool`, default: ``1e4``):
Whether to normalize the total counts of each cell to a specific value.
result_normed_key (:class:`str`, default: ``"X_normed"``):
The key of :class:`~anndata.AnnData` to store the normalized data. If
:class:`None`, will use normed data to replce the :attr:`use_key`.
log1p (:class:`bool`, default: ``True``):
Whether to apply log1p transform to the normalized data.
result_log1p_key (:class:`str`, default: ``"X_log1p"``):
The key of :class:`~anndata.AnnData` to store the log1p transformed data.
subset_hvg (:class:`int` or :class:`bool`, default: ``False``):
Whether to subset highly variable genes.
hvg_use_key (:class:`str`, optional):
The key of :class:`~anndata.AnnData` to use for calculating highly variable
genes. If :class:`None`, will use :attr:`adata.X`.
hvg_flavor (:class:`str`, default: ``"seurat_v3"``):
The flavor of highly variable genes selection. See
:func:`scanpy.pp.highly_variable_genes` for more details.
binning (:class:`int`, optional):
Whether to bin the data into discrete values of number of bins provided.
result_binned_key (:class:`str`, default: ``"X_binned"``):
The key of :class:`~anndata.AnnData` to store the binned data.
"""
self.use_key = use_key
self.filter_gene_by_counts = filter_gene_by_counts
self.filter_cell_by_counts = filter_cell_by_counts
self.normalize_total = normalize_total
self.result_normed_key = result_normed_key
self.log1p = log1p
self.result_log1p_key = result_log1p_key
self.subset_hvg = subset_hvg
self.hvg_use_key = hvg_use_key
self.hvg_flavor = hvg_flavor
self.binning = binning
self.result_binned_key = result_binned_key
def __call__(self, adata: AnnData, batch_key: Optional[str] = None) -> Dict:
"""
format controls the different input value wrapping, including categorical
binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc.
Args:
adata (:class:`AnnData`):
The :class:`AnnData` object to preprocess.
batch_key (:class:`str`, optional):
The key of :class:`AnnData.obs` to use for batch information. This arg
is used in the highly variable gene selection step.
"""
key_to_process = self.use_key
# preliminary checks, will use later
if key_to_process == "X":
key_to_process = None # the following scanpy apis use arg None to use X
is_logged = self.check_logged(adata, obs_key=key_to_process)
# step 1: filter genes
if self.filter_gene_by_counts:
logger.info("Filtering genes by counts ...")
sc.pp.filter_genes(
adata,
min_counts=self.filter_gene_by_counts
if isinstance(self.filter_gene_by_counts, int)
else None,
)
# step 2: filter cells
if isinstance(self.filter_cell_by_counts, int):
logger.info("Filtering cells by counts ...")
sc.pp.filter_cells(
adata,
min_counts=self.filter_cell_by_counts
if isinstance(self.filter_cell_by_counts, int)
else None,
)
# step 3: normalize total
if self.normalize_total:
logger.info("Normalizing total counts ...")
normed_ = sc.pp.normalize_total(
adata,
target_sum=self.normalize_total
if isinstance(self.normalize_total, float)
else None,
layer=key_to_process,
inplace=False,
)["X"]
key_to_process = self.result_normed_key or key_to_process
_set_obs_rep(adata, normed_, layer=key_to_process)
# step 4: log1p
if self.log1p:
logger.info("Log1p transforming ...")
if is_logged:
logger.warning(
"The input data seems to be already log1p transformed. "
"Set `log1p=False` to avoid double log1p transform."
)
if self.result_log1p_key:
_set_obs_rep(
adata,
_get_obs_rep(adata, layer=key_to_process),
layer=self.result_log1p_key,
)
key_to_process = self.result_log1p_key
sc.pp.log1p(adata, layer=key_to_process)
# step 5: subset hvg
if self.subset_hvg:
logger.info("Subsetting highly variable genes ...")
if batch_key is None:
logger.warning(
"No batch_key is provided, will use all cells for HVG selection."
)
sc.pp.highly_variable_genes(
adata,
layer=self.hvg_use_key,
n_top_genes=self.subset_hvg
if isinstance(self.subset_hvg, int)
else None,
batch_key=batch_key,
flavor=self.hvg_flavor,
subset=True,
)
# step 6: binning
if self.binning:
logger.info("Binning data ...")
if not isinstance(self.binning, int):
raise ValueError(
"Binning arg must be an integer, but got {}.".format(self.binning)
)
n_bins = self.binning # NOTE: the first bin is always a spectial for zero
binned_rows = []
bin_edges = []
layer_data = _get_obs_rep(adata, layer=key_to_process)
layer_data = layer_data.A if issparse(layer_data) else layer_data
for row in layer_data:
non_zero_ids = row.nonzero()
non_zero_row = row[non_zero_ids]
bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
# bins = np.sort(np.unique(bins))
# NOTE: comment this line for now, since this will make the each category
# has different relative meaning across datasets
non_zero_digits = _digitize(non_zero_row, bins)
assert non_zero_digits.min() >= 1
assert non_zero_digits.max() <= n_bins - 1
binned_row = np.zeros_like(row, dtype=np.int64)
binned_row[non_zero_ids] = non_zero_digits
binned_rows.append(binned_row)
bin_edges.append(np.concatenate([[0], bins]))
adata.layers[self.result_binned_key] = np.stack(binned_rows)
adata.obsm["bin_edges"] = np.stack(bin_edges)
[文档]
def check_logged(self, adata: AnnData, obs_key: Optional[str] = None) -> bool:
"""
Check if the data is already log1p transformed.
Args:
adata (:class:`AnnData`):
The :class:`AnnData` object to preprocess.
obs_key (:class:`str`, optional):
The key of :class:`AnnData.obs` to use for batch information. This arg
is used in the highly variable gene selection step.
"""
data = _get_obs_rep(adata, layer=obs_key)
max_, min_ = data.max(), data.min()
if max_ > 30:
return False
if min_ < 0:
return False
non_zero_min = data[data > 0].min()
if non_zero_min >= 1:
return False
return True
def _digitize(x: np.ndarray, bins: np.ndarray, side="one") -> np.ndarray:
"""
Digitize the data into bins. This method spreads data uniformly when bins
have same values.
Args:
x (:class:`np.ndarray`):
The data to digitize.
bins (:class:`np.ndarray`):
The bins to use for digitization, in increasing order.
side (:class:`str`, optional):
The side to use for digitization. If "one", the left side is used. If
"both", the left and right side are used. Default to "one".
Returns:
:class:`np.ndarray`:
The digitized data.
"""
assert x.ndim == 1 and bins.ndim == 1
left_digits = np.digitize(x, bins)
if side == "one":
return left_digits
right_difits = np.digitize(x, bins, right=True)
rands = np.random.rand(len(x)) # uniform random numbers
digits = rands * (right_difits - left_digits) + left_digits
digits = np.ceil(digits).astype(np.int64)
return digits
[文档]
def binning(
row: Union[np.ndarray, torch.Tensor], n_bins: int
) -> Union[np.ndarray, torch.Tensor]:
"""Binning the row into n_bins."""
dtype = row.dtype
return_np = False if isinstance(row, torch.Tensor) else True
row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row
# TODO: use torch.quantile and torch.bucketize
if row.min() <= 0:
non_zero_ids = row.nonzero()
non_zero_row = row[non_zero_ids]
bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
non_zero_digits = _digitize(non_zero_row, bins)
binned_row = np.zeros_like(row, dtype=np.int64)
binned_row[non_zero_ids] = non_zero_digits
else:
bins = np.quantile(row, np.linspace(0, 1, n_bins - 1))
binned_row = _digitize(row, bins)
return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype)