Definition of Transform and Pipeline
from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *

The classes here provide functionality for creating a composition of partially reversible functions. By “partially reversible” we mean that a transform can be decoded, creating a form suitable for display. This is not necessarily identical to the original form (e.g. a transform that changes a byte tensor to a float tensor does not recreate a byte tensor when decoded, since that may lose precision, and a float tensor can be displayed already).

Classes are also provided and for composing transforms, and mapping them over collections. Pipeline is a transform which composes several Transform, knowing how to decode them or show an encoded item.



 Transform (enc=None, dec=None, split_idx=None, order=None)

Delegates (__call__,decode,setup) to (encodes,decodes,setups) if split_idx matches

A Transform is the main building block of the fastai data pipelines. In the most general terms a transform can be any function you want to apply to your data, however the Transform class provides several mechanisms that make the process of building them easy and flexible.

The main Transform features:

  • Type dispatch - Type annotations are used to determine if a transform should be applied to the given argument. It also gives an option to provide several implementations and it choses the one to run based on the type. This is useful for example when running both independent and dependent variables through the pipeline where some transforms only make sense for one and not the other. Another usecase is designing a transform that handles different data formats. Note that if a transform takes multiple arguments only the type of the first one is used for dispatch.
  • Handling of tuples - When a tuple (or a subclass of tuple) of data is passed to a transform it will get applied to each element separately. You can opt out of this behavior by passing a list or an L, as only tuples gets this specific behavior. An alternative is to use ItemTransform defined below, which will always take the input as a whole.
  • Reversability - A transform can be made reversible by implementing the decodes method. This is mainly used to turn something like a category which is encoded as a number back into a label understandable by humans for showing purposes. Like the regular call method, the decode method that is used to decode will be applied over each element of a tuple separately.
  • Type propagation - Whenever possible a transform tries to return data of the same type it received. Mainly used to maintain semantics of things like ArrayImage which is a thin wrapper of pytorch’s Tensor. You can opt out of this behavior by adding ->None return type annotation.
  • Preprocessing - The setup method can be used to perform any one-time calculations to be later used by the transform, for example generating a vocabulary to encode categorical data.
  • Filtering based on the dataset type - By setting the split_idx flag you can make the transform be used only in a specific DataSource subset like in training, but not validation.
  • Ordering - You can set the order attribute which the Pipeline uses when it needs to merge two lists of transforms.
  • Appending new behavior with decorators - You can easily extend an existing Transform by creating encodes or decodes methods for new data types. You can put those new methods outside the original transform definition and decorate them with the class you wish them patched into. This can be used by the fastai library users to add their own behavior, or multiple modules contributing to the same transform.

Defining a Transform

There are a few ways to create a transform with different ratios of simplicity to flexibility. - Extending the Transform class - Use inheritence to implement the methods you want. - Passing methods to the constructor - Instantiate the Transform class and pass your functions as enc and dec arguments. - @Transform decorator - Turn any function into a Transform by just adding a decorator - very straightforward if all you need is a single encodes implementation. - Passing a function to fastai APIs - Same as above, but when passing a function to other transform aware classes like Pipeline or TfmdDS you don’t even need a decorator. Your function will get converted to a Transform automatically.

A simple way to create a Transform is to pass a function to the constructor. In the below example, we pass an anonymous function that does integer division by 2:

