In [1]:
# This line will add a button to toggle visibility of code blocks,
# for use with the HTML export version
from IPython.core.display import HTML
HTML('''<button style="margin:0 auto; display: block;" onclick="jQuery('.code_cell .input_area').toggle();
    jQuery('.prompt').toggle();">Toggle code</button>''')
Out[1]:

1. Introduction

Tree-based methods stratify or segment the predictor space into a number of simple regions1.

As the spliting rules to make these decision regions can be summerised in a tree structure, these approaches are called decision trees.

A decision tree can be thought of as breaking data down by asking a series of questions in order to categorise samples into the same class.

In [2]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import matplotlib
from IPython.display import Image
import os
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from time import time
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
import warnings # prevent warnings
import joblib # saving models
import itertools

matplotlib.rcParams['animation.embed_limit'] = 30000000.0

# colours for printing outputs
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'
    
image_dir = os.path.join(os.getcwd(),"Images")
data_dir = os.path.join(os.getcwd(),"..","Data")

fig_num=0
plt.rcParams['figure.dpi'] = 120
# golden ratio for figures ()
gr = 1.618

height_pix = 500
width_pix = height_pix*gr

height_inch = 4
width_inch = height_inch*gr
In [3]:
# Centered figures in the notebook and presentation
# ...was a real pain to find this:
# https://gist.githubusercontent.com/maxalbert/800b9f06c7b2dd365ea5

import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import urllib
import base64
from io import BytesIO, StringIO

def fig2str(fig, format='svg'):
    """
    Return a string containing the raw data of the matplotlib figure in the given format.

    """
    assert isinstance(fig, matplotlib.figure.Figure)
    imgdata = BytesIO()
    fig.savefig(imgdata, format=format, bbox_inches='tight')
    imgdata.seek(0)  # rewind the data
    output = imgdata.getvalue()
    if format == 'svg':
        return output
    else:
        return urllib.parse.quote(base64.b64encode(output))

class MatplotlibFigure(object):
    """
    Thin wrapper around a matplotlib figure which provides a custom
    HTML representation that allows tweaking the appearance

    """
    def __init__(self, fig, centered=False):
        assert isinstance(fig, matplotlib.figure.Figure)
        self.centered = centered

    def _repr_html_(self):
        img_str_png = fig2str(fig, format='png')
        uri = 'data:image/png;base64,' + img_str_png
        html_repr = "<img src='{}'>".format(uri)
        if self.centered:
            html_repr = "<center>" + html_repr + "</center>"
        return html_repr
In [4]:
# TODO: Tidy code here, maybe split it up thoughout the notebook... 
# not sure why I thought this way would be helpful looking back its just confusing :P

col_dict = {
    "Adelie":"#ff7600",
    "Chinstrap":"#c65dcb",
    "Gentoo":"#057576"}

shape_dict = {
    "Adelie":"o",
    "Chinstrap":"s",
    "Gentoo":"X"}

datasets = {}
penguins = sns.load_dataset("penguins")

# dropna
penguins_rm = penguins.dropna()

# keep categorical features
penguins_cat = penguins_rm[["island", "sex", "species"]]
penguins_bin = penguins_cat[penguins_cat.species != "Chinstrap"]
y_bin = penguins_bin[["species"]].replace({'Adelie': 0, 'Gentoo': 1}).values.flatten()
# for the classification data we don't want species there either
penguins_class_feat = penguins_bin.drop("species", axis=1)
datasets['cat'] = {"df":penguins_bin, 
                    "X": penguins_class_feat[["island", "sex"]].values,
                    "y": y_bin,
                    "feats": ["island", "sex"],
                     "class": ['Adelie', 'Gentoo']}


# drop continuous features
penguins_cont = penguins_rm.drop(["island", "sex", "species"], axis=1)

# for the regression data we don't want body_mass_g there either
penguins_reg_feat = penguins_cont.drop(["body_mass_g"], axis=1)

datasets['reg_full'] = {"df":penguins_cont, 
                        "X":penguins_reg_feat.values,
                        "y": penguins_rm[["body_mass_g"]].values.flatten(),
                        "feats": list(penguins_reg_feat.columns)}

# regression datset just to compare flipper length to body mass
datasets['flbm'] = {"df":penguins_cont, 
                    "X":penguins_reg_feat[["flipper_length_mm"]].values,
                    "y": penguins_rm[["body_mass_g"]].values.flatten(),
                    "feats": ["Flipper Length (mm)"]}

# make a binary classification dataset
penguins_class = penguins_rm.drop(["island", "sex"], axis=1)
# for the classification data we don't want species there either
penguins_class_feat = penguins_class.drop("species", axis=1)
y_multi = penguins_class[["species"]].replace({'Adelie': 0, 'Gentoo': 1, "Chinstrap":2}).values.flatten()

datasets['multi'] = {"df":penguins_class, 
                     "X": penguins_class_feat.values,
                     "y": y_multi,
                     "feats": list(penguins_class_feat.columns),
                     "class": ['Adelie', 'Gentoo', "Chinstrap"]}

datasets['multi2'] = {"df":penguins_class, 
                     "X": penguins_class_feat[["flipper_length_mm", "bill_length_mm"]].values,
                     "y": y_multi,
                     "feats": ["flipper_length_mm", "bill_length_mm"],
                     "class": ['Adelie', 'Gentoo', "Chinstrap"]}


penguins_bin = penguins_class[penguins_class.species != "Chinstrap"]
y_bin = penguins_bin[["species"]].replace({'Adelie': 0, 'Gentoo': 1}).values.flatten()
# for the classification data we don't want species there either
penguins_class_feat = penguins_bin.drop("species", axis=1)

datasets['bin'] = {"df":penguins_bin, 
                   "X": penguins_class_feat.values,
                   "y": y_bin,
                   "feats": list(penguins_class_feat.columns),
                   "class": ['Adelie', 'Gentoo']}

