Interpretability & Model Inspection

One way to probe the models we build is to test them against the established knowledge of domain experts. In this final section, we’ll explore how to build intuitions about our machine learning model and avoid pitfalls like spurious correlations. These methods for model interpretability increase our trust into models, but they can also serve as an additional level of reproducibility in our research and a valuable research artefact that can be discussed in a publication.

This part of the tutorial will also go into some considerations why the feature importance of tree-based methods can serve as a start but often shouldn’t be used as the sole source of truth regarding feature interpretation of our applied research.

This section will introduce tools like shap, discuss feature importance, and manual inspection of models.

In [1]:
import pandas as pd
penguins = pd.read_csv('../data/penguins_clean.csv')
penguins.head()
Out[1]:
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Sex Species
0 39.1 18.7 181.0 MALE Adelie Penguin (Pygoscelis adeliae)
1 39.5 17.4 186.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
2 40.3 18.0 195.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
3 36.7 19.3 193.0 FEMALE Adelie Penguin (Pygoscelis adeliae)
4 39.3 20.6 190.0 MALE Adelie Penguin (Pygoscelis adeliae)
In [2]:
from sklearn.model_selection import train_test_split
num_features = ["Culmen Length (mm)", "Culmen Depth (mm)", "Flipper Length (mm)"]
cat_features = ["Sex"]
features = num_features + cat_features
target = ["Species"]

X_train, X_test, y_train, y_test = train_test_split(penguins[features], penguins[target[0]], stratify=penguins[target[0]], train_size=.7, random_state=42)
In [3]:
from sklearn.svm import SVC
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from joblib import dump, load

model = load("../model/svc.joblib")
model.score(X_test, y_test)
Out[3]:
1.0

Scikit-Learn Inspection

In [4]:
from sklearn.inspection import PartialDependenceDisplay, partial_dependence, plot_partial_dependence