f = Transform(lambda o:o//2)

If you call this transform, it will apply the transformation:

test_eq_type(f(2), 1)

Another way to define a Transform is to extend the Transform class:

class A(Transform): pass

However, to enable your transform to do something, you have to define an encodes method. Note that we can use the class name as a decorator to add this method to the original class.

def encodes(self, x): return x+1

f1 = A()
test_eq(f1(1), 2) # f1(1) is the same as f1.encode(1)

In addition to adding an encodes method, we can also add a decodes method. This enables you to call the decode method (without an s). For more information about the purpose of decodes, see the discussion about Reversibility in the above section.

Just like with encodes, you can add a decodes method to the original class by using the class name as a decorator:

class B(A): pass

def decodes(self, x): return x-1

f2 = B()
test_eq(f2.decode(2), 1)

test_eq(f2(1), 2) # uses A's encode method from the parent class

If you do not define an encodes or decodes method the original value will be returned:

class _Tst(Transform): pass 

f3 = _Tst() # no encodes or decodes method have been defined
test_eq_type(f3.decode(2.0), 2.0)
test_eq_type(f3(2), 2)

Transforms can be created from class methods too:

class A:
    def create(cls, x:int): return x+1
test_eq(Transform(A.create)(1), 2)

Defining Transforms With A Decorator

Transform can be used as a decorator to turn a function into a Transform.

def f(x): return x//2
test_eq_type(f(2), 1)
test_eq_type(f.decode(2.0), 2.0)

def f(x): return x*2
test_eq_type(f(2), 4)
test_eq_type(f.decode(2.0), 2.0)

Typed Dispatch and Transforms

We can also apply different transformations depending on the type of the input passed by using TypedDispatch. TypedDispatch automatically works with Transform when using type hints:

class A(Transform): pass

def encodes(self, x:int): return x//2

def encodes(self, x:float): return x+1

When we pass in an int, this calls the first encodes method:

f = A()
test_eq_type(f(3), 1)

When we pass in a float, this calls the second encodes method:

test_eq_type(f(2.), 3.)

When we pass in a type that is not specified in encodes, the original value is returned:

test_eq(f('a'), 'a')

If the type annotation is a tuple, then any type in the tuple will match:

class MyClass(int): pass

class A(Transform):
    def encodes(self, x:MyClass|float): return x/2
    def encodes(self, x:str|list): return str(x)+'_1'

f = A()

The below two examples match the first encodes, with a type of MyClass and float, respectively:

test_eq(f(MyClass(2)), 1.) # input is of type MyClass 
test_eq(f(6.0), 3.0) # input is of type float

The next two examples match the second encodes method, with a type of str and list, respectively:

test_eq(f('a'), 'a_1') # input is of type str
test_eq(f(['a','b','c']), "['a', 'b', 'c']_1") # input is of type list

Casting Types With Transform

Without any intervention it is easy for operations to change types in Python. For example, FloatSubclass (defined below) becomes a float after performing multiplication:

class FloatSubclass(float): pass
test_eq_type(FloatSubclass(3.0) * 2, 6.0)

This behavior is often not desirable when performing transformations on data. Therefore, Transform will attempt to cast the output to be of the same type as the input by default. In the below example, the output will be cast to a FloatSubclass type to match the type of the input:

def f(x): return x*2

test_eq_type(f(FloatSubclass(3.0)), FloatSubclass(6.0))

We can optionally turn off casting by annotating the transform function with a return type of None:

def f(x)-> None: return x*2 # Same transform as above, but with a -> None annotation

test_eq_type(f(FloatSubclass(3.0)), 6.0)  # Casting is turned off because of -> None annotation

However, Transform will only cast output back to the input type when the input is a subclass of the output. In the below example, the input is of type FloatSubclass which is not a subclass of the output which is of type str. Therefore, the output doesn’t get cast back to FloatSubclass and stays as type str:

def f(x): return str(x)
test_eq_type(f(Float(2.)), '2.0')

Just like encodes, the decodes method will cast outputs to match the input type in the same way. In the below example, the output of decodes remains of type MySubclass:

class MySubclass(int): pass

def enc(x): return MySubclass(x+1)
def dec(x): return x-1

f = Transform(enc,dec)
t = f(1) # t is of type MySubclass
test_eq_type(f.decode(t), MySubclass(1)) # the output of decode is cast to MySubclass to match the input type.

Apply Transforms On Subsets With split_idx

You can apply transformations to subsets of data by specifying a split_idx property. If a transform has a split_idx then it’s only applied if the split_idx param matches. In the below example, we set split_idx equal to 1:

def enc(x): return x+1
def dec(x): return x-1
f = Transform(enc,dec)
f.split_idx = 1

The transformations are applied when a matching split_idx parameter is passed:

test_eq(f(1, split_idx=1),2)
test_eq(f.decode(2, split_idx=1),1)

On the other hand, transformations are ignored when the split_idx parameter does not match:

test_eq(f(1, split_idx=0), 1)
test_eq(f.decode(2, split_idx=0), 2)

Transforms on Lists

Transform operates on lists as a whole, not element-wise:

class A(Transform):
    def encodes(self, x): return dict(x)
    def decodes(self, x): return list(x.items())
f = A()
_inp = [(1,2), (3,4)]
t = f(_inp)

test_eq(t, dict(_inp))
test_eq(f.decodes(t), _inp)

If you want a transform to operate on a list elementwise, you must implement this appropriately in the encodes and decodes methods:

class AL(Transform): pass

def encodes(self, x): return [x_+1 for x_ in x]

def decodes(self, x): return [x_-1 for x_ in x]

f = AL()
t = f([1,2])

test_eq(t, [2,3])
test_eq(f.decode(t), [1,2])

Transforms on Tuples

Unlike lists, Transform operates on tuples element-wise.

def neg_int(x): return -x
f = Transform(neg_int)

test_eq(f((1,2,3)), (-1,-2,-3))

Transforms will also apply TypedDispatch element-wise on tuples when an input type annotation is specified. In the below example, the values 1.0 and 3.0 are ignored because they are of type float, not int:

def neg_int(x:int): return -x
f = Transform(neg_int)

test_eq(f((1.0, 2, 3.0)), (1.0, -2, 3.0))

Another example of how Transform can use TypedDispatch with tuples is shown below:

class B(Transform): pass

def encodes(self, x:int): return x+1

def encodes(self, x:str): return x+'hello'

def encodes(self, x): return str(x)+'!'

If the input is not an int or str, the third encodes method will apply:

b = B()
test_eq(b([1]), '[1]!') 
test_eq(b([1.0]), '[1.0]!')

However, if the input is a tuple, then the appropriate method will apply according to the type of each element in the tuple:

test_eq(b(('1',)), ('1hello',))
test_eq(b((1,2)), (2,3))
test_eq(b(('a',1.0)), ('ahello','1.0!'))

Dispatching over tuples works recursively, by the way:

class B(Transform):
    def encodes(self, x:int): return x+1
    def encodes(self, x:str): return x+'_hello'
    def decodes(self, x:int): return x-1
    def decodes(self, x:str): return x.replace('_hello', '')

f = B()
start = (1.,(2,'3'))
t = f(start)
test_eq_type(t, (1.,(3,'3_hello')))
test_eq(f.decode(t), start)

Dispatching also works with typing module type classes, like numbers.integral:

def f(x:numbers.Integral): return x+1

t = f((1,'1',1))
test_eq(t, (2, '1', 2))



 InplaceTransform (enc=None, dec=None, split_idx=None, order=None)

A Transform that modifies in-place and just returns whatever it’s passed

class A(InplaceTransform): pass

def encodes(self, x:pd.Series): x.fillna(10, inplace=True)
f = A()

test_eq_type(f(pd.Series([1,2,None])),pd.Series([1,2,10],dtype=np.float64)) #fillna fills with floats.



 DisplayedTransform (enc=None, dec=None, split_idx=None, order=None)

A transform with a __repr__ that shows its attrs

Transforms normally are represented by just their class name and a list of encodes and decodes implementations:

class A(Transform): encodes,decodes = noop,noop
f = A()
encodes: (object,object) -> noop
decodes: (object,object) -> noop

A DisplayedTransform will in addition show the contents of all attributes listed in the comma-delimited string self.store_attrs:

class A(DisplayedTransform):
    encodes = noop
    def __init__(self, a, b=2):
A -- {'a': 1, 'b': 2}:
encodes: (object,object) -> noop



 ItemTransform (enc=None, dec=None, split_idx=None, order=None)

A transform that always take tuples as items

ItemTransform is the class to use to opt out of the default behavior of Transform.

class AIT(ItemTransform): 
    def encodes(self, xy): x,y=xy; return (x+y,y)
    def decodes(self, xy): x,y=xy; return (x-y,y)
f = AIT()
test_eq(f((1,2)), (3,2))
test_eq(f.decode((3,2)), (1,2))

If you pass a special tuple subclass, the usual retain type behavior of Transform will keep it:

class _T(tuple): pass
x = _T((1,2))
test_eq_type(f(x), _T((3,2)))



 get_func (t, name, *args, **kwargs)

Get the (potentially partial-ized with args and kwargs) or noop if not defined

This works for any kind of t supporting getattr, so a class or a module.

test_eq(get_func(operator, 'neg', 2)(), -2)
test_eq(get_func(operator.neg, '__call__')(2), -2)
test_eq(get_func(list, 'foobar')([2]), [2])
a = [2,1]
get_func(list, 'sort')(a)
test_eq(a, [1,2])

Transforms are built with multiple-dispatch: a given function can have several methods depending on the type of the object received. This is done directly with the TypeDispatch module and type-annotation in Transform, but you can also use the following class.



 Func (name, *args, **kwargs)

Basic wrapper around a name with args and kwargs to call on a given type

You can call the Func object on any module name or type, even a list of types. It will return the corresponding function (with a default to noop if nothing is found) or list of functions.

test_eq(Func('sqrt')(math), math.sqrt)


 Sig (*args, **kwargs)

Sig is just sugar-syntax to create a Func object more easily with the syntax*args, **kwargs).

