.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/Interpretation.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_Interpretation.py: ===================== 6. Interpretation ===================== The purpose of this notebook is to apply various post-hoc interpretation methods on our model. For this purose, we will rebuild our DecisionTree model, train it. After this we will apply SHAP, PDP and ALE on the trained DecisionTree model. .. GENERATED FROM PYTHON SOURCE LINES 10-41 .. code-block:: Python import numpy as np import pandas as pd #np.bool = np.bool_ import shap import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.colors as mcolors from easy_mpl import pie from easy_mpl import bar_chart from easy_mpl.utils import create_subplots from shap.plots import waterfall from shap import summary_plot, Explanation from ai4water import Model from ai4water.utils.utils import TrainTestSplit from ai4water.postprocessing import PartialDependencePlot from utils import LABEL_MAP from utils import version_info from utils import shap_scatter from utils import make_classes from utils import shap_scatter_plots from utils import prepare_data, set_rcParams, plot_ale, SAVE .. GENERATED FROM PYTHON SOURCE LINES 42-46 .. code-block:: Python for lib, ver in version_info().items(): print(lib, ver) .. rst-class:: sphx-glr-script-out .. code-block:: none python 3.12.10 (main, May 6 2025, 10:49:23) [GCC 11.4.0] os posix ai4water 1.07 lightgbm 4.6.0 catboost 1.2.10 xgboost 3.2.0 easy_mpl 0.21.5 SeqMetrics 2.0.0 numpy 1.26.4 pandas 2.2.3 matplotlib 3.10.8 h5py 3.16.0 sklearn 1.3.1 optuna 4.8.0 skopt 0.10.2 plotly 6.6.0 seaborn 0.13.2 crepes 0.9.0 mapie 0.9.2 shap 0.49.1 scipy 1.17.1 .. GENERATED FROM PYTHON SOURCE LINES 47-50 .. code-block:: Python set_rcParams() .. GENERATED FROM PYTHON SOURCE LINES 51-59 .. code-block:: Python inputs = ['Solution pH', 'Time (m)', 'Anions', 'Ni (At%)', 'HA (mg/L)', 'loading (g)', 'Pore size (nm)', 'O (At%)', 'Light intensity (watt)', 'Mo (At%)', 'Dye concentration (mg/L)'] data, encoders = prepare_data(inputs=inputs, outputs="k") print(data.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none (1527, 12) .. GENERATED FROM PYTHON SOURCE LINES 60-80 .. code-block:: Python input_features = data.columns.tolist()[0:-1] output_features = data.columns.tolist()[-1:] TrainX, TestX, TrainY, TestY = TrainTestSplit(seed=313).split_by_random( data[input_features], data[output_features] ) print(TrainX.shape, TrainY.shape, TestX.shape, TestY.shape) model = Model( model = "DecisionTreeRegressor", input_features=input_features, output_features=output_features, verbosity=-1, ) model.fit(TrainX, TrainY.values) .. rst-class:: sphx-glr-script-out .. code-block:: none (1068, 11) (1068, 1) (459, 11) (459, 1) .. raw:: html
DecisionTreeRegressor(random_state=313)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 81-83 .. code-block:: Python train_p = model.predict(TrainX, process_results=False) .. GENERATED FROM PYTHON SOURCE LINES 84-86 .. code-block:: Python test_p = model.predict(TestX, process_results=False) .. GENERATED FROM PYTHON SOURCE LINES 87-88 Average prediction on training data .. GENERATED FROM PYTHON SOURCE LINES 88-90 .. code-block:: Python print(train_p.mean()) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.006940225438833623 .. GENERATED FROM PYTHON SOURCE LINES 91-92 default feature importance from decision tree .. GENERATED FROM PYTHON SOURCE LINES 92-95 .. code-block:: Python print(model._model.feature_importances_) .. rst-class:: sphx-glr-script-out .. code-block:: none [7.40882935e-02 1.94588607e-01 9.18145954e-02 3.19367156e-01 3.43836446e-02 9.41130328e-02 5.70019588e-03 1.71579275e-02 3.37566239e-02 2.51049444e-04 1.34778874e-01] .. GENERATED FROM PYTHON SOURCE LINES 96-103 .. code-block:: Python bar_chart(model._model.feature_importances_, [LABEL_MAP[n] if n in LABEL_MAP else n for n in model.input_features], sort=True, show=False) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_001.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_001.png, /auto_examples/images/sphx_glr_Interpretation_001_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 104-106 SHAP ====== .. GENERATED FROM PYTHON SOURCE LINES 106-112 .. code-block:: Python exp = shap.TreeExplainer(model=model._model, data=TrainX, feature_names=input_features) print(exp.expected_value) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.0066904566801259955 .. GENERATED FROM PYTHON SOURCE LINES 113-126 .. code-block:: Python shap_values = exp.shap_values(TrainX, TrainY) summary_plot(shap_values, TrainX, max_display=34, feature_names=[LABEL_MAP[n] if n in LABEL_MAP else n for n in input_features], show=False) if SAVE: plt.savefig("results/figures/shap_summary.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_002.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_002.png, /auto_examples/images/sphx_glr_Interpretation_002_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 127-138 .. code-block:: Python sv_bar = np.mean(np.abs(shap_values), axis=0) classes, colors, colors_ = make_classes(exp) df_with_classes = pd.DataFrame({'features': exp.feature_names, 'classes': classes, 'mean_shap': sv_bar}) print(df_with_classes) .. rst-class:: sphx-glr-script-out .. code-block:: none features classes mean_shap 0 Solution pH Experimental Conditions 0.001051 1 Time (m) Experimental Conditions 0.001239 2 Anions Experimental Conditions 0.000390 3 Ni (At%) Atomic Composition 0.003584 4 HA (mg/L) Experimental Conditions 0.000281 5 loading (g) Experimental Conditions 0.001199 6 Pore size (nm) Physicochemical Properties 0.000322 7 O (At%) Atomic Composition 0.000349 8 Light intensity (watt) Experimental Conditions 0.000395 9 Mo (At%) Atomic Composition 0.000014 10 Dye concentration (mg/L) Experimental Conditions 0.001875 .. GENERATED FROM PYTHON SOURCE LINES 139-170 .. code-block:: Python f, ax = plt.subplots(figsize=(7,9)) ax = bar_chart( sv_bar, [LABEL_MAP[n] if n in LABEL_MAP else n for n in exp.feature_names], bar_labels=np.round(sv_bar, 4), bar_label_kws={'label_type':'edge', 'fontsize': 10, 'weight': 'bold', "fmt": '%.4f', 'padding': 1.5 }, show=False, sort=True, color=colors_, ax = ax ) ax.spines[['top', 'right']].set_visible(False) ax.set_xlabel(xlabel='mean(|SHAP value|)') ax.set_xticklabels(ax.get_xticks().astype(float)) ax.set_yticklabels(ax.get_yticklabels()) labels = df_with_classes['classes'].unique() handles = [plt.Rectangle((0,0),1,1, color=colors[l]) for l in labels] plt.legend(handles, labels, loc='lower right') ax.xaxis.set_major_locator(plt.MaxNLocator(4)) if SAVE: plt.savefig("results/figures/shap_bar.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_003.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_003.png, /auto_examples/images/sphx_glr_Interpretation_003_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 171-215 .. code-block:: Python seg_colors = (colors.values()) # Change the saturation of seg_colors to 70% for the interior segments rgb = mcolors.to_rgba_array(seg_colors)[:,:-1] hsv = mcolors.rgb_to_hsv(rgb) hsv[:,1] = 0.7 * hsv[:, 1] interior_colors = mcolors.hsv_to_rgb(hsv) fractions = np.array([ df_with_classes.loc[df_with_classes['classes']=='Experimental Conditions']['mean_shap'].sum(), df_with_classes.loc[df_with_classes['classes']=='Physicochemical Properties']['mean_shap'].sum(), df_with_classes.loc[df_with_classes['classes']=='Atomic Composition']['mean_shap'].sum(), ]) dye_frac = df_with_classes.loc[df_with_classes['classes']=='Dye Properties']['mean_shap'].sum() labels = ['Experimental \nConditions', 'Physicochemical \nProperties', 'Atomic \nComposition'] if dye_frac > 0.0: fractions = np.array(fractions.tolist().append(dye_frac)) labels.append('Dye Properties') fractions /=fractions.sum() _, texts= pie(fractions=fractions, colors=seg_colors, labels=labels, wedgeprops=dict(edgecolor="w", width=0.03), radius=1, autopct=None, textprops = dict(fontsize=12), startangle=90, counterclock=False, show=False) texts[0].set_fontsize(12) _, texts, autotexts = pie(fractions=fractions, colors=interior_colors, autopct='%1.0f%%', textprops = dict(fontsize=24), wedgeprops=dict(edgecolor="w"), radius=1-2*0.03, startangle=90, counterclock=False, ax=plt.gca(), show=False) texts[0].set_fontsize(12) if SAVE: plt.savefig("results/figures/shap_pie.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_004.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_004.png, /auto_examples/images/sphx_glr_Interpretation_004_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 216-220 .. code-block:: Python index = train_p.argmax() print(index, train_p.max()) .. rst-class:: sphx-glr-script-out .. code-block:: none 355 0.036809739452414975 .. GENERATED FROM PYTHON SOURCE LINES 221-234 .. code-block:: Python e = Explanation( shap_values[index], base_values=exp.expected_value, data=TrainX.values[index], feature_names=input_features ) waterfall(e, max_display=20, show=False) if SAVE: plt.savefig(f"results/figures/shap_local_{index}.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_005.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_005.png, /auto_examples/images/sphx_glr_Interpretation_005_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 235-238 The following figures show SHAP interaction plots. These figures depict the inteaction effect of two features on model performance. In these figures, the numbers in legends for Anions, have following meanings .. GENERATED FROM PYTHON SOURCE LINES 238-240 .. code-block:: Python encoders['Anions'].inverse_transform(np.array([0,1,2,3,4, 5, 5]).reshape(-1,1)) .. rst-class:: sphx-glr-script-out .. code-block:: none array(['N/A', 'Na2HPO4', 'Na2SO4', 'NaCO3', 'NaCl', 'NaHCO3', 'NaHCO3'], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 241-249 Similarly for catalyst, the numbers in legend have following meanings Pt-BFO : 6 Pd-BFO: 4 LM : 2 Ag-BFO : 0 Photolysis : 5 LTH : 3 BFO : 1 .. GENERATED FROM PYTHON SOURCE LINES 251-254 Dye Concentration ------------------ It represents initial concentration of dye. .. GENERATED FROM PYTHON SOURCE LINES 254-261 .. code-block:: Python feature_name = 'Dye concentration (mg/L)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_006.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_006.png, /auto_examples/images/sphx_glr_Interpretation_006_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 262-265 Ni (At%) ------------------ .. GENERATED FROM PYTHON SOURCE LINES 265-271 .. code-block:: Python feature_name = 'Ni (At%)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_007.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_007.png, /auto_examples/images/sphx_glr_Interpretation_007_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 272-275 loading ----------- It represents how much photocatalyst is present. .. GENERATED FROM PYTHON SOURCE LINES 275-281 .. code-block:: Python feature_name = 'loading (g)' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_008.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_008.png, /auto_examples/images/sphx_glr_Interpretation_008_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 282-284 Time ------- .. GENERATED FROM PYTHON SOURCE LINES 284-290 .. code-block:: Python feature_name = 'Time (m)' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_009.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_009.png, /auto_examples/images/sphx_glr_Interpretation_009_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 291-293 Solution pH -------------- .. GENERATED FROM PYTHON SOURCE LINES 293-299 .. code-block:: Python feature_name = 'Solution pH' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_010.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_010.png, /auto_examples/images/sphx_glr_Interpretation_010_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 300-302 Light intensiy --------------- .. GENERATED FROM PYTHON SOURCE LINES 302-309 .. code-block:: Python feature_name = 'Light intensity (watt)' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_011.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_011.png, /auto_examples/images/sphx_glr_Interpretation_011_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 310-312 Oxygen --------------- .. GENERATED FROM PYTHON SOURCE LINES 312-318 .. code-block:: Python feature_name = 'O (At%)' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_012.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_012.png, /auto_examples/images/sphx_glr_Interpretation_012_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 319-321 Humic Acid ----------- .. GENERATED FROM PYTHON SOURCE LINES 321-327 .. code-block:: Python feature_name = 'HA (mg/L)' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_013.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_013.png, /auto_examples/images/sphx_glr_Interpretation_013_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 328-330 Pore size ----------- .. GENERATED FROM PYTHON SOURCE LINES 330-337 .. code-block:: Python feature_name = 'Pore size (nm)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_014.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_014.png, /auto_examples/images/sphx_glr_Interpretation_014_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 338-340 Anions -------- .. GENERATED FROM PYTHON SOURCE LINES 340-346 .. code-block:: Python feature_name = 'Anions' shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_015.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_015.png, /auto_examples/images/sphx_glr_Interpretation_015_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 347-353 .. code-block:: Python feature_name = 'Mass ratio (Catalyst/Dye)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. GENERATED FROM PYTHON SOURCE LINES 354-356 S --- .. GENERATED FROM PYTHON SOURCE LINES 356-364 .. code-block:: Python feature_name = 'S (At%)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. GENERATED FROM PYTHON SOURCE LINES 365-367 Surface Area -------------- .. GENERATED FROM PYTHON SOURCE LINES 367-374 .. code-block:: Python feature_name = 'Surface area (m2/g)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. GENERATED FROM PYTHON SOURCE LINES 375-377 Mo ----------- .. GENERATED FROM PYTHON SOURCE LINES 377-382 .. code-block:: Python feature_name = 'Mo (At%)' if feature_name in TrainX: shap_scatter_plots(shap_values, TrainX, feature_name, encoders=encoders, save=SAVE) .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_016.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_016.png, /auto_examples/images/sphx_glr_Interpretation_016_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 383-476 .. code-block:: Python fig, ((ax1, ax2, ax3, ax4), (ax5, ax6, ax7, ax8)) = plt.subplots( 2,4, figsize=(15, 8)) ax = shap_scatter( shap_values[:, 5], TrainX.loc[:, 'loading (g)'], TrainX.loc[:, 'Ni (At%)'], feature_name='Cat. Loading (g/L)', ax=ax1, show=False ) ax.set_ylabel('') #ax.set_xlim(ax.get_xlim()[0], 0.32) ax = shap_scatter( shap_values[:, 5], TrainX.loc[:, 'loading (g)'], TrainX.loc[:, 'Pore size (nm)'], feature_name='Cat. Loading (g/L)', ax=ax2, show=False ) ax.set_ylabel('') #ax.set_xlim(2.5, 12.5) ax = shap_scatter( shap_values[:, 5], TrainX.loc[:, 'loading (g)'], TrainX.loc[:, 'Solution pH'], feature_name='Cat. Loading (g/L)', ax=ax3, show=False, ) ax.set_ylabel('') #ax.set_xlim(ax.get_xlim()[0], 62) ax = shap_scatter( shap_values[:, 5], TrainX.loc[:, 'loading (g)'], TrainX.loc[:, 'Mo (At%)'], feature_name='Cat. Loading (g/L)', ax=ax4, show=False, ) ax.set_ylabel('') #ax.set_xlim(2.5, 12.5) ax = shap_scatter( shap_values[:, 3], TrainX.loc[:, 'Ni (At%)'], TrainX.loc[:, 'Pore size (nm)'], feature_name='Ni (At%)', ax=ax5, show=False ) ax.set_ylabel('') ax = shap_scatter( shap_values[:, 3], TrainX.loc[:, 'Ni (At%)'], TrainX.loc[:, 'Solution pH'], feature_name='Ni (At%)', ax=ax6, show=False ) ax.set_ylabel('') #ax.set_xlim(2.5, 12.5) ax = shap_scatter( shap_values[:, 3], TrainX.loc[:, 'Ni (At%)'], TrainX.loc[:, 'Mo (At%)'], feature_name='Ni (At%)', ax=ax7, show=False ) ax.set_ylabel('') ax = shap_scatter( shap_values[:, 0], TrainX.loc[:, 'Solution pH'], TrainX.loc[:, 'O (At%)'], feature_name='Solution pH', ax=ax8, show=False ) ax.set_ylabel('') plt.tight_layout() if SAVE: plt.savefig("results/figures/shap_dep.png", dpi=600, bbox_inches="tight") plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_017.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_017.png, /auto_examples/images/sphx_glr_Interpretation_017_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 477-479 Partial Dependence Plot ========================== .. GENERATED FROM PYTHON SOURCE LINES 479-490 .. code-block:: Python pdp = PartialDependencePlot( model.predict, TrainX, num_points=20, feature_names=TrainX.columns.tolist(), show=False, save=False ) .. GENERATED FROM PYTHON SOURCE LINES 491-522 .. code-block:: Python mpl.rcParams.update(mpl.rcParamsDefault) colors = ["#DB0007", "#670E36", "#e30613", "#0057B8", "#6C1D45", "#034694", "#1B458F", "#003399", "#FFCD00", "#003090", "#C8102E", "#6CABDD", "#DA291C", "#241F20", "#00A650", "#D71920", "#132257", "#ED2127", "#7A263A", "#FDB913", "#DB0007", "#670E36", "#e30613", "#0057B8", "#6C1D45", "#034694", "#1B458F", "#003399", "#FFCD00", "#003090", ] f, axes = create_subplots(TrainX.shape[1], figsize=(10, 12)) for ax, feature, clr in zip(axes.flat, TrainX.columns, colors): pdp_vals, ice_vals = pdp.calc_pdp_1dim(TrainX.values, feature) ax = pdp.plot_pdp_1dim(pdp_vals, ice_vals, TrainX.values, feature, pdp_line_kws={ 'color': clr, 'zorder': 3}, ice_color="gray", ice_lines_kws=dict(zorder=2, alpha=0.15), ax=ax, show=False, ) ax.set_xlabel(LABEL_MAP.get(feature, feature)) ax.set_ylabel(f"E[f(x) | " + feature + "]") if SAVE: plt.savefig("results/figures/pdp.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_018.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_018.png, /auto_examples/images/sphx_glr_Interpretation_018_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 523-525 Accumulated Local Effects ========================== .. GENERATED FROM PYTHON SOURCE LINES 525-539 .. code-block:: Python class MyModel: def predict(self, X): return model.predict(X).reshape(-1,) f, axes = create_subplots(TrainX.shape[1], figsize=(10, 12)) for ax, feature, clr in zip(axes.flat, TrainX.columns, colors): plot_ale(MyModel().predict, TrainX, feature, ax=ax, show=False, color=clr, ) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_019.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_019.png, /auto_examples/images/sphx_glr_Interpretation_019_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 540-544 All Features model interpretation ================================== For the sake of comparison, we also show interpretation of model which uses all features as input. .. GENERATED FROM PYTHON SOURCE LINES 544-551 .. code-block:: Python set_rcParams() data, encoders = prepare_data(outputs="k") print(data.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none (1527, 35) .. GENERATED FROM PYTHON SOURCE LINES 552-573 .. code-block:: Python input_features = data.columns.tolist()[0:-1] output_features = data.columns.tolist()[-1:] TrainX, TestX, TrainY, TestY = TrainTestSplit(seed=313).split_by_random( data[input_features], data[output_features] ) print(TrainX.shape, TrainY.shape, TestX.shape, TestY.shape) model = Model( model = "DecisionTreeRegressor", input_features=input_features, output_features=output_features, verbosity=-1, ) model.fit(TrainX, TrainY.values) .. rst-class:: sphx-glr-script-out .. code-block:: none (1068, 34) (1068, 1) (459, 34) (459, 1) .. raw:: html
DecisionTreeRegressor(random_state=313)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 574-576 .. code-block:: Python train_p = model.predict(TrainX, process_results=False) .. GENERATED FROM PYTHON SOURCE LINES 577-579 .. code-block:: Python test_p = model.predict(TestX, process_results=False) .. GENERATED FROM PYTHON SOURCE LINES 580-581 Average prediction on training data .. GENERATED FROM PYTHON SOURCE LINES 581-583 .. code-block:: Python print(train_p.mean()) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.0069402254388336235 .. GENERATED FROM PYTHON SOURCE LINES 584-585 default feature importance from decision tree .. GENERATED FROM PYTHON SOURCE LINES 585-588 .. code-block:: Python print(model._model.feature_importances_) .. rst-class:: sphx-glr-script-out .. code-block:: none [3.81006042e-05 1.73386029e-06 1.19712970e-02 0.00000000e+00 7.23370042e-03 1.99232034e-05 3.97314739e-04 2.75412342e-01 1.25751993e-03 4.10362230e-02 9.97703911e-04 0.00000000e+00 2.34364582e-05 2.25969210e-04 1.03530578e-02 2.14975624e-03 5.75046118e-04 0.00000000e+00 8.39858878e-02 3.36541611e-02 2.61657725e-03 1.93957659e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.34414955e-01 7.38624495e-02 3.42792786e-02 9.15359069e-02] .. GENERATED FROM PYTHON SOURCE LINES 589-599 .. code-block:: Python fig, ax = plt.subplots(figsize=(6, 8)) bar_chart(model._model.feature_importances_, [LABEL_MAP[n] if n in LABEL_MAP else n for n in model.input_features], sort=True, show=False, ax=ax) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_020.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_020.png, /auto_examples/images/sphx_glr_Interpretation_020_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 600-602 SHAP all features ================== .. GENERATED FROM PYTHON SOURCE LINES 602-608 .. code-block:: Python exp = shap.TreeExplainer(model=model._model, data=TrainX, feature_names=input_features) print(exp.expected_value) .. rst-class:: sphx-glr-script-out .. code-block:: none 0.00671328508398034 .. GENERATED FROM PYTHON SOURCE LINES 609-622 .. code-block:: Python shap_values = exp.shap_values(TrainX, TrainY) summary_plot(shap_values, TrainX, max_display=34, feature_names=[LABEL_MAP[n] if n in LABEL_MAP else n for n in input_features], show=False) if SAVE: plt.savefig("results/figures/shap_summary_all.png", dpi=600, bbox_inches="tight") plt.tight_layout() plt.show() .. image-sg:: /auto_examples/images/sphx_glr_Interpretation_021.png :alt: Interpretation :srcset: /auto_examples/images/sphx_glr_Interpretation_021.png, /auto_examples/images/sphx_glr_Interpretation_021_2_00x.png 2.00x :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 623-710 .. code-block:: Python # sv_bar = np.mean(np.abs(shap_values), axis=0) # classes, colors, colors_ = make_classes(exp) # df_with_classes = pd.DataFrame({'features': exp.feature_names, # 'classes': classes, # 'mean_shap': sv_bar}) # print(df_with_classes) # # %% # f, ax = plt.subplots(figsize=(7,9)) # ax = bar_chart( # sv_bar, # [LABEL_MAP[n] if n in LABEL_MAP else n for n in exp.feature_names], # bar_labels=np.round(sv_bar, 4), # bar_label_kws={'label_type':'edge', # 'fontsize': 10, # 'weight': 'bold', # "fmt": '%.4f', # 'padding': 1.5 # }, # show=False, # sort=True, # color=colors_, # ax = ax # ) # ax.spines[['top', 'right']].set_visible(False) # ax.set_xlabel(xlabel='mean(|SHAP value|)') # ax.set_xticklabels(ax.get_xticks().astype(float)) # ax.set_yticklabels(ax.get_yticklabels()) # labels = df_with_classes['classes'].unique() # handles = [plt.Rectangle((0,0),1,1, # color=colors[l]) for l in labels] # plt.legend(handles, labels, loc='lower right') # ax.xaxis.set_major_locator(plt.MaxNLocator(4)) # if SAVE: # plt.savefig("results/figures/shap_bar_all.png", dpi=600, bbox_inches="tight") # plt.tight_layout() # plt.show() # # %% # seg_colors = (colors.values()) # # Change the saturation of seg_colors to 70% for the interior segments # rgb = mcolors.to_rgba_array(seg_colors)[:,:-1] # hsv = mcolors.rgb_to_hsv(rgb) # hsv[:,1] = 0.7 * hsv[:, 1] # interior_colors = mcolors.hsv_to_rgb(hsv) # fractions = np.array([ # df_with_classes.loc[df_with_classes['classes']=='Experimental Conditions']['mean_shap'].sum(), # df_with_classes.loc[df_with_classes['classes']=='Physicochemical Properties']['mean_shap'].sum(), # df_with_classes.loc[df_with_classes['classes']=='Atomic Composition']['mean_shap'].sum(), # ]) # dye_frac = df_with_classes.loc[df_with_classes['classes']=='Dye Properties']['mean_shap'].sum() # labels = ['Experimental \nConditions', 'Physicochemical \nProperties', # 'Atomic \nComposition'] # if dye_frac > 0.0: # fractions = np.array(fractions.tolist().append(dye_frac)) # labels.append('Dye Properties') # fractions /=fractions.sum() # _, texts= pie(fractions=fractions, # colors=seg_colors, # labels=labels, # wedgeprops=dict(edgecolor="w", width=0.03), radius=1, # autopct=None, # textprops = dict(fontsize=12), # startangle=90, counterclock=False, show=False) # texts[0].set_fontsize(12) # _, texts, autotexts = pie(fractions=fractions, # colors=interior_colors, # autopct='%1.0f%%', # textprops = dict(fontsize=24), # wedgeprops=dict(edgecolor="w"), radius=1-2*0.03, # startangle=90, counterclock=False, ax=plt.gca(), show=False) # texts[0].set_fontsize(12) # if SAVE: # plt.savefig("results/figures/shap_pie_all.png", dpi=600, bbox_inches="tight") # plt.tight_layout() # plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 56.549 seconds) .. _sphx_glr_download_auto_examples_Interpretation.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: Interpretation.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: Interpretation.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: Interpretation.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_