pd_results = partial_dependence(model, X_train.sample(20), num_features)
pd_results
Out[4]:
{'average': array([[[[ 2.12937783,  2.13588149,  2.13431911, ...,  0.94401185,
            0.93124713,  0.90303554],
          [ 2.18109075,  2.19164012,  2.19269231, ...,  1.04926325,
            1.02141138,  0.92065486],
          [ 2.18485349,  2.19556478,  2.19675354, ...,  1.05945926,
            1.03227387,  0.92337057],
          ...,
          [ 2.22958618,  2.24154014,  2.24414834, ...,  2.19277505,
            2.18131887,  2.07452034],
          [ 2.22710438,  2.23927261,  2.24193387, ...,  2.19035734,
            2.17893451,  2.07365359],
          [ 2.21335823,  2.22660656,  2.22953382, ...,  2.17622751,
            2.16481618,  1.66537095]],
 
         [[ 2.12778226,  2.13371169,  2.13180218, ...,  0.93694443,
            0.92496161,  0.89955536],
          [ 2.18029755,  2.19063018,  2.19154947, ...,  1.04081841,
            1.01235447,  0.91629588],
          [ 2.1841219 ,  2.1946356 ,  2.19570302, ...,  1.05169382,
            1.02320692,  0.91892149],
          ...,
          [ 2.23019371,  2.24205983,  2.24463508, ...,  2.19277393,
            2.18121323,  2.0732167 ],
          [ 2.22775072,  2.23982688,  2.24245353, ...,  2.19039193,
            2.17886857,  2.07245779],
          [ 2.21416571,  2.22730511,  2.23019004, ...,  2.17636336,
            2.16485928,  1.66448574]],
 
         [[ 2.0933227 ,  2.08834199,  2.07963557, ...,  0.87451617,
            0.86922246,  0.46814362],
          [ 2.15747693,  2.16380738,  2.16215833, ...,  0.93049757,
            0.9131985 ,  0.87546876],
          [ 2.16227886,  2.16918736,  2.16792252, ...,  0.93990757,
            0.92042814,  0.87705015],
          ...,
          [ 2.22852453,  2.23978488,  2.24202219, ...,  2.17933384,
            2.16573029,  1.6407549 ],
          [ 2.22643811,  2.23785513,  2.24012504, ...,  2.17744381,
            2.16397628,  1.64185637],
          [ 2.21423305,  2.22647455,  2.22890919, ...,  2.16474195,
            2.15159038,  1.63970881]],
 
         ...,
 
         [[-0.13939651, -0.16352771, -0.17212705, ..., -0.19622255,
           -0.19416361, -0.1672551 ],
          [-0.15663238, -0.18096657, -0.18935406, ..., -0.21241968,
           -0.21058695, -0.18575215],
          [-0.15769632, -0.18203336, -0.19040508, ..., -0.21343239,
           -0.21161983, -0.18697688],
          ...,
          [ 0.89712525,  0.86421085,  0.8511747 , ..., -0.19753959,
           -0.19620063, -0.17258012],
          [ 0.90349904,  0.8701319 ,  0.85670111, ..., -0.19436011,
           -0.19302766, -0.16918082],
          [ 0.92912358,  0.89522807,  0.88058214, ...,  0.82067065,
           -0.17803635, -0.15351811]],
 
         [[-0.13502782, -0.15822406, -0.16651772, ..., -0.18971657,
           -0.18774822, -0.16236429],
          [-0.15475578, -0.1783551 , -0.18647553, ..., -0.20822163,
           -0.20641802, -0.18260141],
          [-0.15610422, -0.17971   , -0.18781179, ..., -0.20944927,
           -0.20765969, -0.18398996],
          ...,
          [ 0.87392191,  0.8440517 ,  0.83299236, ..., -0.20174718,
           -0.19992807, -0.17456929],
          [ 0.8790075 ,  0.84876702,  0.83743698, ..., -0.19874486,
           -0.19690603, -0.17119563],
          [ 0.90082383,  0.86994997,  0.85773928, ..., -0.18418439,
           -0.18228949, -0.15545296]],
 
         [[-0.0922914 , -0.10914678, -0.11573867, ..., -0.14515763,
           -0.14488201, -0.13352932],
          [-0.11250344, -0.1318902 , -0.13914526, ..., -0.16935395,
           -0.1690744 , -0.15745806],
          [-0.11415773, -0.13370961, -0.14100056, ..., -0.17114985,
           -0.17086481, -0.15922215],
          ...,
          [ 0.87727714,  0.85466366,  0.44646934, ..., -0.17727546,
           -0.17573322, -0.15587262],
          [ 0.88015044,  0.85759127,  0.84937087, ..., -0.17424065,
           -0.17262199, -0.15221639],
          [ 0.89365987,  0.87181146,  0.86362408, ..., -0.15942958,
           -0.15754432, -0.13537563]]],
 
 
        [[[ 0.9456012 ,  0.31121148,  0.29615171, ..., -0.16484704,
           -0.16255529, -0.12453604],
          [ 0.91325481,  0.88221099,  0.86920217, ..., -0.18132519,
           -0.17905893, -0.14229432],
          [ 0.91031128,  0.87975795,  0.8670146 , ..., -0.18222054,
           -0.17994743, -0.14328361],
          ...,
          [ 0.85224118,  0.84126165,  0.83967854, ...,  0.88759024,
            0.29354657, -0.06561701],
          [ 0.85442463,  0.8437751 ,  0.84243816, ...,  0.89539025,
            0.90165147, -0.0574607 ],
          [ 0.86729763,  0.85768117,  0.85714136, ...,  0.92801893,
            0.93535636,  0.77494813]],
 
         [[ 0.94847654,  0.31287971,  0.29717536, ..., -0.16730339,
           -0.16513319, -0.1276644 ],
          [ 0.91686533,  0.88476848,  0.8711418 , ..., -0.18331114,
           -0.18120406, -0.14527399],
          [ 0.91394951,  0.8823741 ,  0.86902354, ..., -0.18415824,
           -0.18204779, -0.14624148],
          ...,
          [ 0.85333554,  0.8428695 ,  0.84152794, ...,  0.88999834,
            0.29571258, -0.06541161],
          [ 0.85537943,  0.84526208,  0.84418295, ...,  0.89807718,
            0.90411155, -0.05698239],
          [ 0.86770949,  0.85866203,  0.85842069, ...,  0.93161986,
            0.93881557,  0.7763062 ]],
 
         [[ 0.9986713 ,  0.35042535,  0.32509232, ..., -0.18406871,
           -0.18313217,  0.24815913],
          [ 0.98677561,  0.944964  ,  0.92014708, ..., -0.19495411,
           -0.19453515, -0.16742819],
          [ 0.98522946,  0.94439402,  0.91999346, ..., -0.19525176,
           -0.19487923, -0.1681432 ],
          ...,
          [ 0.89043479,  0.89440067,  0.90048811, ...,  0.94829041,
            0.94932913,  0.35341264],
          [ 0.88971661,  0.89378307,  0.90025711, ...,  0.9576013 ,
            0.9593107 ,  0.36469209],
          [ 0.89193146,  0.89554737,  0.90247798, ...,  0.99138463,
            0.99561615,  0.80558535]],
 
         ...,
 
         [[ 2.13917929,  1.53448898,  1.52580736, ...,  0.86851122,
            0.86094317,  0.8542783 ],
          [ 2.17915581,  2.18304945,  2.17990093, ...,  0.90222135,
            0.88329529,  0.84498538],
          [ 2.18219544,  2.18673204,  2.18401158, ...,  0.90911942,
            0.88817022,  0.84521273],
          ...,
          [ 2.20815173,  2.23054615,  2.23734195, ...,  2.22828299,
            2.22139284,  1.53914456],
          [ 2.20453141,  2.22783848,  2.2349992 , ...,  2.22883192,
            2.22230922,  2.14480905],
          [ 2.18578944,  2.21291568,  2.22166636, ...,  2.22601796,
            2.22053311,  2.15623786]],
 
         [[ 2.13665727,  1.53385525,  1.52675307, ...,  0.88336789,
            0.87459411,  0.86203754],
          [ 2.17685886,  2.18176338,  2.17940473, ...,  0.91958472,
            0.8982956 ,  0.85151484],
          [ 2.17995437,  2.18543654,  2.18344978, ...,  0.9268521 ,
            0.90338775,  0.85162224],
          ...,
          [ 2.21017579,  2.23135041,  2.23771857, ...,  2.2259057 ,
            2.21872701,  1.53331428],
          [ 2.20682918,  2.22879914,  2.23548748, ...,  2.22637167,
            2.21955825,  2.13904724],
          [ 2.18929119,  2.21459662,  2.22267882, ...,  2.22329208,
            2.21751988,  2.15080534]],
 
         [[ 2.10761195,  2.1073407 ,  2.10331269, ...,  0.93337326,
            0.92302738,  0.89579821],
          [ 2.14359774,  2.14821051,  2.14643129, ...,  0.94716437,
            0.92843629,  0.87612071],
          [ 2.14659672,  2.15163427,  2.15007627, ...,  0.95041589,
            0.93056309,  0.87517756],
          ...,
          [ 2.1840647 ,  2.20169643,  2.20681088, ...,  2.1732663 ,
            2.16259814,  1.05030086],
          [ 2.18104758,  2.19915424,  2.20449625, ...,  2.17430956,
            2.16425009,  1.05981559],
          [ 2.16521945,  2.18506778,  2.19130096, ...,  2.17271064,
            2.1645457 ,  1.08364663]]],
 
 
        [[[-0.10100535,  0.51938892,  0.54069389, ...,  2.18059831,
            2.18276677,  2.16732636],
          [-0.15065574, -0.14673823, -0.13761174, ...,  2.16808817,
            2.17378345,  2.17165632],
          [-0.15440424, -0.15159354, -0.14328142, ...,  2.16559077,
            2.17177813,  2.17130787],
          ...,
          [-0.1938452 , -0.20912104, -0.21292044, ..., -0.15227624,
            0.46215345,  0.98881777],
          [-0.19071684, -0.20639855, -0.21039019, ..., -0.15366647,
           -0.14005564,  0.97854307],
          [-0.17456706, -0.19178741, -0.19651355, ..., -0.15208609,
           -0.1408919 ,  0.55350066]],
 
         [[-0.10073424,  0.5214086 ,  0.54428564, ...,  2.18477562,
            2.18681591,  2.17101141],
          [-0.15122325, -0.1465607 , -0.13683228, ...,  2.17302706,
            2.17845249,  2.17556555],
          [-0.15502833, -0.15153299, -0.14267434, ...,  2.17063111,
            2.17652857,  2.17524151],
          ...,
          [-0.19535925, -0.21064439, -0.21444591, ..., -0.15361701,
            0.4609786 ,  0.99076829],
          [-0.1922383 , -0.20793951, -0.21193684, ..., -0.15512573,
           -0.1413816 ,  0.97982857],
          [-0.17608277, -0.19336846, -0.19811593, ..., -0.1538311 ,
           -0.14257528,  0.55324891]],
 
         [[-0.09062905,  0.55463893,  0.59799051, ...,  2.21634996,
            2.21743739,  2.19975332],
          [-0.15027273, -0.13643502, -0.11873492, ...,  2.21034855,
            2.21373474,  2.20593408],
          [-0.15472258, -0.14287274, -0.12685665, ...,  2.20875256,
            2.21245407,  2.20580993],
          ...,
          [-0.20555316, -0.22118297, -0.22510312, ..., -0.16360686,
           -0.14763837,  1.01107437],
          [-0.20253635, -0.21867332, -0.22282291, ..., -0.16634541,
           -0.15157052,  0.99461522],
          [-0.18650136, -0.20462633, -0.20967064, ..., -0.16805658,
           -0.15648282,  0.55119206]],
 
         ...,
 
         [[ 0.99602501,  1.66540232,  1.69767519, ...,  2.22579077,
            2.22663567,  2.21357459],
          [ 0.92560033,  0.98401269,  1.02370393, ...,  2.22971187,
            2.23195196,  2.2267657 ],
          [ 0.91924178,  0.97380474,  1.01374867, ...,  2.2292921 ,
            2.23170387,  2.22732835],
          ...,
          [-0.1821708 , -0.20165121, -0.20696587, ...,  0.86080033,
            0.88215819,  1.68684699],
          [-0.17958281, -0.19976333, -0.20541523, ...,  0.85298784,
            0.87166849,  1.06887361],
          [-0.16452217, -0.18724042, -0.19413456, ..., -0.1612597 ,
            0.85119212,  0.99137148]],
 
         [[ 0.99243157,  1.65461573,  1.68534939, ...,  2.2171798 ,
            2.21840237,  2.20768478],
          [ 0.92867761,  0.98106963,  1.017354  , ...,  2.2226541 ,
            2.22532921,  2.22260944],
          [ 0.92282734,  0.97183211,  1.00844287, ...,  2.22236725,
            2.22521815,  2.22331947],
          ...,
          [-0.17422199, -0.19322672, -0.1983623 , ...,  0.87898663,
            0.90312433,  1.70011296],
          [-0.1716048 , -0.19130494, -0.19678708, ...,  0.86990611,
            0.89084076,  1.08436177],
          [-0.15645782, -0.17859786, -0.18532733, ...,  0.85244168,
            0.86598021,  1.01576324]],
 
         [[ 0.97121412,  1.00192438,  1.02131595, ...,  2.16797948,
            2.17175602,  2.17570719],
          [ 0.93370726,  0.95800488,  0.9799719 , ...,  2.18276901,
            2.18803081,  2.19859842],
          [ 0.93015282,  0.95329494,  0.97478093, ...,  2.18336134,
            2.18878191,  2.20000225],
          ...,
          [-0.13146693, -0.14385753,  0.25387582, ...,  1.01785812,
            1.04708837,  2.1363593 ],
          [-0.12861337, -0.14159155, -0.1442291 , ...,  1.0018736 ,
            1.03177279,  2.12708193],
          [-0.11327957, -0.12800249, -0.13185768, ...,  0.95479379,
            0.97659007,  2.08633669]]]]),
 'values': [array([36.3, 36.6, 39.5, 41.4, 43.6, 44.1, 45.5, 46.1, 46.5, 47.8, 49.8,
         50.2, 50.5, 50.8, 51.3, 52. , 53.5, 59.6]),
  array([13.9, 15. , 15.1, 16.7, 17. , 17.3, 17.8, 17.9, 18.2, 18.4, 18.5,
         18.6, 18.8, 19.2, 19.5, 19.7, 19.9, 20.7]),
  array([178., 185., 188., 190., 191., 192., 193., 196., 197., 198., 200.,
         201., 202., 205., 210., 215., 217., 230.])]}