datasets['flbl'] = {"df":penguins_bin, 
                    "X": penguins_class_feat[["flipper_length_mm", "bill_length_mm"]].values,
                    "y": y_bin,
                    "feats": ["flipper_length_mm", "bill_length_mm"],
                     "class": ['Adelie', 'Gentoo']}

datasets['blbd'] = {"df":penguins_bin, 
                    "X": penguins_class_feat[["bill_length_mm", "bill_depth_mm"]].values,
                    "y": y_bin,
                    "feats": ["bill_length_mm", "bill_depth_mm"],
                    "class": ['Adelie', 'Gentoo']}
In [5]:
if os.path.exists(os.path.join(data_dir,"palmerpenguins")):
    print("Already Cloned")
else:
    import git
    git.Git(os.getcwd()).clone("https://github.com/allisonhorst/palmerpenguins.git")
    
penguins_fig_dir = os.path.join(data_dir,"palmerpenguins", "man", "figures")
Already Cloned

Terminology5

Root node: no incoming edge, zero, or more outgoing edges.

Internal node: one incoming edge, two (or more) outgoing edges.

Leaf node: each leaf node is assigned a class label if nodes are pure; otherwise, the class label is determined by majority vote.

Parent and child nodes: If a node is split, we refer to that given node as the parent node, and the resulting nodes are called child nodes.

Notes

  • Leaves are typically drawn upside down, so they are at the bottom of the tree
In [6]:
# taken from https://github.com/rasbt/stat479-machine-learning-fs19/blob/master/06_trees/06-trees__slides.pdf
# TODO: If I had time I would make my own diagram with "Global Pandemic?" as the second box.
fig_num+=1
print(color.BOLD+color.UNDERLINE+"Figure %d: Categorical Decision Tree"%fig_num+color.END)
Image(os.path.join(image_dir, "tree_terms.png"), width=550)
Figure 1: Categorical Decision Tree
Out[6]:

Dataset Example: Penguins

The "palmer penguins" dataset2 contains data for 344 penguins from 3 different species and from 3 islands in the Palmer Archipelago, Antarctica.


Artwork by @allison_horst

In [7]:
fig_num+=1
print(color.BOLD+color.UNDERLINE+"Figure %d: Penguin Buddies"%fig_num+color.END)
Image(os.path.join(penguins_fig_dir, "lter_penguins.png"), width=600)
Figure 2: Penguin Buddies
Out[7]:
In [8]:
display(penguins.head())
species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g sex
0 Adelie Torgersen 39.1 18.7 181.0 3750.0 Male
1 Adelie Torgersen 39.5 17.4 186.0 3800.0 Female
2 Adelie Torgersen 40.3 18.0 195.0 3250.0 Female
3 Adelie Torgersen NaN NaN NaN NaN NaN
4 Adelie Torgersen 36.7 19.3 193.0 3450.0 Female
In [9]:
fig_num+=1
g = sns.pairplot(penguins, hue="species", markers=shape_dict, palette=col_dict, height=height_inch/2, aspect = gr)
plt.suptitle("Figure %d: Penguins Data Pairplot"%fig_num)
plt.tight_layout()
plt.close()
fig = g.fig
display(MatplotlibFigure(fig, centered=True))
In [10]:
g = sns.scatterplot(data=datasets['flbl']['df'], x = "flipper_length_mm", 
                    y = "bill_length_mm", hue="species", style = "species",
                    markers = shape_dict,  palette= col_dict)
g.axes.get_legend().set_title(False)
plt.suptitle('Extra: Example plot')
plt.show()

2. General Decision Tree Algorithm

An algorithm starts at a tree root and then splits the data based on the features that give the best split based on a splitting criterion.

Generally this splitting procedure occours until3,4...

  • ...all the samples within a given node all belong to the same class,
  • ...the maximum depth of the tree is reached,
  • ...a split cannot be found that improves the model.

NOTES

  • The process of growing a decision tree can be expressed as a recursive algorithm as follows5:
    1. Pick a feature such that when parent node is split, it results in the largest information gain and stopping if information gain is not positive.
    2. Stop if child nodes are pure or no improvement in class purity can be made.
    3. Go back to step 1 for each of the two child nodes.
In [11]:
from sklearn.tree import DecisionTreeClassifier
from mlxtend.plotting import plot_decision_regions
from sklearn import tree
from re import search
import matplotlib as mpl

DT_g1 = DecisionTreeClassifier(criterion='gini',
                               max_depth = 1,
                               random_state=42)

l_labels = [[189, 57.5],[219, 57.5]]
r_labels = [[0.225, 0.4],[0.725, 0.4]]
tp_labels = [[0.28, 0.5],[0.64, 0.5]]


# TODO: this is a bit of a large function that I should come back to and
# refactor
def regions_tree(DT, X, y, feature_names, class_names, col_dict, l_label_pos=None, 
                 r_label_pos=None, tp_label_pos=None, impurity = False, 
                 xaxis_lim = None, yaxis_lim = None, color=False, title=None, 
                 savefig=None, alpha=1):
    DT.fit(X,y)

    fig, axes = plt.subplots(ncols=2, figsize = (width_inch*2, height_inch))
    plt.sca(axes[0])

    scatter_kwargs = {'alpha': alpha}
    ax = plot_decision_regions(X, y, clf = DT,
                               markers = ','.join([shape_dict[x] for x in class_names]),
                               colors = ','.join([col_dict[x] for x in class_names]),
                               scatter_kwargs = scatter_kwargs
                              )

    handles, labels = ax.get_legend_handles_labels()
    
    if l_label_pos:
        plt.text(l_label_pos[0][0], l_label_pos[0][1], "$R_1$", bbox=dict(facecolor='white', alpha=0.3))
        plt.text(l_label_pos[1][0], l_label_pos[1][1], "$R_2$", bbox=dict(facecolor='white', alpha=0.3))
    
    ax.legend(handles, 
              class_names, 
              framealpha=0.3, scatterpoints=1)

    plt.xlabel(feature_names[0]) 
    plt.ylabel(feature_names[1])
    if xaxis_lim:
        plt.xlim(xaxis_lim)
    if yaxis_lim:
        plt.ylim(yaxis_lim)
    
    # The arrows dont show up on versions of Scikit Learn due to a weird interaction with sns 
    # so I need to use `plt.style.context("classic")`.
    with plt.style.context("classic"):
        plt.sca(axes[1])
        tp = tree.plot_tree(DT,
                       feature_names=feature_names, 
                       class_names=class_names,
                       filled = True,
                           impurity = impurity)
        if r_label_pos:
            axes[1].text(r_label_pos[0][0], r_label_pos[0][1],
                         "$R_1$")
            axes[1].text(r_label_pos[1][0], r_label_pos[1][1], 
                         "$R_2$")
            
        if tp_label_pos:
            axes[1].text(tp_label_pos[0][0], tp_label_pos[0][1],
                         "True", {'fontweight':'bold'})
            axes[1].text(tp_label_pos[1][0], tp_label_pos[1][1], 
                         "False", {'fontweight':'bold'})
        
        for i, node in enumerate(tp):
            for class_name in class_names:
                if search(class_name, node.get_text()):
                    tp[i].set_backgroundcolor(col_dict[class_name])

    if title:
        plt.suptitle(title)
        
    if savefig:
        plt.savefig(savefig)
    
    return fig

