Type dispatch

Basic single and dual parameter dispatch
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *




 lenient_issubclass (cls, types)

If possible return whether cls is a subclass of types, otherwise return False.

assert not lenient_issubclass(typing.Collection, list)
assert lenient_issubclass(list, typing.Collection)
assert lenient_issubclass(typing.Collection, object)
assert lenient_issubclass(typing.List, typing.Collection)
assert not lenient_issubclass(typing.Collection, typing.List)
assert not lenient_issubclass(object, typing.Callable)



 sorted_topologically (iterable, cmp=<built-in function lt>,

Return a new list containing all items from the iterable sorted topologically

td = [3, 1, 2, 5]
test_eq(sorted_topologically(td), [1, 2, 3, 5])
test_eq(sorted_topologically(td, reverse=True), [5, 3, 2, 1])
td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted_topologically(td, cmp=lenient_issubclass), [int, numbers.Integral, numbers.Number])
td = [numbers.Integral, tuple, list, int, dict]
td = sorted_topologically(td, cmp=lenient_issubclass)
assert td.index(int) < td.index(numbers.Integral)


Type dispatch, or Multiple dispatch, allows you to change the way a function behaves based upon the input types it recevies. This is a prominent feature in some programming languages like Julia. For example, this is a conceptual example of how multiple dispatch works in Julia, returning different values depending on the input types of x and y:

collide_with(x::Asteroid, y::Asteroid) = ... 
# deal with asteroid hitting asteroid

collide_with(x::Asteroid, y::Spaceship) = ... 
# deal with asteroid hitting spaceship

collide_with(x::Spaceship, y::Asteroid) = ... 
# deal with spaceship hitting asteroid

collide_with(x::Spaceship, y::Spaceship) = ... 
# deal with spaceship hitting spaceship

Type dispatch can be especially useful in data science, where you might allow different input types (i.e. numpy arrays and pandas dataframes) to function that processes data. Type dispatch allows you to have a common API for functions that do similar tasks.

The TypeDispatch class allows us to achieve type dispatch in Python. It contains a dictionary that maps types from type annotations to functions, which ensures that the proper function is called when passed inputs.



 TypeDispatch (funcs=(), bases=())

Dictionary-like object; __getitem__ matches keys of types using issubclass

To demonstrate how TypeDispatch works, we define a set of functions that accept a variety of input types, specified with different type annotations:

def f2(x:int, y:float): return x+y              #int and float for 2nd arg
def f_nin(x:numbers.Integral)->int:  return x+1 #integral numeric
def f_ni2(x:int): return x                      #integer
def f_bll(x:bool|list): return x              #bool or list
def f_num(x:numbers.Number): return x           #Number (root of numerics)

We can optionally initialize TypeDispatch with a list of functions we want to search. Printing an instance of TypeDispatch will display convenient mapping of types -> functions:

t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(object,object) -> NoneType

Note that only the first two arguments are used for TypeDispatch. If your function only contains one argument, the second parameter will be shown as object. If you pass None into TypeDispatch, then this will be displayed as (object, object) -> NoneType.

TypeDispatch is a dictionary-like object, which means that you can retrieve a function by the associated type annotation. For example, the statement:


Will return f_num because that is the matching function that has a type annotation that is a super-class of of float - numbers.Number:

assert issubclass(float, numbers.Number)
test_eq(t[float], f_num)

The same is true for other types as well:

test_eq(t[np.int32], f_nin)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
test_eq(t[np.int32], f_nin)

If you try to get a type that doesn’t match, TypeDispatch will return None:

test_eq(t[str], None)



 TypeDispatch.add (f)

Add type t and function f

This method allows you to add an additional function to an existing TypeDispatch instance :

def f_col(x:typing.Collection): return x
test_eq(t[str], f_col)
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(typing.Collection,object) -> f_col
(object,object) -> NoneType

If you accidentally add the same function more than once things will still work as expected:

test_eq(t[int], f_ni2)

However, if you add a function that has a type collision that raises an ambiguity, this will automatically resolve to the latest function added:

def f_ni3(z:int): return z # collides with f_ni2 with same type annotations
test_eq(t[int], f_ni3)

Using bases:

The argument bases can optionally accept a single instance of TypeDispatch or a collection (i.e. a tuple or list) of TypeDispatch objects. This can provide functionality similar to multiple inheritance.