In [5]:
PartialDependenceDisplay.from_estimator(model, X_train, [0,1,2], target=list(y_train.unique())[0])
Out[5]:
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x21e4dbf8580>
No description has been provided for this image
In [6]:
PartialDependenceDisplay.from_estimator(model, X_train, [0,1,2], target=list(y_train.unique())[1])
Out[6]:
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x21e4ea84370>
No description has been provided for this image
In [7]:
PartialDependenceDisplay.from_estimator(model, X_train, [0,1,2], target=list(y_train.unique())[2])
Out[7]:
<sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay at 0x21e4eac7af0>
No description has been provided for this image

Tree importance vs Permutation importance

In [8]:
from sklearn.ensemble import RandomForestClassifier

num_transformer = StandardScaler()
cat_transformer = OneHotEncoder(handle_unknown='ignore')

preprocessor = ColumnTransformer(transformers=[
    ('num', num_transformer, num_features),
    ('cat', cat_transformer, cat_features)
])

rf = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', RandomForestClassifier()),
])

rf.fit(X_train, y_train)
rf.score(X_test, y_test)
Out[8]:
0.9900990099009901
In [9]:
pd.Series(rf.named_steps["classifier"].feature_importances_, index=num_features+['F', 'M']).plot.bar()
Out[9]:
<AxesSubplot:>
No description has been provided for this image
In [10]:
from sklearn.inspection import permutation_importance

