grain.transforms module#

Data transformation APIs.

List of Members#

class grain.transforms.Batch(batch_size: 'int', drop_remainder: 'bool' = False, batch_fn: 'Callable[[Sequence[Any]], Any] | None' = None)[source]#
Parameters:
  • batch_size (int)

  • drop_remainder (bool)

  • batch_fn (Callable[[Sequence[Any]], Any] | None)

class grain.transforms.Filter[source]#

Abstract base class for filter transformations for individual elements.

The pipeline will drop any element for which the filter function returns False.

Implementations should be threadsafe since they are often executed in parallel.

class grain.transforms.Map[source]#

Abstract base class for all 1:1 transformations of elements.

Implementations should be threadsafe since they are often executed in parallel.

class grain.transforms.MapWithIndex[source]#

Abstract base class for 1:1 transformations of elements and their index.

Implementations should be threadsafe since they are often executed in parallel.

class grain.transforms.RandomMap[source]#

Abstract base class for all random 1:1 transformations of elements.

Implementations should be threadsafe since they are often executed in parallel.

grain.transforms.Transformation#

alias of Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex

grain.transforms.Transformations#

alias of Sequence[Batch | Map | RandomMap | TfRandomMap | Filter | FlatMap | MapWithIndex]

class grain.transforms.DatasetSelectionMap[source]#

Map from index to (constituent dataset index, index within dataset).

This abstract base class defines the interface for mapping a single global sequence index across multiple underlying (constituent) datasets. It acts as a routing table for mixed or concatenated data pipelines.

Note, this must be stateless, picklable and should avoid randomness to support determinism since it may be created and called concurrently in multiple processes.

Example

Implementing a simple concatenation map for two datasets of size 2 and 3:

import grain

class ConcatMap(grain.transforms.DatasetSelectionMap):
  def __len__(self) -> int:
    return 7
  def __getitem__(self, index: int) -> tuple[int, int]:
    if index >= len(self):
      raise IndexError("Index out of range")
    if index < 3:
      return (0, index)
    else:
      return (1, index - 3)

cmap = ConcatMap()
assert len(cmap) == 7
assert cmap[3] == (1, 0)