Below is an example of a very shallow decision tree where we have set max_depth = 1.

Terminology (Reminder)1

  • The regions $R_1$ and $R_2$ are known as terminal nodes or leaves of the tree.
  • The points where the predictor space is split are known as the internal nodes. In this case as we only have one split this is the "root node".
  • The segments of the trees that connect the nodes are branches.
In [12]:
fig_num+=1
fig = regions_tree(DT_g1, datasets['flbl']['X'], datasets['flbl']['y'], 
             datasets['flbl']['feats'], datasets['flbl']['class'], 
             col_dict, l_labels, r_labels, tp_labels, 
             title="Figure %d: Shallow Penguins Tree"%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

Extra: dtreeviz

Heres an additional visualisation package with extra features such as bein able to follow the path of a hypothetical test sample.

I don't use dtreeviz in the lectures, as it can be a bit of a hassle to setup. However you may also find this a useful way of thinking about the splitting.

Notes

In [13]:
from re import search

# Change to True if you want to run the dtreeviz code
DTREEVIS = True

if DTREEVIS:
    if not search("graphviz", os.environ.get('PATH')):
        # CHANGE THIS TO WHERE graphviz IS ON YOUR COMPUTER!
        GRAPHVIS_PATH = 'C:\\Program Files\\Graphviz\\bin'
        #C:\Users\delliot2\.conda\envs\mlp\Lib\site-packages\graphviz
        os.environ["PATH"] += os.pathsep + GRAPHVIS_PATH
        
    from dtreeviz.trees import dtreeviz
    DT_g1.fit(datasets['flbl']['X'], datasets['flbl']['y'])

    viz = dtreeviz(DT_g1, datasets['flbl']['X'], datasets['flbl']['y'],
                   feature_names=datasets['flbl']['feats'],
                   class_names=['Adelie', 'Gentoo'],
                   orientation ='LR', colors = col_dict, scale=2.0
                  )
    display(viz)
G cluster_legend node0 2021-03-14T18:06:22.820080 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ leaf1 2021-03-14T18:06:22.940780 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node0->leaf1 < leaf2 2021-03-14T18:06:22.987652 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/ node0->leaf2 legend 2021-03-14T18:06:22.648587 image/svg+xml Matplotlib v3.3.4, https://matplotlib.org/

We can make the tree "deeper", and therefore more complex, by setting the max_depth = 3.

In [14]:
DT_g3 = DecisionTreeClassifier(criterion='gini',
                               max_depth = 3,
                               random_state=42)

fig_num+=1
fig = regions_tree(DT_g3, datasets['flbl']['X'], datasets['flbl']['y'], 
             datasets['flbl']['feats'], datasets['flbl']['class'], 
             col_dict, [[219, 57.5],[219, 37.5]], 
             title="Figure %d: Penguins Tree (max_depth = 3)"%fig_num,
             savefig = os.path.join(image_dir,"Soft_Margin_Hyperplane.png")
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))

Extrra

You may be wondering, where is the decision boundary part for the node on the 3rd level in figure 5? Why bother doing this split if it doesn't do anything to our decision boundary?

Well, although not contributing to the decision boundry (because either side of the split is still going to be classified as Gentoo) this split does improve our "splitting criterion" (information gain), as the right leaf node is now pure. More on this later in the notes!

We could also use more than 2 features as seen below.

NOTES

  • Although more features are harder to plot decision boundaries (see dtreeviz).
In [15]:
# TODO: time allowing, do a 3d plot of the decision boundary.
DT_g = DecisionTreeClassifier(criterion='gini',
                              max_depth = 3,
                              random_state=42)
DT_g.fit(datasets['bin']['X'], datasets['bin']['y'])

fig, axes = plt.subplots(figsize = (width_inch, height_inch))
tp = tree.plot_tree(DT_g,
                    feature_names=datasets['bin']['feats'], 
                    class_names=datasets['bin']['class'],
                    filled = True,
                   impurity = False,
               ax=axes
              )

for i, node in enumerate(tp):
    for class_name in datasets['bin']['class']:
        if search(class_name, node.get_text()):
            tp[i].set_backgroundcolor(col_dict[class_name])

fig_num+=1            
plt.title("Figure %d: Multiple Features Penguins Tree (max_depth = 3)"%fig_num)
plt.savefig(os.path.join(image_dir, "figurex.png"))
plt.close()
display(MatplotlibFigure(fig, centered=True))

Notes

  • There is an example of a 3D decision region on pg. 547 in Murphy, K. P. (2012). Machine learning: a probabilistic perspective. MIT press.

We could also also easily extend this to have more than a 2 (binary) class labels.

In [16]:
fig_num+=1   
fig = regions_tree(DT_g, datasets['multi2']['X'], datasets['multi2']['y'], 
             datasets['multi2']['feats'], datasets['multi2']['class'], 
             col_dict,
             title="Figure %d: Multiple Penguin Classes Tree (max_depth = 3)"%fig_num
            )         
