scgpt.data_sampler 源代码

from typing import Iterable, List, Sequence

import numpy as np
import torch
from torch.utils.data import Sampler, SubsetRandomSampler, BatchSampler


[文档] class SubsetSequentialSampler(Sampler): """Samples elements sequentially from a given list of indices, without replacement. Arguments: indices (sequence): a sequence of indices """ def __init__(self, indices: Sequence[int]): self.indices = indices def __iter__(self) -> Iterable[int]: return iter(self.indices) def __len__(self) -> int: return len(self.indices)
[文档] class SubsetsBatchSampler(Sampler[List[int]]): r"""Samples batches of indices from a list of subsets of indices. Each subset of indices represents a data subset and is sampled without replacement randomly or sequentially. Specially, each batch only contains indices from a single subset. This sampler is for the scenario where samples need to be drawn from multiple subsets separately. Arguments: subsets (List[Sequence[int]]): A list of subsets of indices. batch_size (int): Size of mini-batch. intra_subset_shuffle (bool): If ``True``, the sampler will shuffle the indices within each subset. inter_subset_shuffle (bool): If ``True``, the sampler will shuffle the order of subsets. drop_last (bool): If ``True``, the sampler will drop the last batch if its size would be less than ``batch_size``. """ def __init__( self, subsets: List[Sequence[int]], batch_size: int, intra_subset_shuffle: bool = True, inter_subset_shuffle: bool = True, drop_last: bool = False, ): self.subsets = subsets self.batch_size = batch_size self.intra_subset_shuffle = intra_subset_shuffle self.inter_subset_shuffle = inter_subset_shuffle self.drop_last = drop_last if intra_subset_shuffle: self.subset_samplers = [SubsetRandomSampler(subset) for subset in subsets] else: self.subset_samplers = [ SubsetSequentialSampler(subset) for subset in subsets ] self.batch_samplers = [ BatchSampler(sampler, batch_size, drop_last) for sampler in self.subset_samplers ] if inter_subset_shuffle: # maintain a mapping from sample batch index to batch sampler _id_to_batch_sampler = [] for i, batch_sampler in enumerate(self.batch_samplers): _id_to_batch_sampler.extend([i] * len(batch_sampler)) self._id_to_batch_sampler = np.array(_id_to_batch_sampler) assert len(self._id_to_batch_sampler) == len(self) self.batch_sampler_iterrators = [ batch_sampler.__iter__() for batch_sampler in self.batch_samplers ] def __iter__(self) -> Iterable[List[int]]: if self.inter_subset_shuffle: # randomly sample from batch samplers random_idx = torch.randperm(len(self._id_to_batch_sampler)) batch_sampler_ids = self._id_to_batch_sampler[random_idx] for batch_sampler_id in batch_sampler_ids: batch_sampler_iter = self.batch_sampler_iterrators[batch_sampler_id] yield next(batch_sampler_iter) else: for batch_sampler in self.batch_samplers: yield from batch_sampler def __len__(self) -> int: return sum(len(batch_sampler) for batch_sampler in self.batch_samplers)