Growing a (Decision) Tree

post017_full_tree_result.png

In this post I will complete the creation of the decision tree by iteratively building it with Gini metric; I will show a pre-order depth first traversal which will be used to create tree charts using graphviz dot language

The code for this post is available here

In the last post we saw how to evaluate a Gini inpurity index across all features Sowing a (Decision) Tree

Basic data structure

Here I will describe what will be the content of each node in the decision tree.

the main tree node contains the classification of the current dataset according to the majority; in case of even count of two or more classes the result is random.

With the predicted class a confidence is calculated which is the frequency of the class in the current dataset

Optionally a further split may be added which contains the chosen feature and a floating point number representing the threshold separating the left sub-branch from the right, here we can save the metric used to evaluate the split.

#[derive(Debug)]
pub struct Rule {
    dimension: String,
    cutoff: f64,
    metric: f64,
}

#[derive(Debug)]
pub struct Decision {
    rule: Option<Rule>,
    confidence: f64,
    prediction: String,
}

Predicting class at a given node

Here I will describe how each tree node will predict a class from the training dataset

In the training stage, at each node we have an associated subset of the original training set.

The predicted category will be the most populated; we can associate a confidence with this prediction by estimating the probability of finding a sample in the predicted class

pub fn predict_majority_dataframe<'a>(
    data: & 'a DataFrame, target: & str
) -> PolarsResult<Decision> {
    // extract the categorical target column
    let labels = data.column(target)?.categorical()?;

    let total = labels.len() as f64;

    // count all categories and sort them
    let result_count = labels.value_counts()?;

    // get the most frequent category
    let result_cat = result_count.head(Some(1));

    // transform the series into a categorical vector
    let actual_cat = result_cat
        .column(target)?
        .categorical()?;

    // collect all categories as strings
    let string_cat: Vec<String> = actual_cat
        .iter_str()
        .flatten()
        .map(|name| (*name).into())
        .collect();

    let probability: Vec<f64>= result_cat
        .column("counts")?
        .u32()?
    .iter()
    .flatten()
    .map(|c| (c as f64)/total)
        .collect();
    // return the most common category as a string
    return Ok(
        Decision{
            rule: None,
            prediction: string_cat
                .get(0)
                .unwrap()
                .to_string(),
            confidence: probability
                .get(0)
                .unwrap()
                .to_owned()
        }
    );
}

Evaluating the best split point

The best feature and its split point is found by choosing the value that optimize the metric:

  • per each feature
    • metric is evaluated in all split points
  • all result are joined
  • select the best pair of feature and split point according to the metric
pub fn evaluate_best_split<'a>(
    data: & DataFrame,
    features: & HashSet <& str>,
    target: & str,
) -> PolarsResult<Rule> {

    // iteratively evaluate the metric on all features
    let metrics: PolarsResult<Vec<LazyFrame>> = features
        .iter()
        .map(|feature| {
            Ok(evaluate_metric(& data, feature, target)?
                .lazy()
                .with_column(feature.lit().alias("feature")))
        })
        .collect();

    // join all results in a single dataframe
    let concat_rules = UnionArgs {
        parallel: true,
        rechunk: true,
        to_supertypes: true,
    };
    let concat_metrics: DataFrame = concat(metrics?, concat_rules)?.collect()?;

    // search for the best split
    let expr: Expr = col("metrics").lt_eq(col("metrics").min());
    let best_split: DataFrame = concat_metrics
        .clone()
        .lazy()
        .filter(expr)
        .select([col("feature"), col("split"), col("metrics")])
        .collect()?;

    let chosen_features: Vec<String> = best_split
        .column("feature")?
        .str()?
        .iter()
        .flatten()
        .map(|name| <& str as Into<String>>::into(name))
        .collect();

    let chosen_split_point: f64 = best_split.column("split")?.f64()?.get(0).unwrap();

    let split_metric: f64 = best_split.column("metrics")?.f64()?.get(0).unwrap();
    Ok(Rule {
        dimension: chosen_features
            .get(0)
            .unwrap()
            .to_string(),
        cutoff: chosen_split_point,
        metric: split_metric,
    })
}

Stopping rules

We are going to create a tree using a greedy algorithm, i.e. one node at a time, recursively; while this does not guarantees to have the best possible result, it makes this problem treatable.

Until when should this iteration go?

I’d like to implement three basic stopping rules:

  • the current node contains one class only
  • the current level is equal to the maximum depth provided by the user
  • the current node contains less elements than the minimum decided by the user