result = permutation_importance(
    rf, X_test, y_test, n_repeats=10, random_state=42
)

pd.Series(result.importances_mean, index=features).plot.bar()
Out[10]:
<AxesSubplot:>
No description has been provided for this image
In [11]:
result = permutation_importance(
    model, X_test, y_test, n_repeats=10, random_state=42
)

pd.Series(result.importances_mean, index=features).plot.bar()
Out[11]:
<AxesSubplot:>
No description has been provided for this image

Shap Inspection

In [12]:
import shap

rf = RandomForestClassifier()
rf.fit(X_train[num_features], y_train)

explainer = shap.Explainer(rf)
explainer
Out[12]:
<shap.explainers._tree.Tree at 0x21e4ee10040>
In [13]:
shap_values = explainer.shap_values(X_test[num_features])
In [14]:
shap.initjs()
No description has been provided for this image
In [15]:
shap.force_plot(explainer.expected_value[0], shap_values[0][0], feature_names=num_features)
Out[15]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
In [16]:
shap.force_plot(explainer.expected_value[0], shap_values[0], feature_names=num_features)
Out[16]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Model Inspection

There are several tools that work for figuring out that a model is doing what it's supposed to do. Scikit-learn classifiers mostly work out of the box, which is why we don't necessarily have to debug the models.