These are searched for matching functions if no match in your list of functions:

def f_str(x:str): return x+'1'

t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])
t2 = TypeDispatch(f_str, bases=t) # you can optionally supply a list of TypeDispatch objects for `bases`.
(str,object) -> f_str
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(list,object) -> f_bll
(object,object) -> NoneType
test_eq(t2[int], f_ni2)       # searches `t` b/c not found in `t2`
test_eq(t2[np.int32], f_nin)  # searches `t` b/c not found in `t2`
test_eq(t2[float], f_num)     # searches `t` b/c not found in `t2`
test_eq(t2[bool], f_bll)      # searches `t` b/c not found in `t2`
test_eq(t2[str], f_str)       # found in `t`!
test_eq(t2('a'), 'a1')        # found in `t`!, and uses __call__

o = np.int32(1)
test_eq(t2(o), 2)             # found in `t2` and uses __call__

Up To Two Arguments

TypeDispatch supports up to two arguments when searching for the appropriate function. The following functions f1 and f2 both have two parameters:

def f1(x:numbers.Integral, y): return x+1  #Integral is a numeric type
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])
(int,float) -> f2
(Integral,object) -> f1

You can lookup functions from a TypeDispatch instance with two parameters like this:

test_eq(t[np.int32], f1)
test_eq(t[int,float], f2)

Keep in mind that anything beyond the first two parameters are ignored, and any collisions will be resolved in favor of the most recent function added. In the below example, f1 is ignored in favor of f2 because the first two parameters have identical type hints:

def f1(a:str, b:int, c:list): return a
def f2(a: str, b:int): return b
t = TypeDispatch([f1,f2])
test_eq(t[str, int], f2)
(str,int) -> f2


Type Dispatch matches types with functions according to whether the supplied class is a subclass or the same class of the type annotation(s) of associated functions.

Let’s consider an example where we try to retrieve the function corresponding to types of [np.int32, float].

In this scenario, f2 will not be matched. This is because the first type annotation of f2, int, is not a superclass (or the same class) of np.int32:

def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

assert not issubclass(np.int32, int)

Instead, f1 is a valid match, as its first argument is annoted with the type numbers.Integeral, which np.int32 is a subclass of:

assert issubclass(np.int32, numbers.Integral)
test_eq(t[np.int32,float], f1)

In f1 , the 2nd parameter y is not annotated, which means TypeDispatch will match anything where the first argument matches int that is not matched with anything else:

assert issubclass(int, numbers.Integral) # int is a subclass of numbers.Integral
test_eq(t[int], f1)
test_eq(t[int,int], f1)

If no match is possible, None is returned:

test_eq(t[float,float], None)



 TypeDispatch.__call__ (*args, **kwargs)

Call self as a function.

TypeDispatch is also callable. When you call an instance of TypeDispatch, it will execute the relevant function:

def f_arr(x:np.ndarray): return x.sum()
def f_int(x:np.int32): return x+1
t = TypeDispatch([f_arr, f_int])

arr = np.array([5,4,3,2,1])
test_eq(t(arr), 15) # dispatches to f_arr

o = np.int32(1)
test_eq(t(o), 2) # dispatches to f_int
assert t.first() is not None

You can also call an instance of of TypeDispatch when there are two parameters:

def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)

When no match is found, a TypeDispatch instance becomes an identity function. This default behavior is leveraged by fasatai for data transformations to provide a sensible default when a matching function cannot be found.

test_eq(t('a'), 'a')



 TypeDispatch.returns (x)

Get the return type of annotation of x.

You can optionally pass an object to TypeDispatch.returns and get the return type annotation back:

def f1(x:int) -> np.ndarray: return np.array(x)
def f2(x:str) -> float: return List
def f3(x:float): return List # f3 has no return type annotation

t = TypeDispatch([f1, f2, f3])

test_eq(t.returns(1), np.ndarray)  # dispatched to f1
test_eq(t.returns('Hello'), float) # dispatched to f2
test_eq(t.returns(1.0), None)      # dispatched to f3

class _Test: pass
_test = _Test()
test_eq(t.returns(_test), None) # type `_Test` not found, so None returned

Using TypeDispatch With Methods