f = Sig.sqrt()
test_eq(f(math), math.sqrt)



 compose_tfms (x, tfms, is_enc=True, reverse=False, **kwargs)

Apply all func_nm attribute of tfms on x, maybe in reverse order

def to_int  (x):   return Int(x)
def to_float(x):   return Float(x)
def double  (x):   return x*2
def half(x)->None: return x/2
def test_compose(a, b, *fs): test_eq_type(compose_tfms(a, tfms=map(Transform,fs)), b)

test_compose(1,   Int(1),   to_int)
test_compose(1,   Float(1), to_int,to_float)
test_compose(1,   Float(2), to_int,to_float,double)
test_compose(2.0, 2.0,      to_int,double,half)
class A(Transform):
    def encodes(self, x:float):  return Float(x+1)
    def decodes(self, x): return x-1
tfms = [A(), Transform(math.sqrt)]
t = compose_tfms(3., tfms=tfms)
test_eq_type(t, Float(2.))
test_eq(compose_tfms(t, tfms=tfms, is_enc=False), 1.)
test_eq(compose_tfms(4., tfms=tfms, reverse=True), 3.)
tfms = [A(), Transform(math.sqrt)]
test_eq(compose_tfms((9,3.), tfms=tfms), (3,2.))



 mk_transform (f)