Sometimes we have to switch off regularization in scikit-learn to achieve the model state we expect.

In neural networks we are working with many moving parts. The first step is a practical step: Overfit a small batch of data on the network. This ensures that the model is capable of learning and all the connections are made as expected. This works as a first-order sense check that models are performing.

A more in-depth solution for Pytorch is Pytorch Surgeon, which can be used to extract submodels of the complete architecture for debugging purposes.

Some example code from the Pytorch Surgeon Docs (torch and surgeon are not installed to save space):

In [17]:
import torch
import torch.nn as nn
from surgeon_pytorch import Extract, get_nodes

class SomeModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(5, 3)
        self.layer2 = nn.Linear(3, 2)
        self.layer3 = nn.Linear(1, 1)

    def forward(self, x):
        x1 = torch.relu(self.layer1(x))
        x2 = torch.sigmoid(self.layer2(x1))
        y = self.layer3(x2).tanh()
        return y

model = SomeModel()
print(get_nodes(model)) # ['x', 'layer1', 'relu', 'layer2', 'sigmoid', 'layer3', 'tanh']
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Input In [17], in <cell line: 1>()
----> 1 import torch
      2 import torch.nn as nn
      3 from surgeon_pytorch import Extract, get_nodes

ModuleNotFoundError: No module named 'torch'

This enables us to extract the model at one of the nodes above:

In [18]:
model_ext = Extract(model, node_out='sigmoid')
x = torch.rand(1, 5)
sigmoid = model_ext(x)
print(sigmoid) # tensor([[0.5570, 0.3652]], grad_fn=<SigmoidBackward0>)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Input In [18], in <cell line: 1>()
----> 1 model_ext = Extract(model, node_out='sigmoid')
      2 x = torch.rand(1, 5)
      3 sigmoid = model_ext(x)

NameError: name 'Extract' is not defined