plt.close()
display(MatplotlibFigure(fig, centered=True))
In [17]:
DT_g2 = DecisionTreeClassifier(criterion='gini',
                               random_state=42)
DT_g2.fit(datasets['multi']['X'], datasets['multi']['y'])

fig, axes = plt.subplots(figsize=(10, 5))
tp = tree.plot_tree(DT_g2,
                    feature_names=datasets['multi']['feats'], 
                    class_names=datasets['multi']['class'],
                    filled = True,
                   impurity = False,
               ax=axes
              )
for i, node in enumerate(tp):
    for class_name in datasets['bin']['class']:
        if search(class_name, node.get_text()):
            tp[i].set_backgroundcolor(col_dict[class_name])

plt.title("Multiple Features and Penguin Classes Tree")
plt.savefig(os.path.join(image_dir, "figurey.png"))
plt.show()

Estimating Class Probabilities3

We can estimate the probability that an instance belongs to a particular class easily.

For our new observation we...

  1. traverse the tree to find the leaf node the instance is assigned to,
  2. return the ratio of training instances of class $c$ in that node,
  3. assign our observation to the class with the highest probability.

Example

Using the model in figure 7, we could find the probability of the following penguins species:

In [18]:
test_x = datasets['multi2']['df'].sample(random_state=1)[['bill_length_mm', 'flipper_length_mm']]
display(test_x)
test_x = test_x.values
DT_g.fit(datasets['multi2']['X'], datasets['multi2']['y'])
display(pd.DataFrame(DT_g.predict_proba(test_x), columns = datasets['multi2']['class']).round(4))
print("Predicted Species is " + datasets['multi2']['class'][DT_g.predict(test_x)[0]])
bill_length_mm flipper_length_mm
65 41.6 192.0
Adelie Gentoo Chinstrap
0 0.0 0.0204 0.9796
Predicted Species is Chinstrap

In a general sense this approach is pretty simple, however there are a number of design choices and considerations we have to make including5:

  • How do we decide which features to select for splitting a parent node into child nodes?
  • How do we decide where to split features?
  • When do we stop growing a tree?
  • How do we make predictions if no attributes exist to perfectly separate non-pure nodes further?

3. Specific Decision Tree Algorithms

Most decision tree algorithms address the following implimentation choices differently5:

  • Splitting Criterion: Information gain, statistical tests, objective function, etc.
  • Number of Splits: Binary or multi-way.
  • Variables: Discrete vs. Continuous.
  • Pruning: Pre- vs. Post-pruning.

There are a number of decision tree algorithms, prominant ones include:

  • Iterative Dichotomizer 3 (ID3)
  • C4.5
  • CART

Notes

  • Binary means nodes always have two children.

CART

Scikit-Learn uses an optimised version of the Classification And Regression Tree (CART) algorithm.

  • Splitting Criterion: Information gain
  • Number of Splits: Binary
  • Independent Variables (Features): Continuous
  • Dependent variable: Continuous or Categorical
  • Pruning: Pre- & Post-pruning

Notes

Information Gain4

An algorithm starts at a tree root and then splits the data based on the feature, $f$, that gives the largest information gain, $IG$.

To split using information gain relies on calculating the difference between an impurity measure of a parent node, $D_p$, and the impurities of its child nodes, $D_j$; information gain being high when the sum of the impurity of the child nodes is low.

We can maximise the information gain at each split using,

$$IG(D_p,f) = I(D_p)-\sum^m_{j=1}\frac{N_j}{N_p}I(D_j),$$

where $I$ is out impurity measure, $N_p$ is the total number of samples at the parent node, and $N_j$ is the number of samples in the $j$th child node.

Some algorithms, such as Scikit-learn's implimentation of CART, reduce the potential search space by implimenting binary trees:

$$IG(D_p,f) = I(D_p) - \frac{N_\text{left}}{N_p}I(D_\text{left})-\frac{N_\text{right}}{N_p}I(D_\text{right}).$$

Notes

  • The CART algorithm is greedy - meaning it searches for the optimum split at each level. It does not check if this is the best split to improve impurity further down the tree3.
  • To find the optimal tree is known as an NP-Complete problem, meaning it is intractable even for small training sets3.

Three impurity measures that are commonly used in binary decision trees are the classification error ($I_E$), gini impurity ($I_G$), and entropy ($I_H$)4.

In [19]:
def gini(p):
    return p * (1 - p) + (1 - p) * (1 - (1 - p))


def entropy(p):
    return - p * np.log2(p) - (1 - p) * np.log2((1 - p))


def error(p):
    return 1 - np.max([p, 1 - p])

# edited from https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/ch03/ch03.ipynb
def inf_gain_plot(plot_type="All", title=None):

    x = np.arange(0.0, 1.0, 0.01)

    ent = [entropy(p) if p != 0 else None for p in x]
    sc_ent = [e * 0.5 if e else None for e in ent]
    err = [error(i) for i in x]
    
    lab_ls_c = zip([ent, sc_ent, gini(x), err], 
                        ['Entropy', 'Entropy (scaled)', 
                         'Gini impurity', 'Misclassification error'],
                        ['-', '-', '--', '-.'],
                        ['black', 'gray', 'red', 'green', 'cyan'])
    
    if plot_type not in ["All", "Entropy", "Gini", "Misclassification"]:
        print("Not a valid `plot_type`")
        return

    fig = plt.figure(figsize = (width_inch, height_inch))
    ax = plt.subplot(111)
    for i, lab, ls, c, in lab_ls_c:
        alpha = 0.1
        if plot_type == "All":
            alpha = 1.
        elif plot_type == "Entropy" and lab in ['Entropy', 'Entropy (scaled)']:
            alpha = 1.
        elif plot_type == "Gini" and lab == "Gini impurity":
            alpha = 1.
        elif plot_type == "Misclassification" and lab == "Misclassification error":
            alpha = 1.
        
        line = ax.plot(x, i, label=lab, linestyle=ls, lw=2, color=c, alpha=alpha)

    ax.legend(loc='lower center', bbox_to_anchor=(0.5,-0.3),
              ncol=5, fancybox=True, shadow=False)

    ax.axhline(y=0.5, linewidth=1, color='k', linestyle='--')
    ax.axhline(y=1.0, linewidth=1, color='k', linestyle='--')
    plt.ylim([0, 1.1])
    plt.xlabel('p(i=1)')
    plt.ylabel('impurity index')
    if title:
        plt.title(title)
    return fig
    
