Comparing Quantile Interpolation Methods#

This example illustrates different interpolation methods that can be used during prediction in quantile regression forests (QRF). When a desired quantile lies between two data points, interpolation methods determine the predicted value. In this toy example, the QRF creates a split that divides the samples into two groups (samples 1–3 and samples 4–5), with quantiles calculated separately for each. The interpolation methods demonstrate how predictions are handled when a quantile does not exactly match a data point.

import altair as alt
import numpy as np
import pandas as pd

from quantile_forest import RandomForestQuantileRegressor

random_state = np.random.RandomState(0)
intervals = np.linspace(0, 1, num=101, endpoint=True).round(2).tolist()

# Create a simple toy dataset.
X = np.array([[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1]])
y = np.array([-2, -1, 0, 1, 2])

# We use a single estimator that retains all leaf samples and is trained without bootstrap.
# By construction of the data, this leads to samples split between two terminal leaf nodes.
qrf = RandomForestQuantileRegressor(
    n_estimators=1,
    max_samples_leaf=None,
    bootstrap=False,
    random_state=random_state,
)
qrf.fit(X, y)

interpolations = {
    "Linear": "#006aff",
    "Lower": "#ffd237",
    "Higher": "#0d4599",
    "Midpoint": "#f2a619",
    "Nearest": "#a6e5ff",
}

legend = {"Actual": "#000000"}
legend.update(interpolations)

dfs = []
for idx, interval in enumerate(intervals):
    # Initialize data with actual values.
    data = {
        "method": ["Actual"] * len(y),
        "X": [f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())],
        "y_pred": y.tolist(),
        "y_pred_low": [None] * len(y),
        "y_pred_upp": [None] * len(y),
        "quantile_low": [None] * len(y),
        "quantile_upp": [None] * len(y),
    }

    # Make predictions at the median and intervals.
    quantiles = [0.5, round(0.5 - interval / 2, 3), round(0.5 + interval / 2, 3)]

    # Populate data based on prediction results with different interpolations.
    for interpolation in interpolations:
        # Get predictions using the specified quantiles and interpolation method.
        y_pred = qrf.predict(X, quantiles=quantiles, interpolation=interpolation.lower())

        data["method"].extend([interpolation] * len(y))
        data["X"].extend([f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())])
        data["y_pred"].extend(y_pred[:, 0])
        data["y_pred_low"].extend(y_pred[:, 1])
        data["y_pred_upp"].extend(y_pred[:, 2])
        data["quantile_low"].extend([quantiles[1]] * len(y))
        data["quantile_upp"].extend([quantiles[2]] * len(y))

    df_i = pd.DataFrame(data)
    dfs.append(df_i)
df = pd.concat(dfs, ignore_index=True)


def plot_interpolation_predictions(df, legend):
    """Plot predictions by quantile interpolation methods."""
    # Slider for varying the prediction interval that determines the quantiles being interpolated.
    slider = alt.binding_range(name="Prediction Interval: ", min=0, max=1, step=0.01)
    interval_val = alt.param(name="interval", value=0.8, bind=slider)

    click = alt.selection_point(bind="legend", fields=["method"], on="click")

    color = alt.condition(
        click,
        alt.Color("method:N", sort=list(legend.keys()), title=None),
        alt.value("lightgray"),
    )

    tooltip = [
        alt.Tooltip("method:N", title="Method"),
        alt.Tooltip("X:N", title="X Values"),
        alt.Tooltip("y_pred:Q", format=",.3f", title="Predicted Y"),
        alt.Tooltip("y_pred_low:Q", format=",.3f", title="Predicted Lower Y"),
        alt.Tooltip("y_pred_upp:Q", format=",.3f", title="Predicted Upper Y"),
        alt.Tooltip("quantile_low:Q", format=".3f", title="Lower Quantile"),
        alt.Tooltip("quantile_upp:Q", format=".3f", title="Upper Quantile"),
    ]

    bar_pred = (
        alt.Chart(df)
        .mark_bar()
        .encode(
            x=alt.X(
                "method:N",
                axis=alt.Axis(labels=False, tickSize=0),
                sort=list(legend.keys()),
                title=None,
            ),
            y=alt.Y("y_pred_low:Q", title=""),
            y2=alt.Y2("y_pred_upp:Q", title=None),
            color=color,
            tooltip=tooltip,
        )
    )

    circle_pred = (
        alt.Chart(df, width=alt.Step(20))
        .mark_circle(opacity=1, size=75)
        .encode(
            x=alt.X(
                "method:N",
                axis=alt.Axis(labels=False, tickSize=0),
                sort=list(legend.keys()),
                title=None,
            ),
            y=alt.Y("y_pred:Q", title="Actual and Predicted Values"),
            color=color,
            tooltip=tooltip,
        )
    )

    chart = (
        (bar_pred + circle_pred)
        .add_params(interval_val, click)
        .transform_filter(
            "(datum.method == 'Actual')"
            "| (datum.quantile_low == round((0.5 - interval / 2) * 1000) / 1000)"
            "| (datum.quantile_upp == round((0.5 + interval / 2) * 1000) / 1000)"
        )
        .properties(height=400)
        .facet(
            column=alt.Column(
                "X:N",
                header=alt.Header(labelOrient="bottom", titleOrient="bottom"),
                title="Samples (Feature Values)",
            ),
            title="QRF Predictions by Quantile Interpolation on Toy Dataset",
        )
        .configure_facet(spacing=15)
        .configure_range(category=alt.RangeScheme(list(legend.values())))
        .configure_scale(bandPaddingInner=0.9)
        .configure_title(anchor="middle")
        .configure_view(stroke=None)
    )

    return chart


chart = plot_interpolation_predictions(df, legend)
chart