Convert function f to Transform if it isn’t already one



 gather_attrs (o, k, nm)

Used in getattr to collect all attrs k from self.{nm}



 gather_attr_names (o, nm)

Used in dir to collect all attrs k from self.{nm}



 Pipeline (funcs=None, split_idx=None)

A pipeline of composed (for encode/decode) transforms, setup with types

         __call__="Compose `__call__` of all `fs` on `o`",
         decode="Compose `decode` of all `fs` on `o`",
         show="Show `o`, a single item from a tuple, decoding as needed",
         add="Add transforms `ts`",
         setup="Call each tfm's `setup` in order")

Pipeline is a wrapper for compose_tfms. You can pass instances of Transform or regular functions in funcs, the Pipeline will wrap them all in Transform (and instantiate them if needed) during the initialization. It handles the transform setup by adding them one at a time and calling setup on each, goes through them in order in __call__ or decode and can show an object by applying decoding the transforms up until the point it gets an object that knows how to show itself.

# Empty pipeline is noop
pipe = Pipeline()
test_eq(pipe(1), 1)
test_eq(pipe((1,)), (1,))
# Check pickle works
assert pickle.loads(pickle.dumps(pipe))
class IntFloatTfm(Transform):
    def encodes(self, x):  return Int(x)
    def decodes(self, x):  return Float(x)


