rpmjp/projects/communityshield/shap_explainer.py
CompletedMay – August 2025
CommunityShield
ML-powered crime pattern explorer for Chicago. 8.5M rows, 4 XGBoost models with SHAP explanations, beat-level heatmap, and an honest methodology page about what the data can and cannot tell you.
Python 3.12FastAPIPostgreSQL 16PostGISXGBoostSHAPReact 19MapLibre GL
Languages
TypeScript52.4%
Python41.8%
CSS3.2%
Other2.6%
shap_explainer.py
"""SHAP explainers for the binary XGBoost models.
Uses TreeSHAP — fast, exact for tree ensembles. Explains a single prediction
by attributing the log-odds delta from baseline to each feature.
For the 4-class hierarchical model we skip SHAP for now (combining
supercategory × subtype attributions is non-trivial and adds little
interpretive value over showing the binary explanations).
Key design decision below: we cache explainers by id() of the booster
instance, NOT by model_id. Boosters are loaded once at FastAPI startup and
live for the process lifetime, so id() is stable. Caching avoids
reconstructing the TreeExplainer on every prediction request (which would
add ~200ms of latency to an otherwise sub-50ms endpoint).
"""
from __future__ import annotations
from typing import Any
import numpy as np
import shap
# Human-readable labels for the encoded features. The frontend shows these
# instead of raw column names — "Hour of day" reads better than "hour" in
# the SHAP attribution panel.
FEATURE_LABELS = {
"hour": "Hour of day",
"day_of_week": "Day of week",
"month": "Month",
"is_weekend": "Weekend",
"quarter": "Quarter",
"shift": "Shift (day/evening/night)",
"beat_num": "Beat",
"community_area": "Community area",
"latitude": "Latitude",
"longitude": "Longitude",
"district_enc": "Police district",
"location_enc": "Location type",
"type_enc": "Crime type",
}
# Module-level cache keyed by id(booster). Boosters are loaded once at
# startup via the FastAPI lifespan context, so id() is stable for the
# process lifetime. Caching here avoids rebuilding the explainer per request.
_EXPLAINER_CACHE: dict[int, shap.TreeExplainer] = {}
def build_explainer(booster) -> shap.TreeExplainer:
"""Return a TreeExplainer, caching per booster instance."""
key = id(booster)
if key in _EXPLAINER_CACHE:
return _EXPLAINER_CACHE[key]
explainer = shap.TreeExplainer(booster)
_EXPLAINER_CACHE[key] = explainer
return explainer
def explain_binary_prediction(
booster, X: np.ndarray, feature_cols: list[str]
) -> dict[str, Any]:
"""Compute SHAP values for one row and return a structured explanation.
Returns:
{
"base_value": float (model's average log-odds output),
"prediction_value": float (this row's log-odds output),
"contributions": [
{"feature": str, "label": str, "value": float (input), "shap": float},
...
] sorted by abs(shap) descending
}
The structured output ships straight to the frontend's ExplanationPanel
component, which renders the top contributions as a horizontal bar chart
with the human-readable label, the input value, and the SHAP magnitude.
"""
explainer = build_explainer(booster)
# X is shape (1, n_features). SHAP values for binary XGBoost: shape (1, n_features).
# Some XGBoost binary configurations return a list — handle both shapes.
shap_values = explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[1] if len(shap_values) > 1 else shap_values[0]
shap_row = shap_values[0]
# Base value = expected log-odds output of the model across the training set.
# The prediction value = base + sum(SHAP contributions). This invariant is
# what makes SHAP exact and explainable: every feature's contribution sums
# to exactly the deviation from baseline.
base = explainer.expected_value
if isinstance(base, (list, np.ndarray)):
base = float(base[-1] if hasattr(base, "__len__") and len(base) > 0 else base)
else:
base = float(base)
contributions = []
for i, col in enumerate(feature_cols):
contributions.append({
"feature": col,
"label": FEATURE_LABELS.get(col, col),
"value": float(X[0][i]),
"shap": float(shap_row[i]),
})
# Sort by absolute SHAP value so the top of the list is always the most
# influential feature, whether it pushed the prediction up or down.
contributions.sort(key=lambda c: abs(c["shap"]), reverse=True)
return {
"base_value": base,
"prediction_value": base + float(shap_row.sum()),
"contributions": contributions,
}