Quantile Regression Forests vs. Random Forests#

This example compares the predictions generated by a quantile regression forest (QRF) and a standard random forest regressor (RF) on a synthetic right-skewed dataset. In a right-skewed distribution, the mean is to the right of the median. This example demonstrates how the median (quantile = 0.5) predicted by a quantile regressor (QRF) can be a more reliable estimator than the mean predicted by a standard random forest when dealing with skewed distributions.

import altair as alt
import numpy as np
import pandas as pd
import scipy as sp
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_random_state

from quantile_forest import RandomForestQuantileRegressor

random_state = np.random.RandomState(0)
n_samples = 5000
quantiles = np.linspace(0, 1, num=101, endpoint=True).round(2).tolist()

def make_skewed_dataset(n_samples, a=7, loc=-1, scale=1, random_state=None):
    """Make a skewed dataset."""
    random_state = check_random_state(random_state)
    skewnorm_rv = sp.stats.skewnorm(a, loc, scale)
    skewnorm_rv.random_state = random_state
    y = skewnorm_rv.rvs(n_samples)
    X = random_state.randn(n_samples, 2) * y.reshape(-1, 1)
    return X, y

# Create a right-skewed toy dataset.
X, y = make_skewed_dataset(n_samples, a=7, loc=-1, scale=1, random_state=0)

regr_rf = RandomForestRegressor(random_state=random_state)
regr_qrf = RandomForestQuantileRegressor(random_state=random_state)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)

regr_rf.fit(X_train, y_train)
regr_qrf.fit(X_train, y_train)

y_pred_rf = regr_rf.predict(X_test)  # standard RF predictions (mean)
y_pred_qrf = regr_qrf.predict(X_test, quantiles=quantiles)  # QRF predictions (quantiles)

legend = {
    "Actual": "#c0c0c0",
    "RF (Mean)": "#f2a619",
    "QRF (Median)": "#006aff",

df = pd.DataFrame(
        "actual": y_test,
        "rf": y_pred_rf,
        **{f"qrf_{q_i:.3g}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred_qrf.T)},

def plot_prediction_histograms(df, legend):
    """Plot histogram of predictions by model."""
    # Slider for varying the quantile value used for generating the QRF histogram.
    slider = alt.binding_range(
        name="Predicted Quantile: ",
        step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
    quantile_val = alt.param(name="quantile", value=0.5, bind=slider)

    click = alt.selection_point(bind="legend", fields=["label"], on="click")

    chart = (
        .add_params(quantile_val, click)
        .transform_calculate(qrf_col="'qrf_' + quantile")
        .transform_calculate(calculate="round(datum.actual * 10) / 10", as_="Actual")
        .transform_calculate(calculate="round(datum.rf * 10) / 10", as_="RF (Mean)")
        .transform_calculate(calculate="round(datum.qrf * 10) / 10", as_="QRF (Quantile)")
        .transform_fold(["Actual", "RF (Mean)", "QRF (Quantile)"], as_=["label", "value"])
                    labelExpr="datum.value % 0.5 == 0 ? datum.value : null",
                title="Actual and Predicted Target Values",
            y=alt.Y("count():Q", axis=alt.Axis(format=",d", title="Counts")),
                alt.Color("label:N", sort=list(legend.keys()), title=None),
            opacity=alt.condition(click, alt.value(1), alt.value(0.5)),
                alt.Tooltip("label:N", title="Label"),
                alt.Tooltip("value:Q", title="Value (binned)"),
                alt.Tooltip("count():Q", format=",d", title="Counts"),
            title="Distribution of QRF vs. RF Predictions on Right-Skewed Distribution",
    return chart

chart = plot_prediction_histograms(df, legend)