You can use TypeDispatch when defining methods as well:

def m_nin(self, x:str|numbers.Integral): return str(x)+'1'
def m_bll(self, x:bool): self.foo='a'
def m_num(self, x:numbers.Number): return x*2

t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch instance
a = A()
test_eq(a.f(1), '11')  #dispatch to m_nin
test_eq(a.f(1.), 2.)   #dispatch to m_num
test_is(a.f.inst, a)

a.f(False) # this triggers t.m_bll to run, which sets self.foo to 'a'
test_eq(a.foo, 'a')

As discussed in TypeDispatch.__call__, when there is not a match, TypeDispatch.__call__ becomes an identity function. In the below example, a tuple does not match any type annotations so a tuple is returned:

test_eq(a.f(()), ())

We extend the previous example by using bases to add an additional method that supports tuples:

def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, bases=t)

class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 2.)
test_is(a2.f.inst, a2)
test_eq(a2.foo, 'a')
test_eq(a2.f(()), (1,))

Using TypeDispatch With Class Methods

You can use TypeDispatch when defining class methods too:

def m_nin(cls, x:str|numbers.Integral): return str(x)+'1'
def m_bll(cls, x:bool): cls.foo='a'
def m_num(cls, x:numbers.Number): return x*2

t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t # set class attribute `f` equal to a TypeDispatch

test_eq(A.f(1), '11')  #dispatch to m_nin
test_eq(A.f(1.), 2.)   #dispatch to m_num
test_is(A.f.owner, A)

A.f(False) # this triggers t.m_bll to run, which sets A.foo to 'a'
test_eq(A.foo, 'a')

typedispatch Decorator



 DispatchReg ()

A global registry for TypeDispatch objects keyed by function name

def f_td_test(x, y): return f'{x}{y}'
def f_td_test(x:numbers.Integral|int, y): return x+1
def f_td_test(x:int, y:float): return x+y
def f_td_test(x:int, y:int): return x*y

test_eq(f_td_test(3,2.0), 5)
assert issubclass(int, numbers.Integral)
test_eq(f_td_test(3,2), 6)

test_eq(f_td_test('a','b'), 'ab')

Using typedispatch With other decorators

You can use typedispatch with classmethod and staticmethod decorator

class A:
    def f_td_test(self, x:numbers.Integral, y): return x+1
    def f_td_test(cls, x:int, y:float): return x+y
    def f_td_test(x:int, y:int): return x*y
test_eq(A.f_td_test(3,2), 6)
test_eq(A.f_td_test(3,2.0), 5)
test_eq(A().f_td_test(3,'2.0'), 4)


Now that we can dispatch on types, let’s make it easier to cast objects to a different type.



 retain_meta (x, res, as_copy=False)

Call res.set_meta(x), if it exists



 default_set_meta (x, as_copy=False)

Copy over _meta from x to res, if it’s missing

(object,object) -> cast

Dictionary-like object; __getitem__ matches keys of types using issubclass

This works both for plain python classes:…

mk_class('_T1', 'a')   # mk_class is a fastai utility that constructs a class.
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)        
assert t2 is t            # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)

test_eq_type(_T2(a=1), t2)

…as well as for arrays and tensors.

class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

To customize casting for other types, define a separate cast function with typedispatch for your type.



 retain_type (new, old=None, typ=None, as_copy=False)

Cast new to type of old or typ if it’s a superclass

class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
c = retain_type(b, typ=_T)
test_eq_type(c, a)

If old has a _meta attribute, its content is passed when casting new to the type of old. In the below example, only the attribute a, but not other_attr is kept, because other_attr is not in _meta:

class _A():
    set_meta = default_set_meta
    def __init__(self, t): self.t=t

class _B1(_A):
    def __init__(self, t, a=1):
        self._meta = {'a':a}
        self.other_attr = 'Hello' # will not be kept after casting.
x = _B1(1, a=2)
b = _A(1)
c = retain_type(b, old=x)
test_eq(c._meta, {'a': 2})
assert not getattr(c, 'other_attr', None)



 retain_types (new, old=None, typs=None)

Cast each item of new to type of matching item in old if it’s a superclass

class T(tuple): pass

t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))



 explode_types (o)

Return the type of o, potentially in nested dictionaries for thing that are listy

test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})