def _fail(): raise Exception("foobar")
test_fail(_fail, contains="foo")
def _fail(): raise Exception()
test_fail(_fail)
def _fail(): raise ValueError()
test_fail(_fail, exc=ValueError)
test_fail(lambda: test_fail(_fail, exc=IndexError), exc=AssertionError)Test
Simple test functions
We can check that code raises an exception when that’s expected (test_fail).
To test for equality or inequality (with different types of things) we define a simple function test that compares two objects with a given cmp operator.
test_fail
def test_fail(
f, msg:str='', contains:str='', exc:type=<class 'Exception'>, args:NoneType=None, kwargs:NoneType=None
):
Fails with msg unless f() raises an exception of type exc and (optionally) has contains in e.args
We can also pass args and kwargs to function to check if it fails with special inputs.
def _fail_args(a):
if a == 5:
raise ValueError
test_fail(_fail_args, args=(5,))
test_fail(_fail_args, kwargs=dict(a=5))test
def test(
a, b, cmp, cname:NoneType=None
):
assert that cmp(a,b); display inputs and cname or cmp.__name__ if it fails
test([1,2],[1,2], operator.eq)
test_fail(lambda: test([1,2],[1], operator.eq))
test([1,2],[1], operator.ne)
test_fail(lambda: test([1,2],[1,2], operator.ne))all_equal
def all_equal(
a, b
):
Compares whether a and b are the same length and have the same contents
test(['abc'], ['abc'], all_equal)
test_fail(lambda: test(['abc'],['cab'], all_equal))equals
def equals(
a, b
):
Compares a and b for equality; supports sublists, tensors and arrays too
test([['abc'],['a']], [['abc'],['a']], equals)
test([['abc'],['a'],'b', [['x']]], [['abc'],['a'],'b', [['x']]], equals) # supports any depth and nested structurenequals
def nequals(
a, b
):
Compares a and b for not equals
test(['abc'], ['ab' ], nequals)test_eq test_ne, etc…
Just use test_eq/test_ne to test for ==/!=. test_eq_type checks things are equal and of the same type. We define them using test:
test_eq
def test_eq(
a, b
):
test that a==b
test_eq([1,2],[1,2])
test_eq([1,2],map(int,[1,2]))
test_eq(array([1,2]),array([1,2]))
test_eq(array([1,2]),array([1,2]))
test_eq([array([1,2]),3],[array([1,2]),3])
test_eq(dict(a=1,b=2), dict(b=2,a=1))
test_fail(lambda: test_eq([1,2], 1), contains="==")
test_fail(lambda: test_eq(None, np.array([1,2])), contains="==")
test_eq({'a', 'b', 'c'}, {'c', 'a', 'b'})df1 = pd.DataFrame(dict(a=[1,2],b=['a','b']))
df2 = pd.DataFrame(dict(a=[1,2],b=['a','b']))
df3 = pd.DataFrame(dict(a=[1,2],b=['a','c']))
test_eq(df1,df2)
test_eq(df1.a,df2.a)
test_fail(lambda: test_eq(df1,df3), contains='==')
class T(pd.Series): pass
test_eq(df1.iloc[0], T(df2.iloc[0])) # works with subclassestest_eq(torch.zeros(10), torch.zeros(10, dtype=torch.float64))
test_eq(torch.zeros(10), torch.ones(10)-1)
test_fail(lambda:test_eq(torch.zeros(10), torch.ones(1, 10)), contains='==')
test_eq(torch.zeros(3), [0,0,0])test_eq_type
def test_eq_type(
a, b
):
test that a==b and are same type
test_eq_type(1,1)
test_fail(lambda: test_eq_type(1,1.))
test_eq_type([1,1],[1,1])
test_fail(lambda: test_eq_type([1,1],(1,1)))
test_fail(lambda: test_eq_type([1,1],[1,1.]))test_ne
def test_ne(
a, b
):
test that a!=b
test_ne([1,2],[1])
test_ne([1,2],[1,3])
test_ne(array([1,2]),array([1,1]))
test_ne(array([1,2]),array([1,1]))
test_ne([array([1,2]),3],[array([1,2])])
test_ne([3,4],array([3]))
test_ne([3,4],array([3,5]))
test_ne(dict(a=1,b=2), ['a', 'b'])
test_ne(['a', 'b'], dict(a=1,b=2))is_close
def is_close(
a, b, eps:float=1e-05
):
Is a within eps of b
test_close
def test_close(
a, b, eps:float=1e-05
):
test that a is within eps of b
test_close(1,1.001,eps=1e-2)
test_fail(lambda: test_close(1,1.001))
test_close([-0.001,1.001], [0.,1.], eps=1e-2)
test_close(np.array([-0.001,1.001]), np.array([0.,1.]), eps=1e-2)
test_close(array([-0.001,1.001]), array([0.,1.]), eps=1e-2)test_is
def test_is(
a, b
):
test that a is b
test_fail(lambda: test_is([1], [1]))
a = [1]
test_is(a, a)
b = [2]; test_fail(lambda: test_is(a, b))test_shuffled
def test_shuffled(
a, b
):
test that a and b are shuffled versions of the same sequence of items
a = list(range(50))
b = copy(a)
random.shuffle(b)
test_shuffled(a,b)
test_fail(lambda:test_shuffled(a,a))a = 'abc'
b = 'abcabc'
test_fail(lambda:test_shuffled(a,b))a = ['a', 42, True]
b = [42, True, 'a']
test_shuffled(a,b)test_stdout
def test_stdout(
f, exp, regex:bool=False
):
Test that f prints exp to stdout, optionally checking as regex
test_stdout(lambda: print('hi'), 'hi')
test_fail(lambda: test_stdout(lambda: print('hi'), 'ho'))
test_stdout(lambda: 1+1, '')
test_stdout(lambda: print('hi there!'), r'^hi.*!$', regex=True)test_warns
def test_warns(
f, show:bool=False
):
test_warns(lambda: warnings.warn("Oh no!"))
test_fail(lambda: test_warns(lambda: 2+2), contains='No warnings raised')test_warns(lambda: warnings.warn("Oh no!"), show=True)<class 'UserWarning'>: Oh no!
im = Image.open(TEST_IMAGE).resize((128,128)); im
im = Image.open(TEST_IMAGE_BW).resize((128,128)); im
test_fig_exists
def test_fig_exists(
ax
):
Test there is a figure displayed in ax
fig,ax = plt.subplots()
ax.imshow(array(im));
test_fig_exists(ax)ExceptionExpected
def ExceptionExpected(
ex:type=<class 'Exception'>, regex:str=''
):
Context manager that tests if an exception is raised
def _tst_1(): assert False, "This is a test"
def _tst_2(): raise SyntaxError
with ExceptionExpected(): _tst_1()
with ExceptionExpected(ex=AssertionError, regex="This is a test"): _tst_1()
with ExceptionExpected(ex=SyntaxError): _tst_2()exception is an abbreviation for ExceptionExpected().
with exception: _tst_1()