Testing

Machine learning is very hard to test. Due to the nature of the our models, we often have soft failures in the model that are difficult to test against.

Writing software tests in science, is already incredibly hard, so in this section we’ll touch on

  • some fairly simple tests we can implement to ensure consistency of our input data
  • avoid bad bugs in data loading procedures
  • some strategies to probe our models
In [1]:
import pandas as pd
penguins = pd.read_csv('../data/penguins_clean.csv')
penguins.head()
Out[1]:
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Sex Species
0 39.1 18.7 181.0 MALE Adelie Penguin (Pygoscelis adeliae)
1 39.5 17.4 186.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
2 40.3 18.0 195.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
3 36.7 19.3 193.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
4 39.3 20.6 190.0 MALE Adelie Penguin (Pygoscelis adeliae)
In [2]:
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]

X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target], stratify=penguins[target[0]], train_size=.7, random_state=42)
In [3]:
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
import joblib
from joblib import load

clf = load("../model/svc.joblib")
clf.score(X_test, y_test)
Out[3]:
1.0

Deterministic Tests

When I work with neural networks, implementing a new layer, method, or fancy thing, I try to write a test for that layer. The Conv2D layer in Keras and Pytorch for example should always do the same exact thing, when they convole a kernel with an image.

Consider writing a small pytest test that takes a simple numpy array and tests against a known output.

You can check out the keras test suite here and an example how they validate the input and output shapes.

Admittedly this isn't always easy to do and can go beyond the need for research scripts.

Data Tests for Models

An even easier test is by essentially reusing the notebook from the Model Evaluation and writing a test function for it.

In [4]:
def test_penguins(clf):
    # Define data you definitely know the answer to
    test_data = pd.DataFrame([[34.6, 21.1, 198.0, "MALE"],
                              [46.1, 18.2, 178.0, "FEMALE"],
                              [52.5, 15.6, 221.0, "MALE"]], 
             columns=["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)", "Sex"])
    # Define target to the data
    test_target = ['Adelie Penguin (Pygoscelis adeliae)',
                   'Chinstrap penguin (Pygoscelis antarctica)',
                   'Gentoo penguin (Pygoscelis papua)']
    # Assert the model should get these right.
    assert clf.score(test_data, test_target) == 1
In [5]:
test_penguins(clf)

Automated Testing of Docstring Examples

There is an even easier way to run simple tests. This can be useful when we write specific functions to pre-process our data. In the Model Sharing notebook, we looked into auto-generating docstrings.

We can upgrade our docstring and get free software tests out of it!

This is called doctest and usually useful to keep docstring examples up to date and write quick unit tests for a function.

This makes future users (including yourself from the future) quite happy.

In [6]:
def shorten_class_name(df: pd.DataFrame) -> pd.DataFrame:
    """Shorten the class names of the penguins to the shortest version

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe containing the Species column with penguins

    Returns
    -------
    pd.DataFrame
        Normalised dataframe with shortened names
    
    Examples
    --------
    >>> shorten_class_name(pd.DataFrame([[1,2,3,"Adelie Penguin (Pygoscelis adeliae)"]], columns=["1","2","3","Species"]))
       1  2  3 Species
    0  1  2  3  Adelie
    """
    df["Species"] = df.Species.str.split(r" [Pp]enguin", n=1, expand=True)[0]

    return df

import doctest
doctest.testmod()
Out[6]:
TestResults(failed=0, attempted=1)
In [7]:
shorten_class_name(penguins).head()
Out[7]:
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Sex Species
0 39.1 18.7 181.0 MALE Adelie
1 39.5 17.4 186.0 FEMALE Adelie
2 40.3 18.0 195.0 FEMALE Adelie
3 36.7 19.3 193.0 FEMALE Adelie
4 39.3 20.6 190.0 MALE Adelie

So these give a nice example of usage in the docstring, an expected output and a first test case that is validated by our test suite.

Input Data Validation

You validate that the data that users are providing matches what your model is expecting.

These tools are often used in production systems to determine whether APIs usage and user inputs are formatted correctly.

Example tools are:

In [8]:
import pandera as pa

# data to validate
X_train.describe()
Out[8]:
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm)
count 233.000000 233.000000 233.000000
mean 43.982403 17.228755 201.412017
std 5.537146 1.994191 13.929695
min 33.500000 13.100000 172.000000
25% 39.000000 15.700000 190.000000
50% 44.400000 17.300000 198.000000
75% 48.800000 18.800000 213.000000
max 59.600000 21.200000 231.000000
In [9]:
# define schema
schema = pa.DataFrameSchema({
    "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
                                                   pa.Check.le(60)]),
    "Culmen Depth (mm)": pa.Column(float, checks=[pa.Check.ge(13),
                                                  pa.Check.le(22)]),
    "Flipper Length (mm)": pa.Column(float, checks=[pa.Check.ge(170),
                                                    pa.Check.le(235)]),
    "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
})

validated_test = schema(X_test)
---------------------------------------------------------------------------
SchemaError                               Traceback (most recent call last)
Input In [9], in <cell line: 12>()
      1 # define schema
      2 schema = pa.DataFrameSchema({
      3     "Culmen Length (mm)": pa.Column(float, checks=[pa.Check.ge(30),
      4                                                    pa.Check.le(60)]),
   (...)
      9     "Sex": pa.Column(str, checks=pa.Check.isin(["MALE","FEMALE"])),
     10 })
---> 12 validated_test = schema(X_test)

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:810, in DataFrameSchema.__call__(self, dataframe, head, tail, sample, random_state, lazy, inplace)
    782 def __call__(
    783     self,
    784     dataframe: pd.DataFrame,
   (...)
    790     inplace: bool = False,
    791 ):
    792     """Alias for :func:`DataFrameSchema.validate` method.
    793 
    794     :param pd.DataFrame dataframe: the dataframe to be validated.
   (...)
    808         otherwise creates a copy of the data.
    809     """