fig_num+=1
fig = inf_gain_plot(plot_type="All", 
              title="Figure %d: Impurity Measures"%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

Classification Error4

This is simply the fraction of the training observations in a region that does belongs to the most common class:

$$I_E = 1 - \max\left\{p(i|t)\right\}$$

Here $p(i|t)$ is the proportion of the samples that belong to the $i$th class $c$, for node $t$.

Notes

  • Another way of writing this is $E = 1 - \substack{max\\k}(\hat p_{mk})$, where $\hat p_{mk}$ is the proportion of training observations in the $m$th region that are from the $k$th class1.
In [20]:
fig_num+=1
fig = inf_gain_plot(plot_type="Misclassification",
              title="Figure %d: Classification Error"%fig_num
             )
plt.close()
display(MatplotlibFigure(fig, centered=True))

Entropy Impurity4

For all non-empty classes ($p(i|t) \neq 0$), entropy is given by

$$I_H=-\sum^c_{i=1}p(i|t)\log_2p(i|t).$$

The entropy is therefore 0 if all samples at the node belong to the same class and maximal if we have a uniform class distribution.

For example in binary classification ($c=2$):

  • $I_H=0 \text{ if } p(i=1|t)=1 \text{ or } p(i=0|t)=0$
  • $I_H=1 \text{ if } p(i=1|t)=0.5 \text{ or } p(i=0|t)=0.5$

Notes

  • In binary classification, entropy reaches its minimum (0) when all examples in a given node have the same label; on the other hand, the entropy is at its maximum of 1 when exactly one-half of examples a node are labeled with 1, which would be useless for classification.6
  • Another way of writing this is1 $D = -\sum^K_{k=1}\hat p_{mk}log \hat p_{mk}$.
    • Entropy will take on a value near 0 if the $\hat p_{mk}$'s are all near 0 or 1, therefore will take on a small value if the $m$th node is pure.
In [21]:
fig_num+=1
fig = inf_gain_plot(plot_type="Entropy",
              title="Figure %d: Entropy Impurity"%fig_num
             )
plt.close()
display(MatplotlibFigure(fig, centered=True))
In [22]:
DT_e1 = DecisionTreeClassifier(criterion='entropy',
                               max_depth = 1,
                               random_state=42)

fig = regions_tree(DT_e1, datasets['flbl']['X'], datasets['flbl']['y'], 
             datasets['flbl']['feats'], datasets['flbl']['class'], 
             col_dict,
             l_labels, r_labels, tp_labels, impurity = True,
             title = "Extra: Shallow Tree with Entropy"
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))

Gini Impurity4

Gini Impurity is an alternative measurement, which minimises the probabilty of misclassification,

$$ \begin{align} I_G(t) &= \sum^c_{i=1}p(i|t)(1-p(i|t)) \\ &= 1-\sum^c_{i=1}p(i|t)^2. \end{align} $$

This measure is also maximal when classes are perfectly mixed (e.g. $c=2$):

$$ \begin{align} I_G(t) &= 1 - \sum^c_{i=1}0.5^2 = 0.5. \end{align} $$

Notes

  • It is also a measure of node "purity" as a small value indicates a node contains predominantly observations from a single class.
  • Another way of writing this is1 for $K$ classes, $G = \sum^K_{k=1}\hat p_{mk}(1-\hat p_{mk})$.
    • It takes a small value if all of the $\hat p_{mk}$'s are close to 0 or 1.
  • Whether we use entropy or Gini impurity generally does not really matter, because both have the same concave/bell shape.
    • Gini is computationally more efficient to compute than entropy due to the lack of the log.
    • When they do differ, Gini is more likely to isolate the most frequenct class to its own branch and entropy produce slightly more ballanced trees3.
In [23]:
fig_num+=1
fig = inf_gain_plot(plot_type="Gini",
              title="Figure %d: Gini Impurity"%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))
In [24]:
fig = regions_tree(DT_g1, datasets['flbl']['X'], datasets['flbl']['y'], 
             datasets['flbl']['feats'], datasets['flbl']['class'], 
             col_dict, l_labels, r_labels, tp_labels, impurity = True,
             title = "Extra: Shallow Tree with Gini"
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))

Why not Classification Error?

Classification Error is rarely used for information gain in practice.

This is because it can mean that tree growth gets stuck and error doesnt improve, this is not the case for a concave function such as entropy or gini.

Notes

  • Classification Error its not even an option in scikit-learn.
  • Another example of this is given in exersise 3.
In [25]:
# taken from https://github.com/rasbt/stat479-machine-learning-fs19/blob/master/06_trees/06-trees__notes.pdf
fig_num+=1
print(color.BOLD+color.UNDERLINE+"Figure %d: Child Node Averages"%fig_num+color.END)
Image(os.path.join(image_dir, "entropy_pc.png"), width=550)
Figure 12: Child Node Averages
Out[25]:

Feature Importance10,11

Decision trees allow us assess the importance of each feature for classifying the data,

$$ fi_j = \frac{\sum_{t \in s} ni_t}{\sum^m_t ni_t} $$

where $ni_t$ is the $t$th nodes importance, and $s$ are the indices of nodes that split on feature $fi_j$.

We often assess the normalized total reduction of the criterion (e.g. Gini) brought by that feature,