It is reasonable to split multiple time along the same axis for continuous features, but I’d like to see the effect of dropping a feature once used so I will leave this as a build option

        if (!current_features.is_empty()) && // exhausted features
            (confidence < 1.0) && // all elements belong to one category
            (data.shape().0 > self.min_size) && // size is below minimum threshold
            (level <= self.max_level){ // maximum depth reached

Builder pattern in Rust

Rust does not have optional parameter with default values; to emulate this functionality the “builder” pattern is used.

This pattern consist in the following:

  • create a default constructor for your structure which require only mandatory inputs
  • add a method per each optional field which receives the actual structure (so takes ownership) and returns it mutated
    • this allows to create chains of calls and make sure that there are no other parts of the code which may access the same structure while we are setting it up

In our case we may want to store all the relevant tree creation options: the following are mandatory:

  • the names of the features
  • the name of the target column

The following are optional:

  • the maximum depth of the tree (we may set the default to 3)
  • if we want to reuse all features after each split (usually true)
  • the minimum size of a dataframe: only larger dataframes will be splitted
#[derive(Debug)]
pub struct DTreeBuilder<'a>{
    max_level: usize,
    min_size: usize,
    features: HashSet<& 'a str>,
    target: & 'a str,
    reuse_features: bool
}

// uses a struct to define trees constraints
impl <'a>DTreeBuilder<'a> {
    pub fn new(features: HashSet<& 'a str>, target : & 'a str) -> DTreeBuilder<'a>{
        DTreeBuilder{
            max_level: 3,
            min_size: 1,
            features,
            target,
            reuse_features: true
        }
    }

    pub fn set_max_level(mut self, max_level: usize) -> DTreeBuilder<'a>{
        self.max_level = max_level;
        self
    }

    pub fn set_min_size(mut self, min_size: usize) -> DTreeBuilder<'a>{
        self.min_size = min_size;
        self
    }

    pub fn set_reuse_features(mut self, reuse_features : bool) -> DTreeBuilder<'a>{
        self.reuse_features = reuse_features;
        self
    }
}

Iterative node building

There is a public access point which receives only the original training dataset

impl <'a>DTreeBuilder<'a> {
    // ...
    pub fn build(
        & self,
        data: & DataFrame,
    ) -> PolarsResult<btree::Tree<Decision>> {
        let current_features = if !self.reuse_features {
            let feats = self.features.clone();
            Some(feats)
        }else{
            None
        };
        println!("{1:->0$}{2:?}{1:-<0$}", 20, "\n", self);
        let root = self.build_node(data, 1, & current_features)?;
        Ok(btree::Tree::from_node(root))
    }
    // ...
}

Until a stopping condition is met for each node iteratively all features are evaluated to find the most effective split according to our current metrics (Gini impurity index) than

impl <'a>DTreeBuilder<'a> {
    // ...
    fn build_node(
        & self,
        data: & DataFrame,
        level: usize, // tree depth
        features: & Option<HashSet<& str>>, // optionally used to remove features
    ) -> PolarsResult<btree::Node<Decision>> {
        let prediction = predict_majority_dataframe(data, self.target)?;
        let confidence = prediction.confidence;
        let mut node = btree::Node::new(prediction);
        let current_features = features.clone().unwrap_or(self.features.clone());
        // check stop conditions
        if (!current_features.is_empty()) && // exhausted features
            (confidence < 1.0) && // all elements belong to one category
            (data.shape().0 > self.min_size) && // size is below minimum threshold
            (level <= self.max_level){ // maximum depth reached
                let rule = evaluate_best_split(data, & current_features, self.target)?;
                let higher: DataFrame = data
                    .clone()
                    .lazy()
                    .filter(col(& rule.dimension).gt(rule.cutoff))
                    .collect()?;
                let lower: DataFrame = data
                    .clone()
                    .lazy()
                    .filter(col(& rule.dimension).lt_eq(rule.cutoff))
                    .collect()?;
                // remove features only if requested by the user
                let next_features = match features {
                    None => None,
                    Some(feats) => {
                        let mut reduced_features =
                            feats.clone();
                        reduced_features.remove(rule.dimension.as_str());
                        let feats_vec: Vec<String> = reduced_features
                            .iter()
                            .map(|s| s.to_string())
                            .collect();
                        Some(reduced_features)
                    }
                };
                node.value.rule = Some(rule);
                // creates leaves
                node.left = self
                    .build_node(& higher, level + 1, & next_features)?
                    .into();
                node.right = self
                    .build_node(& lower, level + 1, & next_features)?
                    .into();
            }
        Ok(node)
    }
    // ...
}

