Basic single and dual parameter dispatch
from nbdev.showdoc import *
from fastcore.test import *


def type_hints(f):
    "Same as `typing.get_type_hints` but returns `{}` if not allowed type"
    return typing.get_type_hints(f) if isinstance(f, typing._allowed_types) else {}



Same as typing.get_type_hints but returns {} if not allowed type



Get the return annotation of func

td = {int:1, numbers.Number:2, numbers.Integral:3}
test_eq(sorted(td, key=cmp_instance), [numbers.Number, numbers.Integral, int])
def _f(a): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a, b): pass
test_eq(_p2_anno(_f), (object,object))
def _f(a:None, b)->str: pass
test_eq(_p2_anno(_f), (NoneType,object))
def _f(a:str, b)->float: pass
test_eq(_p2_anno(_f), (str,object))
def _f(a:None, b:str)->float: pass
test_eq(_p2_anno(_f), (NoneType,str))
def _f(a:int, b:int)->float: pass
test_eq(_p2_anno(_f), (int,int))
def _f(self, a:int, b:int): pass
test_eq(_p2_anno(_f), (int,int))
def _f(a:int, b:str)->float: pass
test_eq(_p2_anno(_f), (int,str))
test_eq(_p2_anno(attrgetter('foo')), (object,object))

The following class is the basis that allows us to do type dipatch with type annotations. It contains a dictionary type -> functions and ensures that the proper function is called when passed an object (depending on its type).

class TypeDispatch[source]

TypeDispatch(funcs=(), bases=())

Dictionary-like object; __getitem__ matches keys of types using issubclass

def f_col(x:typing.Collection): return x
def f_nin(x:numbers.Integral)->int:  return x+1
def f_ni2(x:int): return x
def f_bll(x:(bool,list)): return x
def f_num(x:numbers.Number): return x
t = TypeDispatch([f_nin,f_ni2,f_num,f_bll,None])

t.add(f_ni2) #Should work even if we add the same function twice.
test_eq(t[int], f_ni2)
test_eq(t[np.int32], f_nin)
test_eq(t[str], None)
test_eq(t[float], f_num)
test_eq(t[bool], f_bll)
test_eq(t[list], f_bll)
test_eq(t[str], f_col)
test_eq(t[np.int32], f_nin)
o = np.int32(1)
test_eq(t(o), 2)
test_eq(t.returns(o), int)
assert t.first() is not None
(list,object) -> f_bll
(typing.Collection,object) -> f_col
(bool,object) -> f_bll
(int,object) -> f_ni2
(Integral,object) -> f_nin
(Number,object) -> f_num
(object,object) -> NoneType

If bases is set to a collection of TypeDispatch objects, then they are searched matching functions if no match is found in this object.

def f_str(x:str): return x+'1'

t2 = TypeDispatch(f_str, bases=t)
test_eq(t2[int], f_ni2)
test_eq(t2[np.int32], f_nin)
test_eq(t2[float], f_num)
test_eq(t2[bool], f_bll)
test_eq(t2[str], f_str)
test_eq(t2('a'), 'a1')
test_eq(t2[np.int32], f_nin)
test_eq(t2(o), 2)
test_eq(t2.returns(o), int)
def m_nin(self, x:(str,numbers.Integral)): return str(x)+'1'
def m_bll(self, x:bool):'a'
def m_num(self, x:numbers.Number): return x

t = TypeDispatch([m_nin,m_num,m_bll])
class A: f = t
a = A()
test_eq(a.f(1), '11')
test_eq(a.f(1.), 1.)
test_is(a.f.inst, a)
test_eq(, 'a')
test_eq(a.f(()), ())
def m_tup(self, x:tuple): return x+(1,)
t2 = TypeDispatch(m_tup, t)
class A2: f = t2
a2 = A2()
test_eq(a2.f(1), '11')
test_eq(a2.f(1.), 1.)
test_is(a2.f.inst, a2)
test_eq(, 'a')
test_eq(a2.f(()), (1,))
def f1(x:numbers.Integral, y): return x+1
def f2(x:int, y:float): return x+y
t = TypeDispatch([f1,f2])

test_eq(t[int], f1)
test_eq(t[int,int], f1)
test_eq(t[int,float], f2)
test_eq(t[float,float], None)
test_eq(t[np.int32,float], f1)
test_eq(t(3,2.0), 5)
test_eq(t(3,2), 4)
test_eq(t('a'), 'a')
(int,float) -> f2
(Integral,object) -> f1

typedispatch Decorator

class DispatchReg[source]


A global registry for TypeDispatch objects keyed by function name

def f_td_test(x, y): return f'{x}{y}'
def f_td_test(x:numbers.Integral, y): return x+1
def f_td_test(x:int, y:float): return x+y

test_eq(f_td_test(3,2.0), 5)
test_eq(f_td_test(3,2), 4)
test_eq(f_td_test('a','b'), 'ab')


Now that we can dispatch on types, let's make it easier to cast objects to a different type.


retain_meta(x, res)

Call res.set_meta(x), if it exists



Copy over _meta from x to res, if it's missing

This works both for plain python classes:...

mk_class('_T1', 'a')
class _T2(_T1): pass

t = _T1(a=1)
t2 = cast(t, _T2)
test_eq_type(_T2(a=1), t2) well as for arrays and tensors.

class _T1(ndarray): pass

t = array([1])
t2 = cast(t, _T1)
test_eq(array([1]), t2)
test_eq(_T1, type(t2))

To customize casting for other types, define a separate cast function with typedispatch for your type.


retain_type(new, old=None, typ=None)

Cast new to type of old or typ if it's a superclass

class _T(tuple): pass
a = _T((1,2))
b = tuple((1,2))
test_eq_type(retain_type(b, typ=_T), a)

If old has a _meta attribute, its content is passed when casting new to the type of old.

class _A():
    set_meta = default_set_meta
    def __init__(self, t): self.t=t

class _B1(_A):
    def __init__(self, t, a=1):
        self._meta = {'a':a}
x = _B1(1, a=2)
b = _A(1)
test_eq(retain_type(b, old=x)._meta, {'a': 2})
a = {L: [int, tuple]}


retain_types(new, old=None, typs=None)

Cast each item of new to type of matching item in old if it's a superclass

class T(tuple): pass

t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))

t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
test_eq_type(t1, 1)
test_eq_type(t2, T((1,T((1,1)))))



Return the type of o, potentially in nested dictionaries for thing that are listy

test_eq(explode_types((2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]})