rpmjp/portfolio
rpmjp/projects/communityshield/shap_explainer.py
CompletedMay – August 2025

CommunityShield

ML-powered crime pattern explorer for Chicago. 8.5M rows, 4 XGBoost models with SHAP explanations, beat-level heatmap, and an honest methodology page about what the data can and cannot tell you.

Python 3.12FastAPIPostgreSQL 16PostGISXGBoostSHAPReact 19MapLibre GL
Languages
TypeScript52.4%
Python41.8%
CSS3.2%
Other2.6%
shap_explainer.py
"""SHAP explainers for the binary XGBoost models.

Uses TreeSHAP — fast, exact for tree ensembles. Explains a single prediction
by attributing the log-odds delta from baseline to each feature.

For the 4-class hierarchical model we skip SHAP for now (combining
supercategory × subtype attributions is non-trivial and adds little
interpretive value over showing the binary explanations).

Key design decision below: we cache explainers by id() of the booster
instance, NOT by model_id. Boosters are loaded once at FastAPI startup and
live for the process lifetime, so id() is stable. Caching avoids
reconstructing the TreeExplainer on every prediction request (which would
add ~200ms of latency to an otherwise sub-50ms endpoint).
"""
from __future__ import annotations

from typing import Any

import numpy as np
import shap


# Human-readable labels for the encoded features. The frontend shows these
# instead of raw column names — "Hour of day" reads better than "hour" in
# the SHAP attribution panel.
FEATURE_LABELS = {
    "hour": "Hour of day",
    "day_of_week": "Day of week",
    "month": "Month",
    "is_weekend": "Weekend",
    "quarter": "Quarter",
    "shift": "Shift (day/evening/night)",
    "beat_num": "Beat",
    "community_area": "Community area",
    "latitude": "Latitude",
    "longitude": "Longitude",
    "district_enc": "Police district",
    "location_enc": "Location type",
    "type_enc": "Crime type",
}


# Module-level cache keyed by id(booster). Boosters are loaded once at
# startup via the FastAPI lifespan context, so id() is stable for the
# process lifetime. Caching here avoids rebuilding the explainer per request.
_EXPLAINER_CACHE: dict[int, shap.TreeExplainer] = {}


def build_explainer(booster) -> shap.TreeExplainer:
    """Return a TreeExplainer, caching per booster instance."""
    key = id(booster)
    if key in _EXPLAINER_CACHE:
        return _EXPLAINER_CACHE[key]
    explainer = shap.TreeExplainer(booster)
    _EXPLAINER_CACHE[key] = explainer
    return explainer


def explain_binary_prediction(
    booster, X: np.ndarray, feature_cols: list[str]
) -> dict[str, Any]:
    """Compute SHAP values for one row and return a structured explanation.

    Returns:
        {
            "base_value": float (model's average log-odds output),
            "prediction_value": float (this row's log-odds output),
            "contributions": [
                {"feature": str, "label": str, "value": float (input), "shap": float},
                ...
            ] sorted by abs(shap) descending
        }

    The structured output ships straight to the frontend's ExplanationPanel
    component, which renders the top contributions as a horizontal bar chart
    with the human-readable label, the input value, and the SHAP magnitude.
    """
    explainer = build_explainer(booster)

    # X is shape (1, n_features). SHAP values for binary XGBoost: shape (1, n_features).
    # Some XGBoost binary configurations return a list — handle both shapes.
    shap_values = explainer.shap_values(X)
    if isinstance(shap_values, list):
        shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
    shap_row = shap_values[0]

    # Base value = expected log-odds output of the model across the training set.
    # The prediction value = base + sum(SHAP contributions). This invariant is
    # what makes SHAP exact and explainable: every feature's contribution sums
    # to exactly the deviation from baseline.
    base = explainer.expected_value
    if isinstance(base, (list, np.ndarray)):
        base = float(base[-1] if hasattr(base, "__len__") and len(base) > 0 else base)
    else:
        base = float(base)

    contributions = []
    for i, col in enumerate(feature_cols):
        contributions.append({
            "feature": col,
            "label": FEATURE_LABELS.get(col, col),
            "value": float(X[0][i]),
            "shap": float(shap_row[i]),
        })
    # Sort by absolute SHAP value so the top of the list is always the most
    # influential feature, whether it pushed the prediction up or down.
    contributions.sort(key=lambda c: abs(c["shap"]), reverse=True)

    return {
        "base_value": base,
        "prediction_value": base + float(shap_row.sum()),
        "contributions": contributions,
    }