$$ normfi_j = \frac{fi_j}{\sum^p_j fi_j}. $$
In [26]:
def tree_feat_import(DT, X, y, feat_names, class_names, title):

    DT.fit(X, y)
    # get the importances for the features
    importances = DT.feature_importances_

    importances_series = pd.Series(importances,index=feat_names).sort_values(ascending = False)
    
    fig, axes = plt.subplots(ncols = 2, figsize = (width_inch*2, height_inch))
    axes = axes.flatten()
    
    with plt.style.context("classic"):
        plt.sca(axes[0])
        tp = tree.plot_tree(DT,
                       feature_names=feat_names, 
                       class_names=class_names,
                       filled = True)
        for i, node in enumerate(tp):
            for class_name in datasets['bin']['class']:
                if search(class_name, node.get_text()):
                    tp[i].set_backgroundcolor(col_dict[class_name])

    plt.sca(axes[1])
    # plot the important features
    importances_series.plot.barh(legend =False, grid=False)
    plt.suptitle(title)
    plt.tight_layout()

    # summarize feature importance
    #for i,v in enumerate(importances):
    #    print(color.BOLD+feat_names[i]+color.END+": %.3f" % (v))

    #print(color.BOLD+"total: "+color.END + str(round(sum(importances),2)))
    
    return fig