def neg(x): return -x
neg_tfm = Transform(neg, neg)
pipe = Pipeline([neg_tfm, int_tfm])

start = 2.0
t = pipe(start)
test_eq_type(t, Int(-2))
test_eq_type(pipe.decode(t), Float(start))
test_stdout(, '-2')
pipe = Pipeline([neg_tfm, int_tfm])
t = pipe(start)
test_stdout(,2.))), '-1\n-2')
test_eq(, 1)
assert 'foo' in dir(pipe)
assert 'int_float_tfm' in dir(pipe)

You can add a single transform or multiple transforms ts using Pipeline.add. Transforms will be ordered by Transform.order.

pipe = Pipeline([neg_tfm, int_tfm])
class SqrtTfm(Transform):
    def encodes(self, x): 
        return x**(.5)
    def decodes(self, x): return x**2

Transforms are available as attributes named with the snake_case version of the names of their types. Attributes in transforms can be directly accessed as attributes of the pipeline.

test_eq(pipe.int_float_tfm, int_tfm)
test_eq(, 1)

pipe = Pipeline([int_tfm, int_tfm])
test_eq(pipe.int_float_tfm[0], int_tfm)
test_eq(, [1,1])
# Check opposite order
pipe = Pipeline([int_tfm,neg_tfm])
t = pipe(start)
test_eq(t, -2)
test_stdout(, '-2')
class A(Transform):
    def encodes(self, x):  return int(x)
    def decodes(self, x):  return Float(x)

pipe = Pipeline([neg_tfm, A])
t = pipe(start)
test_eq_type(t, -2)
test_eq_type(pipe.decode(t), Float(start))
test_stdout(, '-2.0')
s2 = (1,2)
pipe = Pipeline([neg_tfm, A])
t = pipe(s2)
test_eq_type(t, (-1,-2))
test_eq_type(pipe.decode(t), (Float(1.),Float(2.)))
test_stdout(, '-1.0\n-2.0')
from PIL import Image
class ArrayImage(ndarray):
    _show_args = {'cmap':'viridis'}
    def __new__(cls, x, *args, **kwargs):
        if isinstance(x,tuple): super().__new__(cls, x, *args, **kwargs)
        if args or kwargs: raise RuntimeError('Unknown array init args')
        if not isinstance(x,ndarray): x = array(x)
        return x.view(cls)
    def show(self, ctx=None, figsize=None, **kwargs):
        if ctx is None: _,ctx = plt.subplots(figsize=figsize)
        ctx.imshow(im, **{**self._show_args, **kwargs})
        return ctx
im =
im_t = ArrayImage(im)
def f1(x:ArrayImage): return -x
def f2(x): return,128))
def f3(x:Image.Image): return(ArrayImage(array(x)))
pipe = Pipeline([f2,f3,f1])
t = pipe(TEST_IMAGE)
test_eq(type(t), ArrayImage)
test_eq(t, -array(f3(f2(TEST_IMAGE))))
pipe = Pipeline([f2,f3])
t = pipe(TEST_IMAGE)
ax =

#Check filtering is properly applied
add1 = B()
add1.split_idx = 1
pipe = Pipeline([neg_tfm, A(), add1])
test_eq(pipe(start), -2)
test_eq(pipe(start), -1)
test_eq(pipe(start), -2)
for t in [None, 0, 1]:
    test_eq(pipe.decode(pipe(start)), start)
    test_stdout(lambda:, "-2.0")
def neg(x): return -x
test_eq(type(mk_transform(neg)), Transform)
test_eq(type(mk_transform(math.sqrt)), Transform)
test_eq(type(mk_transform(lambda a:a*2)), Transform)
test_eq(type(mk_transform(Pipeline([neg]))), Pipeline)


#TODO: method examples



 Pipeline.__call__ (o)

Call self as a function.



 Pipeline.decode (o, full=True)



 Pipeline.setup (items=None, train_setup=False)

During the setup, the Pipeline starts with no transform and adds them one at a time, so that during its setup, each transform gets the items processed up to its point and not after.