Sowing a (Decision) Tree

ramona-edwards-uKyj-X5qLtU-unsplash_reduced.jpg

Photo by Ramona Edwards on Unsplash

In this post I start to build a decision tree in Rust. The complete description will span across several eventual posts.

Decision trees are used for classification or regression and may accept categorical or continuous features: in this example I will start a classification decision tree which accepts continuous variables.

The algorithm will be greedy i.e. I will build one level at a time by choosing the most effective split across all features.

In order to simplify the evaluation of this code against existing implementation (e.g. scikit-learn) I will use a well known dataset: the Iris dataset.

The code for this post is avaliable here

Loading data and choosing the environment

For this experiment I chose the Polars crate to manage data loading and manipulation. While there are different ways to read data from a file, the main reasons that led me to this decision are the following:

  • In this algorithm I need to access the dataset features in a simple way, choosing from a list of column names;
  • I need to filter the dataset iteratively; moreover I’d like to avoid duplicating data if possible within this process. Polars provide a nice way to share dataframes and can be filtered using reified filters in the Lazy API.

There are however some cons in choosing this excellent crate:

  • Series hide the type of the data inside, so there are multiple places where I have to manage possible errors while I know the data type in advance: I decided to bubble up all errors until the main function
  • it is very big crate respect to the small example I’m trying to build: in this case I use a small fraction of the functionalities
  • Polars are designed to have exceptional performance with large dataset

For a quick experiment pros win cons, but I may consider a smaller solution in specific future projects.

Evaluation of the most effective split

Literature suggest two possible metrics to evaluate the best split: Gini’s impurity index or Shannon’s information etropy gain.

Let’s start with Gini impurity index: this is equivalent to the probability of misclassification of a sample, i.e. the probability that extracted a sample which belongs to a given category it is randomly assigned to any other available category.

G = \sum_{c \in C}P(x|c)\sum_{k \neq c}P(x|k)

as

1 - P(x|c) = \sum_{k \neq c}P(x|k)

we have

G = 1 - \sum_{c \in C}P(x|c)^2

// Gini impurity metric
fn estimate_gini(data: & DataFrame, target: & str) -> PolarsResult<f64> {
    let label_count: DataFrame = data
        .column(target)?
        .categorical()?
        .value_counts()?;

    let expr: Expr = (col("counts")
        .cast(DataType::Float64)
        / col("counts").sum())
        .pow(2)
        .alias("squares");

    let squared: DataFrame = label_count
        .lazy()
        .select([expr])
        .collect()?;

    let square_sum: f64 = squared
        .column("squares")?
        .sum()?;

    Ok(1.0 - square_sum)
}

As a first implementation I will calculate this metric splitting the dataset in all possible ways along a given feature. I expect to optimize this step in the future. Moreover in this post I assume the feature has no missing values: I will address this in future posts as well.

fn evaluate_metric(data: & DataFrame, feature: & str, target: & str) -> PolarsResult<DataFrame> {
    // grabs the unique values
    let values = data.column(feature)?;
    let unique = values.unique()?;

    // create a lagged column to identify split points
    let split = df!(feature => unique)?
        .lazy()
        .with_columns([(
            (col(feature) + col(feature).shift(lit(-1))) /
                lit(2.0)).alias("split")
        ])
        .collect()?;
    let split_values : Vec<f64> = split
        .column("split")?
        .f64()?
        .iter()
        .flatten() // drop missing values created by lag
        .collect();

    // iterate over split points
    let metrics: PolarsResult<Series> = split_values
        .iter()
        .map(|sp| {
            // split dataframe
            let higher = data.clone().filter(& values.gt_eq(*sp)?)?;
            let lower = data.clone().filter(& values.lt(*sp)?)?;

            // calculate metrics
            let higher_metric = estimate_gini(& higher, target)?;
            let lower_metric = estimate_gini(& lower, target)?;

            Ok(
                ((higher.shape().0 as f64) * higher_metric
                 + (lower.shape().0 as f64) * lower_metric)
                    / (values.len() as f64),
            )
        })
        .collect();

    // return a dataframe with a metric evaluation
    // for each split point
    return Ok(df!(
        "split" => Series::new("split", split_values),
        "metrics" => metrics?,
    )?);
}

Here are the plots of the metric at the root node: it appears that some metrics have more than one local minimum

petal_length_plot.png

petal_width_plot.png

sepal_length_plot.png

sepal_width_plot.png

Predicting a category

Given the dataset associated to a decision tree node we should find a way to return the predicted class: it can be done by choosing the most populated class.

In case of equally populated class just grab the first one it finds. In this implementation I do not return the probability, but in I will add this in the next posts.

fn predict_majority_dataframe(data: & DataFrame, target: & str) -> PolarsResult<String>{
    // extract the categorical target column
    let labels = data
        .column(target)?
        .categorical()?;

    // count all categories and sort them
    let result_count = labels.value_counts()?;
    println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",result_count);

    // get the most frequent category
    let result_cat = result_count
        .column(target)?
        .head(Some(1));
    println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",result_cat);

    // transform the series into a categorical vector
    let actual_cat= result_cat
        .categorical()?;

    // collect all categories as strings
    let string_cat: Vec<String>=actual_cat
        .iter_str()
        .flatten()
        .map(|name| (*name).into())
        .collect();
    println!("{1:->0$}{2:?}{1:-<0$}",20,"\n",string_cat);

    // return the most common category as a string
    return Ok(string_cat.get(0)
        .unwrap()
        .deref()
        .into());
}

marco.p.v.vezzoli

Self taught assembler programming at 11 on my C64 (1983). Never stopped since then -- always looking up for curious things in the software development, data science and AI. Linux and FOSS user since 1994. MSc in physics in 1996. Working in large semiconductor companies since 1997 (STM, Micron) developing analytics and full stack web infrastructures, microservices, ML solutions

You may also like...