fig_num+=1
fig = tree_feat_import(DT_g, datasets['bin']['X'], datasets['bin']['y'], 
                 datasets['bin']['feats'],datasets['bin']['class'],
                 'Figure %d: Feature Importances for Classifying Adelie and Gentoo Penguins'%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

Pruning

Question: When do we stop growing a tree?

Occam’s razor: Favor a simpler hypothesis, because a simpler hypothesis that fits the data equally well is more likely or plausible than a complex one5.

To minimize overfitting, we can either set limits on the trees before building them (pre-pruning), or reduce the tree by removing branches that do not significantly contribute (post-pruning).

NOTES

  • In other words, if decision trees are not pruned, they have a high risk of overfitting to the training data5.

Dataset Example: Breast Cancer Wisconsin Dataset12

Digitized image of a fine needle aspirate of a breast mass. The features describe characteristics of the cell nuclei present in the image.

The dataset was created from digitized images of healthy (benign) and cancerous (malignant) tissues.


Image from Levenson et al. (2015), PLOS ONE, doi:10.1371/journal.pone.0141357.

In [27]:
Image(os.path.join(image_dir, "cancer_tissue.png"), width=600)
Out[27]:

Notes

  • Fine needle aspiration is a type of biopsy procedure.

Extra

You can explore the data below although I recomend limiting the number of features to plot for ease of viewing.

In [28]:
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
display(X.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 569 entries, 0 to 568
Data columns (total 30 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   mean radius              569 non-null    float64
 1   mean texture             569 non-null    float64
 2   mean perimeter           569 non-null    float64
 3   mean area                569 non-null    float64
 4   mean smoothness          569 non-null    float64
 5   mean compactness         569 non-null    float64
 6   mean concavity           569 non-null    float64
 7   mean concave points      569 non-null    float64
 8   mean symmetry            569 non-null    float64
 9   mean fractal dimension   569 non-null    float64
 10  radius error             569 non-null    float64
 11  texture error            569 non-null    float64
 12  perimeter error          569 non-null    float64
 13  area error               569 non-null    float64
 14  smoothness error         569 non-null    float64
 15  compactness error        569 non-null    float64
 16  concavity error          569 non-null    float64
 17  concave points error     569 non-null    float64
 18  symmetry error           569 non-null    float64
 19  fractal dimension error  569 non-null    float64
 20  worst radius             569 non-null    float64
 21  worst texture            569 non-null    float64
 22  worst perimeter          569 non-null    float64
 23  worst area               569 non-null    float64
 24  worst smoothness         569 non-null    float64
 25  worst compactness        569 non-null    float64
 26  worst concavity          569 non-null    float64
 27  worst concave points     569 non-null    float64
 28  worst symmetry           569 non-null    float64
 29  worst fractal dimension  569 non-null    float64
dtypes: float64(30)
memory usage: 133.5 KB
None
In [29]:
X_plot = X.copy()
X_plot['type'] = y

fig_num+=1
sns.pairplot(X_plot.iloc[:,:8].join(X_plot.iloc[:,-1]), hue='type')
plt.suptitle("Figure %d: Breat Cancer Pairplot"%fig_num)
plt.tight_layout()
plt.show()
In [30]:
col_dict = {
    "benign":"#ff7600",
    "malignant":"#057576"}

shape_dict = {
    "benign":"o",
    "malignant":"X"}

DT_g = DecisionTreeClassifier(criterion='gini',
                               random_state=42)

cancer_features = ['mean radius','mean smoothness']
fig = regions_tree(DT_g, X[cancer_features].values, y.values, cancer_features, 
             load_breast_cancer().target_names, col_dict, yaxis_lim=[0.05,0.20],
             title = "Figure %d: Breast Cancer Overfitting"%fig_num
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))

Pre-Pruning

An a priori limit on nodes, or tree depth, is often set to avoid overfitting due to a deep tree4,5.

Notes

  • Another one is to stop growing if a split is not statistically significant (e.g., $\chi^2$ test)5. However this is not yet availble in sklearn, although I'm sure you can find some code for it somewhere on the internet.
In [31]:
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, KFold, StratifiedKFold

def hyper_search(model, params, X, y, save_path, n_iter=60, metric="accuracy", 
                 cv = KFold(5), random_state=42, refit=True,
                 overwrite=False, warning=False):
    if os.path.exists(save_path) and overwrite==False:
        #load the model
        models = joblib.load(save_path)
    else:
        # check all param inputs are lists
        if all(type(x)==list for x in params.values()):
            search_type = "Gridsearch"
            models = GridSearchCV(model, param_grid=params, scoring=metric, cv=cv, 
                                  refit=refit, return_train_score=True)
            n_iter = len(list(itertools.product(*list(iter(params.values())))))
        else:
            search_type = "Randomsearch"
            models = RandomizedSearchCV(model, param_distributions=params,
                                        n_iter=n_iter, scoring=metric, cv=cv,
                                        refit=refit, random_state=random_state,
                                        return_train_score=True)
        
        start = time()
        if warning:
            models.fit(X, y)
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                models.fit(X, y)
        
        print(search_type + " took %.2f seconds for %d candidates" % ((time() - start), n_iter))
        joblib.dump(models, save_path)
    
    return models

# specify parameters and distributions to sample from
param_grid = {"max_depth":list(range(1,15))}

depth_gs = hyper_search(DT_g, param_grid, X[cancer_features].values, y,
                        os.path.join(os.getcwd(), "Models", "depth_gs_object.pkl"), 
                        overwrite=True)

pd.DataFrame(depth_gs.cv_results_).sort_values("rank_test_score")[["param_max_depth", "mean_test_score", "std_test_score"]].head()
Gridsearch took 0.18 seconds for 14 candidates
Out[31]:
param_max_depth mean_test_score std_test_score
0 1 0.889365 0.076523
2 3 0.887564 0.043385
3 4 0.876929 0.046898
1 2 0.866387 0.069261
4 5 0.866372 0.041904
In [32]:
fig_num+=1
print("Using GridSearch, the best maximum tree depth was found to be: "+str(list(depth_gs.best_params_.values())[0]))
fig = regions_tree(depth_gs.best_estimator_, X[cancer_features].values, y.values, 
             cancer_features, load_breast_cancer().target_names, col_dict, 
             yaxis_lim = [0.05,0.20], title= 'Figure %d: "Best" Maximum Tree Depth'%fig_num,
             alpha=0.1
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))
Using GridSearch, the best maximum tree depth was found to be: 1
In [33]:
cols = ["param_max_depth", "mean_test_score", "mean_train_score"]

scores_df = pd.DataFrame(depth_gs.cv_results_).sort_values("param_max_depth")[cols]

scores_df = pd.melt(scores_df, value_vars=['mean_test_score','mean_train_score'],
                     id_vars=['param_max_depth'], var_name = ['type'])

fig_num+=1
fig = plt.figure(figsize = (width_inch, height_inch))
sns.lineplot(data = scores_df, x='param_max_depth', y = 'value', hue = 'type')
plt.title('Figure %d: Tree Depth and Dataset Accuracy'%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

We could also set a minimum number of data points for each node5.

In [34]:
# specify parameters and distributions to sample from
param_grid = {"min_samples_leaf":list(range(1,25))}

lsamples_gs = hyper_search(DT_g, param_grid, X[cancer_features].values, y,
                           os.path.join(os.getcwd(), "Models", "lsamples_gs_object.pkl"), 
                           overwrite=True)

pd.DataFrame(lsamples_gs.cv_results_).sort_values("rank_test_score")[["param_min_samples_leaf", "mean_test_score", "std_test_score"]].head()
Gridsearch took 0.26 seconds for 24 candidates
Out[34]:
param_min_samples_leaf mean_test_score std_test_score
14 15 0.889303 0.046233
15 16 0.889303 0.046233
16 17 0.887549 0.047150
8 9 0.885779 0.057628
12 13 0.885763 0.045751
In [35]:
print("Using GridSearch, the best minimum samples per leaf was found to be: "+str(list(lsamples_gs.best_params_.values())[0]))
fig_num+=1
fig = regions_tree(lsamples_gs.best_estimator_, X[cancer_features].values, y.values, 
             cancer_features, load_breast_cancer().target_names, col_dict, 
             yaxis_lim = [0.05,0.20], title = 'Figure %d: "Best" Minimum Samples per Leaf'%fig_num,
             alpha=0.1
            )
plt.close()
display(MatplotlibFigure(fig, centered=True))
Using GridSearch, the best minimum samples per leaf was found to be: 15
In [36]:
cols = ["param_min_samples_leaf", "mean_test_score", "mean_train_score"]

scores_df = pd.DataFrame(lsamples_gs.cv_results_).sort_values("param_min_samples_leaf")[cols]

scores_df = pd.melt(scores_df, value_vars=['mean_test_score','mean_train_score'],
                     id_vars=['param_min_samples_leaf'], var_name = ['type'])

fig_num+=1
fig = plt.figure(figsize = (width_inch, height_inch))
sns.lineplot(data = scores_df, x='param_min_samples_leaf', y = 'value', hue = 'type')
plt.title('Figure %d: Minimum Samples per Leaf and Dataset Accuracy'%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

Post-Pruning

In general, post-pruning consists of going back through the tree once it has been created and removing branches that do not significantly contribute to the error reduction and replacing them with leaf nodes6

Two common approaches are reduced-error pruning and cost-complexity pruning

  • Reduced-error pruning5
    • Greedily remove nodes based on validation set performance
    • Generally improves performance but can be problematic for limited data set sizes.
  • Cost-complexity pruning5
    • Recursively finds the node with the “weakest link”.
    • Nodes are characterized by $\alpha \geq 0$, and nodes with the smallest effective $\alpha$ are pruned first7.
    • The trees are then defined as $I + \alpha|N|$, where $I$ is an impurity measure, such as the total misclassification rate of the terminal nodes, $\alpha$ is a tuning parameter, and $|N|$ is the total number of nodes5.

Notes

  • An early "bad" split may lead to a good split later, therfore we may want to grow a large tree first and prune it back to obtain a subtree1.
  • $\alpha$ is the price to pay for having a tree with many nodes, so this tends to minimise to a smaller subtree. This is reministant of the lasso.

Cost-complexity pruning7

Using Scikit-learn, we can recursively fit a complex tree with no prior pruning and have a look at the effective alphas and the corresponding total leaf impurities at each step of the pruning process.

As alpha increases, more of the tree is pruned, thus creating a decision tree that generalizes better.

We can select the alpha that reduces the distance between the train and validation scores.

Notes

  • In Scikit-learn 0.22 the parameter ccp_alpha was introduced (short for Cost Complexity Pruning- Alpha)
  • DecisionTreeClassifier.cost_complexity_pruning_path returns the effective alphas and the corresponding total leaf impurities at each step of the pruning process.
In [37]:
# https://github.com/krishnaik06/Post_Pruning_DecisionTre/blob/master/plot_cost_complexity_pruning.ipynb
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X[cancer_features].values, y.values,
                                                    test_size = 0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train,
                                                    test_size = 0.2, random_state=42)

DT = DecisionTreeClassifier(criterion='gini',
                            random_state=42)
path = DT.cost_complexity_pruning_path(X_train,y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)
#print("Number of nodes in the last tree is: {} with ccp_alpha: {}".format(
#      clfs[-1].tree_.node_count, ccp_alphas[-1]))


train_scores = [clf.score(X_train, y_train) for clf in clfs]
test_scores = [clf.score(X_test, y_test) for clf in clfs]

fig, axes = plt.subplots(ncols=2, figsize = (width_inch*2, height_inch))
axes = axes.flatten()
for i, ax in enumerate(axes):
    ax.set_xlabel("alpha")
    ax.set_ylabel("accuracy")
    ax.plot(ccp_alphas, train_scores, marker='o', label="train",
            drawstyle="steps-post")
    ax.plot(ccp_alphas, test_scores, marker='o', label="test",
            drawstyle="steps-post")
    ax.legend()
    if i ==1:
        ax.set_xlim(0.,0.008)
        ax.set_ylim(0.85,1.01)

fig_num+=1
plt.suptitle("Figure %d: Accuracy vs Alpha for Training and Testing Sets"%fig_num)
plt.close()
display(MatplotlibFigure(fig, centered=True))

Then we can train a decision tree using the chosen effective alpha.

In [38]:
DT_ccp = DecisionTreeClassifier(criterion='gini',
                               random_state=42,
                               ccp_alpha = 0.0065                            
                              )
fig_num+=1
fig = regions_tree(DT_ccp, X[cancer_features].values, y.values, 
             cancer_features, load_breast_cancer().target_names, 
             col_dict, yaxis_lim = [0.05,0.20],
             title = 'Figure %d: "Best" Alpha for Accuracy'%fig_num,
             alpha=0.1)

# get the importances for the features
importances = DT_ccp.feature_importances_

importances_series = pd.Series(importances,index=cancer_features).sort_values(ascending = False)

# summarize feature importance
for i,v in enumerate(importances):
    print(color.BOLD+cancer_features[i]+color.END+": %.3f" % (v))
    
plt.close()
display(MatplotlibFigure(fig, centered=True))
mean radius: 0.902
mean smoothness: 0.098

Other Algorithms

ID3 - Iterative Dichotomizer 38

  • Splitting Criterion: Maximises Information Gain/ Minimises Entropy
  • Number of Splits: Multiway
  • Variables: Discrete binary and multi-category features
  • Pruning: None

C4.59

  • Splitting Criterion: Maximises Information Gain/ Minimises Entropy
  • Number of Splits: Multiway
  • Variables: Discrete & Continuous (expensive)
  • Pruning: Post-Pruning
  • can handle missing data
  • C5.0 is the latest release version under a proprietary license

Notes

  • you can read an introduction to ID3 in pgs.27-30 of Burkov (2019)

Associated Exercises

Now might be a good time to try exercises 1-6.

References

  1. James, Gareth, Daniela Witten, Trevor Hastie, and Robert Tibshirani. An introduction to statistical learning. Vol. 112. New York: springer, 2013.
  2. Gorman KB, Williams TD, Fraser WR (2014). Ecological sexual dimorphism and environmental variability within a community of Antarctic penguins (genus Pygoscelis). PLoS ONE 9(3):e90081. https://doi.org/10.1371/journal.pone.0090081
  3. GĂ©ron, A. (2017). Hands-on machine learning with Scikit-Learn and TensorFlow: concepts, tools, and techniques to build intelligent systems. " O'Reilly Media, Inc.".
  4. Raschka, Sebastian, and Vahid Mirjalili. Python Machine Learning, 2nd Ed. Packt Publishing, 2017.
  5. https://github.com/rasbt/stat479-machine-learning-fs19/blob/master/06_trees/06-trees__notes.pdf
  6. Burkov, A. (2019). The hundred-page machine learning book (Vol. 1). Canada: Andriy Burkov.
  7. https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#:~:text=Cost%20complexity%20pruning%20provides%20another,the%20number%20of%20nodes%20pruned.
  8. Quinlan, J. R. (1986). Induction of decision trees. Machine learning, 1 (1), 81-106
  9. Quinlan, J. R. (1993). C4. 5: Programming for machine learning. Morgan Kauffmann, 38, 48.
  10. https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
  11. https://towardsdatascience.com/the-mathematics-of-decision-trees-random-forest-and-feature-importance-in-scikit-learn-and-spark-f2861df67e3#:~:text=Feature%20importance%20is%20calculated%20as,the%20more%20important%20the%20feature.
  12. https://archive.ics.uci.edu/ml/datasets/Breast+Cancer+Wisconsin+(Diagnostic)
In [ ]:
import sys
from shutil import copyfile

# where the HTML template is located
dst = os.path.join(sys.prefix, 'lib', 'site-packages', 'nbconvert', 'templates', "classic.tplx")

# If its not located where it should be
if not os.path.exists(dst):
    # uses a nb_pdf_template
    curr_path = os.path.join(os.getcwd(),"..", "Extra", "classic.tplx")
    # copy where it is meant to be
    copyfile(curr_path, dst)

   
# Create HTML notes document (preferred)
!jupyter nbconvert 1_Decision_Trees.ipynb \
    --to html \
    --output-dir . \
    --template classic
# Create html slides (issues)
!jupyter nbconvert 1_Decision_Trees.ipynb \
    --to slides \
    --output-dir . \
    --TemplateExporter.exclude_input=True \
    --TemplateExporter.exclude_output_prompt=True \
    --SlidesExporter.reveal_scroll=True

# Create pdf notes document (issues)
!jupyter nbconvert 1_Decision_Trees.ipynb \
    --to html \
    --output-dir ./PDF_Prep \
    --output 1_Decision_Trees_no_code \
    --TemplateExporter.exclude_input=True \
    --TemplateExporter.exclude_output_prompt=True