Dumping the tree

Pre-order depth first traversal

in a previoust post I show how to create a depth first traversal algorithm.

To be more specific it was a post-order traversal: you can find more details about the kind of traversal algorithms in this Wikipedia page.

To draw our tree we now need a pre-order traversal iterator: Caleb Sander Mateos suggested me in a comment how to use a stack to implement this kind of traversal: my code follows

I added some more useful information to the iterator result

  • the current node depth
  • its number according to the binary position described here
  • a boolean describing if the current node is a leaf
pub struct PreOrderTraversalIter<'a, T>{
    stack: Vec<TreeStackItem<'a, T>>
}

pub struct TreeItem<'a, T>{
    pub id: usize,
    pub level: usize,
    pub value: & 'a T,
    pub leaf: bool
}

struct TreeStackItem<'a, T>{
    id: usize,
    level: usize,
    node: & 'a Node<T>
}

impl<'a, T> PreOrderTraversalIter<'a, T>{
    fn new(tree: & 'a Tree<T>) -> PreOrderTraversalIter<'a, T>{
        match tree.root {
            None => PreOrderTraversalIter { stack: Vec::new() },
            Some(ref node) => PreOrderTraversalIter {
                stack: vec![
                    TreeStackItem{
                        id: 1,
                        level: 1,
                        node: & node
                    }],
            },
        }
    }
}

impl<'a, T> Iterator for PreOrderTraversalIter<'a, T>{
    type Item = TreeItem<'a, T>;
    fn next(& mut self) -> Option<Self::Item> {
        if let Some(item) = self.stack.pop() {
            let mut leaf: bool = true;
            if let Some(ref left) = item.node.left{
                self.stack.push(
                    TreeStackItem{
                        id: item.id << 1,
                        level: item.level + 1,
                        node: & left
                    });
                leaf = false;
            }
            if let Some(ref right) = item.node.right{
                self.stack.push(
                    TreeStackItem{
                        id: (item.id << 1) + 1,
                        level: item.level + 1,
                        node: & right
                    });
                leaf = false;
            }
            Some(
                TreeItem{
                    id: item.id,
                    level: item.level,
                    value: & item.node.value,
                    leaf
                })
        }else{
            None
        }
    }
}

creating a Dot DSL reification

This is an example of the chart of a sorting tree:

post017_example_tree.png

I chose graphviz to automatically generate a chart of my tree graph, I used a subset of its graph language dot.

In these cases the best way for me to create a language generator is to choose which parts of its grammar to transform into data object; i chose:

  • nodes
  • edges
  • ranks to put nodes at the same level in the same row

A rank is actually a list of node names, i.e. strings, thus a vector of strings should be enough, but we need a specialized representation so I used a wrapper type

struct DotNode{
    name: String,
    label: String,
    shape: String,
    style: Option<String>,
    fillcolor: Option<String>
}

struct DotEdge{
    first: String,
    second: String,
    label: String
}

// wrapper type
#[derive(Default)]
struct DotRank(Vec<String>);

impl DotRank{
    fn new() -> DotRank{
        DotRank(
            Vec::new()
        )
    }
}

Per each one I created its text representation following dot grammar

impl Display for DotNode{
    fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result {
        let style: String = match self.style {
            None => "".into(),
            Some(ref kind) => format!(", style=\"{}\"",kind)
        };
        let fillcolor: String = match self.fillcolor {
            None => "".into(),
            Some(ref kind) => format!(", fillcolor=\"{}\"",kind)
        };
        write!(f,"{} [label=\"{}\", shape=\"{}\"{}{}];",self.name, self.label, self.shape, style, fillcolor)
    }
}


impl Display for DotEdge{
    fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result {
        write!(f,"{} -> {} [label=\"{}\"]",self.first,self.second,self.label)
    }
}

impl Display for DotRank{
    fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result {
        write!(f,"{{rank = same; {};}}",self.0.join("; "))
    }
}

Finally I created a full object which contains all of these elements:

pub struct Dot{
    nodes: Vec<DotNode>,
    edges: Vec<DotEdge>,
    ranks: Vec<DotRank>
}

To simplify the building I created a method to add each kind of element

impl Dot{
    pub fn new() -> Self{
        Dot{
            nodes: Vec::new(),
            edges: Vec::new(),
            ranks: Vec::new(),
        }
    }

    pub fn add_node(
        & mut self,
        name: String,
        label: String,
        shape: String,
        style: Option<String>,
        fillcolor: Option<String>
    ) -> () {
        let node = DotNode{
            name,
            label,
            shape,
            style,
            fillcolor
        };
        self.nodes.push(node);
    }

    pub fn add_edge(
        & mut self,
        first: String,
        second: String,
        label: String
    ) -> () {
        let node = DotEdge{first, second, label};
        self.edges.push(node);
    }

    pub fn append_rank(
        & mut self,
        index: usize,
        node: String
    ) -> (){
        //ensure space
        while self.ranks.len() <= index {
            self.ranks.push(DotRank::new())
        }
        // update the rank at index adding the node
        let mut bin = take(& mut self.ranks[index]);
        bin.0.push(node);
        let _ = replace(& mut self.ranks[index], bin);
    }
}

finally its transformation into a string

impl Display for Dot{
    fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result {
        let mut graph: Vec<String>=vec!["digraph {".into(),"rankdir = BT;".into(),"subgraph{".into()];
        for node in & self.nodes{
            graph.push(node.to_string());
        }
        for edge in & self.edges{
            graph.push(edge.to_string());
        }
        for rank in & self.ranks{
            graph.push(rank.to_string());
        }
        graph.push("}".into());
        graph.push("}".into());
        write!(f,"{}",graph.join("\n"))
    }
}

We are now finally able to generate the graph we saw at the beginning of the article

digraph {
rankdir = BT;
subgraph{
node1 [label="petal_length > 2.45e0", shape="box"];
node3 [label="Setosa 1", shape="box", style="rounded,filled", fillcolor="green"];
node2 [label="petal_width > 1.75e0", shape="box"];
node5 [label="petal_length > 4.95e0", shape="box"];
node11 [label="petal_width > 1.65e0", shape="box"];
node23 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"];
node22 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"];
node10 [label="petal_width > 1.55e0", shape="box"];
node21 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"];
node20 [label="sepal_length > 6.95e0", shape="box"];
node41 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"];
node40 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"];
node4 [label="petal_length > 4.85e0", shape="box"];
node9 [label="sepal_length > 5.95e0", shape="box"];
node19 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"];
node18 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"];
node8 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"];
node1 -> node3 [label="no"]
node1 -> node2 [label="yes"]
node2 -> node5 [label="no"]
node5 -> node11 [label="no"]
node11 -> node23 [label="no"]
node11 -> node22 [label="yes"]
node5 -> node10 [label="yes"]
node10 -> node21 [label="no"]
node10 -> node20 [label="yes"]
node20 -> node41 [label="no"]
node20 -> node40 [label="yes"]
node2 -> node4 [label="yes"]
node4 -> node9 [label="no"]
node9 -> node19 [label="no"]
node9 -> node18 [label="yes"]
node4 -> node8 [label="yes"]
{rank = same; node1;}
{rank = same; node3; node2;}
{rank = same; node5; node4;}
{rank = same; node11; node10; node9; node8;}
{rank = same; node23; node22; node21; node20; node19; node18;}
{rank = same; node41; node40;}
}
}

post017_tree_result.png

Some effects of tree creation options

Here I show the effect of a couple of options

  • not reusing features (quite uncommon for this kind of tree)
  • limiting tree depth

Not reusing features

This is typically used when using categorical features. A popular approach is to transform these feature using one-hot encoding, so each column contains only a boolean value. In this case it does not make sense to find more than one split value.

post017_tree_no_reuse.png

Limiting tree depth

In this case we allow each branch to keep splitting until a fixed limit.

More sophisticated approaches are possible like to to prune each branch dynamically.

post017_tree_limited.png

Conclusions?

We still need to evaluate the performance of our tree using our data and other implementations too.

In order to complete this step we need an effective way to

  • create predictions from some unknown data
  • cross validate hyperparameters tuning

But these are subject for another post

marco.p.v.vezzoli

Self taught assembler programming at 11 on my C64 (1983). Never stopped since then -- always looking up for curious things in the software development, data science and AI. Linux and FOSS user since 1994. MSc in physics in 1996. Working in large semiconductor companies since 1997 (STM, Micron) developing analytics and full stack web infrastructures, microservices, ML solutions

You may also like...