import copy import warnings from collections.abc import Mapping, Sequence from typing import Any, TypeVar, Union from torch.utils.data.datapipes.datapipe import MapDataPipe _T = TypeVar("_T") __all__ = ["SequenceWrapperMapDataPipe"] class SequenceWrapperMapDataPipe(MapDataPipe[_T]): r""" Wraps a sequence object into a MapDataPipe. Args: sequence: Sequence object to be wrapped into an MapDataPipe deepcopy: Option to deepcopy input sequence object .. note:: If ``deepcopy`` is set to False explicitly, users should ensure that data pipeline doesn't contain any in-place operations over the iterable instance, in order to prevent data inconsistency across iterations. Example: >>> # xdoctest: +SKIP >>> from torchdata.datapipes.map import SequenceWrapper >>> dp = SequenceWrapper(range(10)) >>> list(dp) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] >>> dp = SequenceWrapper({"a": 100, "b": 200, "c": 300, "d": 400}) >>> dp["a"] 100 """ sequence: Union[Sequence[_T], Mapping[Any, _T]] def __init__( self, sequence: Union[Sequence[_T], Mapping[Any, _T]], deepcopy: bool = True ) -> None: if deepcopy: try: self.sequence = copy.deepcopy(sequence) except TypeError: warnings.warn( "The input sequence can not be deepcopied, " "please be aware of in-place modification would affect source data" ) self.sequence = sequence else: self.sequence = sequence def __getitem__(self, index: int) -> _T: return self.sequence[index] def __len__(self) -> int: return len(self.sequence)