--> 810     return self.validate(
    811         dataframe, head, tail, sample, random_state, lazy, inplace
    812     )

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:518, in DataFrameSchema.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    505     check_obj = check_obj.map_partitions(
    506         self._validate,
    507         head=head,
   (...)
    513         meta=check_obj,
    514     )
    516     return check_obj.pandera.add_schema(self)
--> 518 return self._validate(
    519     check_obj=check_obj,
    520     head=head,
    521     tail=tail,
    522     sample=sample,
    523     random_state=random_state,
    524     lazy=lazy,
    525     inplace=inplace,
    526 )

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:716, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    714     check_results.append(check_utils.is_table(result))
    715 except errors.SchemaError as err:
--> 716     error_handler.collect_error("schema_component_check", err)
    717 except errors.SchemaErrors as err:
    718     for schema_error_dict in err.schema_errors:

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\error_handlers.py:32, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
     26 """Collect schema error, raising exception if lazy is False.
     27 
     28 :param reason_code: string representing reason for error
     29 :param schema_error: ``SchemaError`` object.
     30 """
     31 if not self._lazy:
---> 32     raise schema_error from original_exc
     34 # delete data of validated object from SchemaError object to prevent
     35 # storing copies of the validated DataFrame/Series for every
     36 # SchemaError collected.
     37 del schema_error.data

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:708, in DataFrameSchema._validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    706 for schema_component in schema_components:
    707     try:
--> 708         result = schema_component(
    709             df_to_validate,
    710             lazy=lazy,
    711             # don't make a copy of the data
    712             inplace=True,
    713         )
    714         check_results.append(check_utils.is_table(result))
    715     except errors.SchemaError as err:

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:2074, in SeriesSchemaBase.__call__(self, check_obj, head, tail, sample, random_state, lazy, inplace)
   2063 def __call__(
   2064     self,
   2065     check_obj: Union[pd.DataFrame, pd.Series],
   (...)
   2071     inplace: bool = False,
   2072 ) -> Union[pd.DataFrame, pd.Series]:
   2073     """Alias for ``validate`` method."""
-> 2074     return self.validate(
   2075         check_obj, head, tail, sample, random_state, lazy, inplace
   2076     )

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schema_components.py:215, in Column.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
    211             validate_column(
    212                 check_obj[column_name].iloc[:, [i]], column_name
    213             )
    214     else:
--> 215         validate_column(check_obj, column_name)
    217 return check_obj

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schema_components.py:188, in Column.validate.<locals>.validate_column(check_obj, column_name)
    187 def validate_column(check_obj, column_name):
--> 188     super(Column, copy(self).set_name(column_name)).validate(
    189         check_obj,
    190         head,
    191         tail,
    192         sample,
    193         random_state,
    194         lazy,
    195         inplace=inplace,
    196     )

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:2032, in SeriesSchemaBase.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
   2026     check_results.append(
   2027         _handle_check_results(
   2028             self, check_index, check, check_obj, *check_args
   2029         )
   2030     )
   2031 except errors.SchemaError as err:
-> 2032     error_handler.collect_error("dataframe_check", err)
   2033 except Exception as err:  # pylint: disable=broad-except
   2034     # catch other exceptions that may occur when executing the
   2035     # Check
   2036     err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else ""

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\error_handlers.py:32, in SchemaErrorHandler.collect_error(self, reason_code, schema_error, original_exc)
     26 """Collect schema error, raising exception if lazy is False.
     27 
     28 :param reason_code: string representing reason for error
     29 :param schema_error: ``SchemaError`` object.
     30 """
     31 if not self._lazy:
---> 32     raise schema_error from original_exc
     34 # delete data of validated object from SchemaError object to prevent
     35 # storing copies of the validated DataFrame/Series for every
     36 # SchemaError collected.
     37 del schema_error.data

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:2027, in SeriesSchemaBase.validate(self, check_obj, head, tail, sample, random_state, lazy, inplace)
   2024 for check_index, check in enumerate(self.checks):
   2025     try:
   2026         check_results.append(
-> 2027             _handle_check_results(
   2028                 self, check_index, check, check_obj, *check_args
   2029             )
   2030         )
   2031     except errors.SchemaError as err:
   2032         error_handler.collect_error("dataframe_check", err)

File C:\tools\Anaconda3\envs\euroscipy-2022-ml-repro\lib\site-packages\pandera\schemas.py:2413, in _handle_check_results(schema, check_index, check, check_obj, *check_args)
   2411         warnings.warn(error_msg, UserWarning)
   2412         return True
-> 2413     raise errors.SchemaError(
   2414         schema,
   2415         check_obj,
   2416         error_msg,
   2417         failure_cases=failure_cases,
   2418         check=check,
   2419         check_index=check_index,
   2420         check_output=check_result.check_output,
   2421     )
   2422 return check_result.check_passed

SchemaError: <Schema Column(name=Sex, type=DataType(str))> failed element-wise validator 0:
<Check isin: isin({'MALE', 'FEMALE'})>
failure cases:
   index failure_case
0    259            .
In [10]:
X_test.Sex.unique()
Out[10]:
array(['FEMALE', 'MALE', '.'], dtype=object)
In [11]:
X_test.loc[259]
Out[11]:
Culmen Length (mm)      44.5
Culmen Depth (mm)       15.7
Flipper Length (mm)    217.0
Sex                        .
Name: 259, dtype: object

Can you fix the data to conform to the schema?

In [ ]: