Growing a (Decision) Tree
Table of Contents
In this post I will complete the creation of the decision tree by iteratively building it with Gini metric; I will show a pre-order depth first traversal which will be used to create tree charts using graphviz dot language
The code for this post is available here
In the last post we saw how to evaluate a Gini inpurity index across all features Sowing a (Decision) Tree
Basic data structure
Here I will describe what will be the content of each node in the decision tree.
the main tree node contains the classification of the current dataset according to the majority; in case of even count of two or more classes the result is random.
With the predicted class a confidence is calculated which is the frequency of the class in the current dataset
Optionally a further split may be added which contains the chosen feature and a floating point number representing the threshold separating the left sub-branch from the right, here we can save the metric used to evaluate the split.
#[derive(Debug)] pub struct Rule { dimension: String, cutoff: f64, metric: f64, } #[derive(Debug)] pub struct Decision { rule: Option<Rule>, confidence: f64, prediction: String, }
Predicting class at a given node
Here I will describe how each tree node will predict a class from the training dataset
In the training stage, at each node we have an associated subset of the original training set.
The predicted category will be the most populated; we can associate a confidence with this prediction by estimating the probability of finding a sample in the predicted class
pub fn predict_majority_dataframe<'a>( data: & 'a DataFrame, target: & str ) -> PolarsResult<Decision> { // extract the categorical target column let labels = data.column(target)?.categorical()?; let total = labels.len() as f64; // count all categories and sort them let result_count = labels.value_counts()?; // get the most frequent category let result_cat = result_count.head(Some(1)); // transform the series into a categorical vector let actual_cat = result_cat .column(target)? .categorical()?; // collect all categories as strings let string_cat: Vec<String> = actual_cat .iter_str() .flatten() .map(|name| (*name).into()) .collect(); let probability: Vec<f64>= result_cat .column("counts")? .u32()? .iter() .flatten() .map(|c| (c as f64)/total) .collect(); // return the most common category as a string return Ok( Decision{ rule: None, prediction: string_cat .get(0) .unwrap() .to_string(), confidence: probability .get(0) .unwrap() .to_owned() } ); }
Evaluating the best split point
The best feature and its split point is found by choosing the value that optimize the metric:
- per each feature
- metric is evaluated in all split points
- all result are joined
- select the best pair of feature and split point according to the metric
pub fn evaluate_best_split<'a>( data: & DataFrame, features: & HashSet <& str>, target: & str, ) -> PolarsResult<Rule> { // iteratively evaluate the metric on all features let metrics: PolarsResult<Vec<LazyFrame>> = features .iter() .map(|feature| { Ok(evaluate_metric(& data, feature, target)? .lazy() .with_column(feature.lit().alias("feature"))) }) .collect(); // join all results in a single dataframe let concat_rules = UnionArgs { parallel: true, rechunk: true, to_supertypes: true, }; let concat_metrics: DataFrame = concat(metrics?, concat_rules)?.collect()?; // search for the best split let expr: Expr = col("metrics").lt_eq(col("metrics").min()); let best_split: DataFrame = concat_metrics .clone() .lazy() .filter(expr) .select([col("feature"), col("split"), col("metrics")]) .collect()?; let chosen_features: Vec<String> = best_split .column("feature")? .str()? .iter() .flatten() .map(|name| <& str as Into<String>>::into(name)) .collect(); let chosen_split_point: f64 = best_split.column("split")?.f64()?.get(0).unwrap(); let split_metric: f64 = best_split.column("metrics")?.f64()?.get(0).unwrap(); Ok(Rule { dimension: chosen_features .get(0) .unwrap() .to_string(), cutoff: chosen_split_point, metric: split_metric, }) }
Stopping rules
We are going to create a tree using a greedy algorithm, i.e. one node at a time, recursively; while this does not guarantees to have the best possible result, it makes this problem treatable.
Until when should this iteration go?
I’d like to implement three basic stopping rules:
- the current node contains one class only
- the current level is equal to the maximum depth provided by the user
- the current node contains less elements than the minimum decided by the user
It is reasonable to split multiple time along the same axis for continuous features, but I’d like to see the effect of dropping a feature once used so I will leave this as a build option
if (!current_features.is_empty()) && // exhausted features (confidence < 1.0) && // all elements belong to one category (data.shape().0 > self.min_size) && // size is below minimum threshold (level <= self.max_level){ // maximum depth reached
Builder pattern in Rust
Rust does not have optional parameter with default values; to emulate this functionality the “builder” pattern is used.
This pattern consist in the following:
- create a default constructor for your structure which require only mandatory inputs
- add a method per each optional field which receives the actual structure (so takes ownership) and returns it mutated
- this allows to create chains of calls and make sure that there are no other parts of the code which may access the same structure while we are setting it up
In our case we may want to store all the relevant tree creation options: the following are mandatory:
- the names of the features
- the name of the target column
The following are optional:
- the maximum depth of the tree (we may set the default to 3)
- if we want to reuse all features after each split (usually true)
- the minimum size of a dataframe: only larger dataframes will be splitted
#[derive(Debug)] pub struct DTreeBuilder<'a>{ max_level: usize, min_size: usize, features: HashSet<& 'a str>, target: & 'a str, reuse_features: bool } // uses a struct to define trees constraints impl <'a>DTreeBuilder<'a> { pub fn new(features: HashSet<& 'a str>, target : & 'a str) -> DTreeBuilder<'a>{ DTreeBuilder{ max_level: 3, min_size: 1, features, target, reuse_features: true } } pub fn set_max_level(mut self, max_level: usize) -> DTreeBuilder<'a>{ self.max_level = max_level; self } pub fn set_min_size(mut self, min_size: usize) -> DTreeBuilder<'a>{ self.min_size = min_size; self } pub fn set_reuse_features(mut self, reuse_features : bool) -> DTreeBuilder<'a>{ self.reuse_features = reuse_features; self } }
Iterative node building
There is a public access point which receives only the original training dataset
impl <'a>DTreeBuilder<'a> { // ... pub fn build( & self, data: & DataFrame, ) -> PolarsResult<btree::Tree<Decision>> { let current_features = if !self.reuse_features { let feats = self.features.clone(); Some(feats) }else{ None }; println!("{1:->0$}{2:?}{1:-<0$}", 20, "\n", self); let root = self.build_node(data, 1, & current_features)?; Ok(btree::Tree::from_node(root)) } // ... }
Until a stopping condition is met for each node iteratively all features are evaluated to find the most effective split according to our current metrics (Gini impurity index) than
impl <'a>DTreeBuilder<'a> { // ... fn build_node( & self, data: & DataFrame, level: usize, // tree depth features: & Option<HashSet<& str>>, // optionally used to remove features ) -> PolarsResult<btree::Node<Decision>> { let prediction = predict_majority_dataframe(data, self.target)?; let confidence = prediction.confidence; let mut node = btree::Node::new(prediction); let current_features = features.clone().unwrap_or(self.features.clone()); // check stop conditions if (!current_features.is_empty()) && // exhausted features (confidence < 1.0) && // all elements belong to one category (data.shape().0 > self.min_size) && // size is below minimum threshold (level <= self.max_level){ // maximum depth reached let rule = evaluate_best_split(data, & current_features, self.target)?; let higher: DataFrame = data .clone() .lazy() .filter(col(& rule.dimension).gt(rule.cutoff)) .collect()?; let lower: DataFrame = data .clone() .lazy() .filter(col(& rule.dimension).lt_eq(rule.cutoff)) .collect()?; // remove features only if requested by the user let next_features = match features { None => None, Some(feats) => { let mut reduced_features = feats.clone(); reduced_features.remove(rule.dimension.as_str()); let feats_vec: Vec<String> = reduced_features .iter() .map(|s| s.to_string()) .collect(); Some(reduced_features) } }; node.value.rule = Some(rule); // creates leaves node.left = self .build_node(& higher, level + 1, & next_features)? .into(); node.right = self .build_node(& lower, level + 1, & next_features)? .into(); } Ok(node) } // ... }
Dumping the tree
Pre-order depth first traversal
in a previoust post I show how to create a depth first traversal algorithm.
To be more specific it was a post-order traversal: you can find more details about the kind of traversal algorithms in this Wikipedia page.
To draw our tree we now need a pre-order traversal iterator: Caleb Sander Mateos suggested me in a comment how to use a stack to implement this kind of traversal: my code follows
I added some more useful information to the iterator result
- the current node depth
- its number according to the binary position described here
- a boolean describing if the current node is a leaf
pub struct PreOrderTraversalIter<'a, T>{ stack: Vec<TreeStackItem<'a, T>> } pub struct TreeItem<'a, T>{ pub id: usize, pub level: usize, pub value: & 'a T, pub leaf: bool } struct TreeStackItem<'a, T>{ id: usize, level: usize, node: & 'a Node<T> } impl<'a, T> PreOrderTraversalIter<'a, T>{ fn new(tree: & 'a Tree<T>) -> PreOrderTraversalIter<'a, T>{ match tree.root { None => PreOrderTraversalIter { stack: Vec::new() }, Some(ref node) => PreOrderTraversalIter { stack: vec![ TreeStackItem{ id: 1, level: 1, node: & node }], }, } } } impl<'a, T> Iterator for PreOrderTraversalIter<'a, T>{ type Item = TreeItem<'a, T>; fn next(& mut self) -> Option<Self::Item> { if let Some(item) = self.stack.pop() { let mut leaf: bool = true; if let Some(ref left) = item.node.left{ self.stack.push( TreeStackItem{ id: item.id << 1, level: item.level + 1, node: & left }); leaf = false; } if let Some(ref right) = item.node.right{ self.stack.push( TreeStackItem{ id: (item.id << 1) + 1, level: item.level + 1, node: & right }); leaf = false; } Some( TreeItem{ id: item.id, level: item.level, value: & item.node.value, leaf }) }else{ None } } }
creating a Dot DSL reification
This is an example of the chart of a sorting tree:
I chose graphviz to automatically generate a chart of my tree graph, I used a subset of its graph language dot.
In these cases the best way for me to create a language generator is to choose which parts of its grammar to transform into data object; i chose:
- nodes
- edges
- ranks to put nodes at the same level in the same row
A rank is actually a list of node names, i.e. strings, thus a vector of strings should be enough, but we need a specialized representation so I used a wrapper type
struct DotNode{ name: String, label: String, shape: String, style: Option<String>, fillcolor: Option<String> } struct DotEdge{ first: String, second: String, label: String } // wrapper type #[derive(Default)] struct DotRank(Vec<String>); impl DotRank{ fn new() -> DotRank{ DotRank( Vec::new() ) } }
Per each one I created its text representation following dot grammar
impl Display for DotNode{ fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result { let style: String = match self.style { None => "".into(), Some(ref kind) => format!(", style=\"{}\"",kind) }; let fillcolor: String = match self.fillcolor { None => "".into(), Some(ref kind) => format!(", fillcolor=\"{}\"",kind) }; write!(f,"{} [label=\"{}\", shape=\"{}\"{}{}];",self.name, self.label, self.shape, style, fillcolor) } } impl Display for DotEdge{ fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result { write!(f,"{} -> {} [label=\"{}\"]",self.first,self.second,self.label) } } impl Display for DotRank{ fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result { write!(f,"{{rank = same; {};}}",self.0.join("; ")) } }
Finally I created a full object which contains all of these elements:
pub struct Dot{ nodes: Vec<DotNode>, edges: Vec<DotEdge>, ranks: Vec<DotRank> }
To simplify the building I created a method to add each kind of element
impl Dot{ pub fn new() -> Self{ Dot{ nodes: Vec::new(), edges: Vec::new(), ranks: Vec::new(), } } pub fn add_node( & mut self, name: String, label: String, shape: String, style: Option<String>, fillcolor: Option<String> ) -> () { let node = DotNode{ name, label, shape, style, fillcolor }; self.nodes.push(node); } pub fn add_edge( & mut self, first: String, second: String, label: String ) -> () { let node = DotEdge{first, second, label}; self.edges.push(node); } pub fn append_rank( & mut self, index: usize, node: String ) -> (){ //ensure space while self.ranks.len() <= index { self.ranks.push(DotRank::new()) } // update the rank at index adding the node let mut bin = take(& mut self.ranks[index]); bin.0.push(node); let _ = replace(& mut self.ranks[index], bin); } }
finally its transformation into a string
impl Display for Dot{ fn fmt(& self, f: & mut fmt::Formatter) -> fmt::Result { let mut graph: Vec<String>=vec!["digraph {".into(),"rankdir = BT;".into(),"subgraph{".into()]; for node in & self.nodes{ graph.push(node.to_string()); } for edge in & self.edges{ graph.push(edge.to_string()); } for rank in & self.ranks{ graph.push(rank.to_string()); } graph.push("}".into()); graph.push("}".into()); write!(f,"{}",graph.join("\n")) } }
We are now finally able to generate the graph we saw at the beginning of the article
digraph { rankdir = BT; subgraph{ node1 [label="petal_length > 2.45e0", shape="box"]; node3 [label="Setosa 1", shape="box", style="rounded,filled", fillcolor="green"]; node2 [label="petal_width > 1.75e0", shape="box"]; node5 [label="petal_length > 4.95e0", shape="box"]; node11 [label="petal_width > 1.65e0", shape="box"]; node23 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"]; node22 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"]; node10 [label="petal_width > 1.55e0", shape="box"]; node21 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"]; node20 [label="sepal_length > 6.95e0", shape="box"]; node41 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"]; node40 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"]; node4 [label="petal_length > 4.85e0", shape="box"]; node9 [label="sepal_length > 5.95e0", shape="box"]; node19 [label="Versicolor 1", shape="box", style="rounded,filled", fillcolor="green"]; node18 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"]; node8 [label="Virginica 1", shape="box", style="rounded,filled", fillcolor="green"]; node1 -> node3 [label="no"] node1 -> node2 [label="yes"] node2 -> node5 [label="no"] node5 -> node11 [label="no"] node11 -> node23 [label="no"] node11 -> node22 [label="yes"] node5 -> node10 [label="yes"] node10 -> node21 [label="no"] node10 -> node20 [label="yes"] node20 -> node41 [label="no"] node20 -> node40 [label="yes"] node2 -> node4 [label="yes"] node4 -> node9 [label="no"] node9 -> node19 [label="no"] node9 -> node18 [label="yes"] node4 -> node8 [label="yes"] {rank = same; node1;} {rank = same; node3; node2;} {rank = same; node5; node4;} {rank = same; node11; node10; node9; node8;} {rank = same; node23; node22; node21; node20; node19; node18;} {rank = same; node41; node40;} } }
Some effects of tree creation options
Here I show the effect of a couple of options
- not reusing features (quite uncommon for this kind of tree)
- limiting tree depth
Not reusing features
This is typically used when using categorical features. A popular approach is to transform these feature using one-hot encoding, so each column contains only a boolean value. In this case it does not make sense to find more than one split value.
Limiting tree depth
In this case we allow each branch to keep splitting until a fixed limit.
More sophisticated approaches are possible like to to prune each branch dynamically.
Conclusions?
We still need to evaluate the performance of our tree using our data and other implementations too.
In order to complete this step we need an effective way to
- create predictions from some unknown data
- cross validate hyperparameters tuning
But these are subject for another post