Sowing a (Decision) Tree
Photo by Ramona Edwards on Unsplash
In this post I start to build a decision tree in Rust. The complete description will span across several eventual posts.
Decision trees are used for classification or regression and may accept categorical or continuous features: in this example I will start a classification decision tree which accepts continuous variables.
The algorithm will be greedy i.e. I will build one level at a time by choosing the most effective split across all features.
In order to simplify the evaluation of this code against existing implementation (e.g. scikit-learn) I will use a well known dataset: the Iris dataset.
The code for this post is avaliable here
Loading data and choosing the environment
For this experiment I chose the Polars crate to manage data loading and manipulation. While there are different ways to read data from a file, the main reasons that led me to this decision are the following:
- In this algorithm I need to access the dataset features in a simple way, choosing from a list of column names;
- I need to filter the dataset iteratively; moreover I’d like to avoid duplicating data if possible within this process. Polars provide a nice way to share dataframes and can be filtered using reified filters in the Lazy API.
There are however some cons in choosing this excellent crate:
- Series hide the type of the data inside, so there are multiple places where I have to manage possible errors while I know the data type in advance: I decided to bubble up all errors until the main function
- it is very big crate respect to the small example I’m trying to build: in this case I use a small fraction of the functionalities
- Polars are designed to have exceptional performance with large dataset
For a quick experiment pros win cons, but I may consider a smaller solution in specific future projects.
Evaluation of the most effective split
Literature suggest two possible metrics to evaluate the best split: Gini’s impurity index or Shannon’s information etropy gain.
Let’s start with Gini impurity index: this is equivalent to the probability of misclassification of a sample, i.e. the probability that extracted a sample which belongs to a given category it is randomly assigned to any other available category.
as
we have
// Gini impurity metric fn estimate_gini(data: & DataFrame, target: & str) -> PolarsResult<f64> { let label_count: DataFrame = data .column(target)? .categorical()? .value_counts()?; let expr: Expr = (col("counts") .cast(DataType::Float64) / col("counts").sum()) .pow(2) .alias("squares"); let squared: DataFrame = label_count .lazy() .select([expr]) .collect()?; let square_sum: f64 = squared .column("squares")? .sum()?; Ok(1.0 - square_sum) }
As a first implementation I will calculate this metric splitting the dataset in all possible ways along a given feature. I expect to optimize this step in the future. Moreover in this post I assume the feature has no missing values: I will address this in future posts as well.
fn evaluate_metric(data: & DataFrame, feature: & str, target: & str) -> PolarsResult<DataFrame> { // grabs the unique values let values = data.column(feature)?; let unique = values.unique()?; // create a lagged column to identify split points let split = df!(feature => unique)? .lazy() .with_columns([( (col(feature) + col(feature).shift(lit(-1))) / lit(2.0)).alias("split") ]) .collect()?; let split_values : Vec<f64> = split .column("split")? .f64()? .iter() .flatten() // drop missing values created by lag .collect(); // iterate over split points let metrics: PolarsResult<Series> = split_values .iter() .map(|sp| { // split dataframe let higher = data.clone().filter(& values.gt_eq(*sp)?)?; let lower = data.clone().filter(& values.lt(*sp)?)?; // calculate metrics let higher_metric = estimate_gini(& higher, target)?; let lower_metric = estimate_gini(& lower, target)?; Ok( ((higher.shape().0 as f64) * higher_metric + (lower.shape().0 as f64) * lower_metric) / (values.len() as f64), ) }) .collect(); // return a dataframe with a metric evaluation // for each split point return Ok(df!( "split" => Series::new("split", split_values), "metrics" => metrics?, )?); }
Here are the plots of the metric at the root node: it appears that some metrics have more than one local minimum
Predicting a category
Given the dataset associated to a decision tree node we should find a way to return the predicted class: it can be done by choosing the most populated class.
In case of equally populated class just grab the first one it finds. In this implementation I do not return the probability, but in I will add this in the next posts.
fn predict_majority_dataframe(data: & DataFrame, target: & str) -> PolarsResult<String>{ // extract the categorical target column let labels = data .column(target)? .categorical()?; // count all categories and sort them let result_count = labels.value_counts()?; println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",result_count); // get the most frequent category let result_cat = result_count .column(target)? .head(Some(1)); println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",result_cat); // transform the series into a categorical vector let actual_cat= result_cat .categorical()?; // collect all categories as strings let string_cat: Vec<String>=actual_cat .iter_str() .flatten() .map(|name| (*name).into()) .collect(); println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",string_cat); // return the most common category as a string return Ok(string_cat.get(0) .unwrap() .deref() .into()); }