{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example usage of `error-parity` with other fairness-constrained classifiers\n", "\n", "Contents:\n", "1. Train a standard (unconstrained) model;\n", "2. Check attainable fairness-accuracy trade-offs via post-processing, with the `error-parity` package;\n", "3. Train fairness-constrained model (in-processing fairness intervention), with the `fairlearn` package;\n", "5. Map results for post-processing + in-processing interventions;\n", "\n", "---\n", "\n", "**NOTE**: This notebook has the following extra requirements: `fairlearn` `lightgbm`.\n", "\n", "Install them with ```pip install fairlearn lightgbm```" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "error-parity==0.3.11\n" ] } ], "source": [ "from error_parity import __version__\n", "print(f\"error-parity=={__version__}\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "import seaborn as sns\n", "sns.set(palette=\"colorblind\", style=\"whitegrid\", rc={\"grid.linestyle\": \"--\", \"figure.dpi\": 200, \"figure.figsize\": (4,3)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some useful global constants:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "SEED = 2\n", "\n", "TEST_SIZE = 0.3\n", "VALIDATION_SIZE = None\n", "\n", "PERF_METRIC = \"accuracy\"\n", "DISP_METRIC = \"equalized_odds_diff\"\n", "\n", "N_JOBS = max(2, os.cpu_count() - 2)\n", "\n", "np.random.seed(SEED)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fetch UCI Adult data\n", "\n", "We'll use the `sex` column as the sensitive attribute.\n", "That is, false positive (FP) and false negative (FN) errors should not disproportionately impact individuals based on their sex." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "SENSITIVE_COL = \"sex\"\n", "sensitive_col_map = {\"Male\": 0, \"Female\": 1}\n", "\n", "# NOTE: You can also try to run this using the `race` column as sensitive attribute (as commented below).\n", "# SENSITIVE_COL = \"race\"\n", "# sensitive_col_map = {\"White\": 0, \"Black\": 1, \"Asian-Pac-Islander\": 1, \"Amer-Indian-Eskimo\": 1, \"Other\": 1}\n", "\n", "sensitive_col_inverse = {val: key for key, val in sensitive_col_map.items()}\n", "\n", "POS_LABEL = \">50K\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download data." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "from fairlearn.datasets import fetch_adult\n", "\n", "X, Y = fetch_adult(\n", " as_frame=True,\n", " return_X_y=True,\n", ")\n", "\n", "# Map labels and sensitive column to numeric data\n", "Y = np.array(Y == POS_LABEL, dtype=int)\n", "S = np.array([sensitive_col_map[elem] for elem in X[SENSITIVE_COL]], dtype=int)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split in train/test/validation data." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "X_train, X_other, y_train, y_other, s_train, s_other = train_test_split(\n", " X, Y, S,\n", " test_size=TEST_SIZE + (VALIDATION_SIZE or 0),\n", " stratify=Y, random_state=SEED,\n", ")\n", "\n", "if VALIDATION_SIZE is not None and VALIDATION_SIZE > 0:\n", " X_val, X_test, y_val, y_test, s_val, s_test = train_test_split(\n", " X_other, y_other, s_other,\n", " test_size=TEST_SIZE / (TEST_SIZE + VALIDATION_SIZE),\n", " stratify=y_other, random_state=SEED,\n", " )\n", "else:\n", " X_test, y_test, s_test = X_other, y_other, s_other\n", " X_val, y_val, s_val = X_train, y_train, s_train" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Log the accuracy attainable by a dummy constant classifier." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': 0.7607125098715961,\n", " 'test': 0.7607315908005187,\n", " 'validation': 0.7607125098715961}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def compute_constant_clf_accuracy(labels: np.ndarray) -> float:\n", " return max((labels == const_pred).mean() for const_pred in np.unique(labels))\n", "\n", "constant_clf_accuracy = {\n", " \"train\": compute_constant_clf_accuracy(y_train),\n", " \"test\": compute_constant_clf_accuracy(y_test),\n", " \"validation\": compute_constant_clf_accuracy(y_val),\n", "}\n", "constant_clf_accuracy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train a standard (unconstrained) classifier" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LGBMClassifier(verbosity=-1)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LGBMClassifier(verbosity=-1)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from lightgbm import LGBMClassifier\n", "\n", "unconstr_clf = LGBMClassifier(verbosity=-1)\n", "unconstr_clf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In-processing model: \n", "> accuracy = 0.87\n", "> equalized odds = 0.0673\n", "\n" ] } ], "source": [ "from error_parity.evaluation import evaluate_predictions_bootstrap\n", "\n", "unconstr_test_results = evaluate_predictions_bootstrap(\n", " y_true=y_test,\n", " y_pred_scores=unconstr_clf.predict(X_test, random_state=SEED).astype(float),\n", " sensitive_attribute=s_test,\n", ")\n", "\n", "print(\n", " f\"In-processing model: \\n\"\n", " f\"> accuracy = {unconstr_test_results['accuracy_mean']:.3}\\n\"\n", " f\"> equalized odds = {unconstr_test_results['equalized_odds_diff_mean']:.3}\\n\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Map attainable fairness-accuracy trade-offs via (relaxed) post-processing\n", "\n", "By varying the tolerance (or slack) of the fairness constraint we can map the different trade-offs attainable by the same model (each trade-off corresponds to a different post-processing intervention).\n", "\n", "**Post-processing** fairness methods intervene on the predictions of an already trained model, using different (possibly randomized) thresholds to binarize predictions of different groups.\n", "\n", "We'll be using the [`error-parity`](https://github.com/socialfoundations/error-parity) package [[Cruz and Hardt, 2023]](https://arxiv.org/abs/2306.07261)." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e0d9ac484c1b44dcb852da934dbb5e3a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/19 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from error_parity.plotting import plot_postprocessing_frontier\n", "\n", "# Plot unconstrained model results with 95% CIs\n", "unconstr_performance = unconstr_test_results[f\"{PERF_METRIC}_mean\"]\n", "unconstr_disparity = unconstr_test_results[f\"{DISP_METRIC}_mean\"]\n", "\n", "sns.scatterplot(\n", " x=[unconstr_performance],\n", " y=[unconstr_disparity],\n", " color=\"black\",\n", " marker=\"*\",\n", " s=100,\n", ")\n", "\n", "plt.plot(\n", " (unconstr_test_results[f\"{PERF_METRIC}_low-percentile\"], unconstr_test_results[f\"{PERF_METRIC}_high-percentile\"]),\n", " (unconstr_disparity, unconstr_disparity),\n", " color=\"black\",\n", " ls=\":\",\n", " marker=\"|\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "plt.plot(\n", " (unconstr_performance, unconstr_performance),\n", " (unconstr_test_results[f\"{DISP_METRIC}_low-percentile\"], unconstr_test_results[f\"{DISP_METRIC}_high-percentile\"]),\n", " color=\"black\",\n", " ls=\":\",\n", " marker=\"_\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "# Plot postprocessing of unconstrained model\n", "plot_postprocessing_frontier(\n", " postproc_results_df,\n", " perf_metric=PERF_METRIC,\n", " disp_metric=DISP_METRIC,\n", " show_data_type=SHOW_RESULTS_ON,\n", " constant_clf_perf=constant_clf_accuracy[SHOW_RESULTS_ON],\n", " model_name=r\"$\\bigstar$\",\n", ")\n", "\n", "# Vertical line with minimum \"useful\" accuracy on this data\n", "curr_const_clf_acc = constant_clf_accuracy[SHOW_RESULTS_ON]\n", "plt.axvline(\n", " x=curr_const_clf_acc,\n", " ls=\"--\",\n", " color=\"grey\",\n", ")\n", "plt.gca().annotate(\n", " \"constant predictor acc.\",\n", " xy=(curr_const_clf_acc, ax_kwargs[\"ylim\"][1] / 2),\n", " zorder=10,\n", " rotation=90,\n", " horizontalalignment=\"right\",\n", " verticalalignment=\"center\",\n", " fontsize=\"small\",\n", " \n", ")\n", "\n", "# Title and legend\n", "ax_kwargs[\"title\"] = f\"Post-processing ({SHOW_RESULTS_ON} results)\"\n", "ax_kwargs[\"xlim\"] = (curr_const_clf_acc - 1e-2, 0.885)\n", "\n", "plt.legend(\n", " loc=\"upper left\",\n", " bbox_to_anchor=(1.03, 1),\n", " borderaxespad=0)\n", "\n", "plt.gca().set(**ax_kwargs)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let's train another type of fairness-aware model\n", "\n", "**In-processing** fairness methods introduce fairness criteria during model training.\n", "\n", "_Main disadvantage_: state-of-the-art in-processing methods can be considerably slower to run (e.g., increasing training time by 20-100 times).\n", "\n", "We'll be using the [`fairlearn`](https://github.com/fairlearn/fairlearn) package [[Weerts et al., 2020]](https://arxiv.org/abs/2303.16626)." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "from fairlearn.reductions import ExponentiatedGradient, EqualizedOdds\n", "\n", "inproc_clf = ExponentiatedGradient(\n", " estimator=unconstr_clf,\n", " constraints=EqualizedOdds(),\n", " max_iter=10,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fit the `ExponentiatedGradient` [[Agarwal et al., 2018]](https://proceedings.mlr.press/v80/agarwal18a.html) in-processing intervention (**note**: may take a few minutes to fit)." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 49.1 s, sys: 1min 9s, total: 1min 58s\n", "Wall time: 1min 59s\n" ] }, { "data": { "text/html": [ "
ExponentiatedGradient(constraints=<fairlearn.reductions._moments.utility_parity.EqualizedOdds object at 0x16dac2650>,\n",
       "                      estimator=LGBMClassifier(verbosity=-1), max_iter=10,\n",
       "                      nu=0.000851617415307666)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "ExponentiatedGradient(constraints=,\n", " estimator=LGBMClassifier(verbosity=-1), max_iter=10,\n", " nu=0.000851617415307666)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "inproc_clf.fit(X_train, y_train, sensitive_features=s_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluate in-processing model on test data." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In-processing model: \n", "> accuracy = 0.867\n", "> equalized odds = 0.0498\n", "\n" ] } ], "source": [ "from error_parity.evaluation import evaluate_predictions_bootstrap\n", "\n", "inproc_test_results = evaluate_predictions_bootstrap(\n", " y_true=y_test,\n", " y_pred_scores=inproc_clf.predict(X_test, random_state=SEED).astype(float),\n", " sensitive_attribute=s_test,\n", ")\n", "\n", "print(\n", " f\"In-processing model: \\n\"\n", " f\"> accuracy = {inproc_test_results['accuracy_mean']:.3}\\n\"\n", " f\"> equalized odds = {inproc_test_results['equalized_odds_diff_mean']:.3}\\n\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**We can go one step further and post-process this in-processing model :)**" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e9c3a354bef24dbb84b96781fb28b9ad", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/19 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot unconstrained model results with 95% CIs\n", "sns.scatterplot(\n", " x=[unconstr_performance],\n", " y=[unconstr_disparity],\n", " color=\"black\",\n", " marker=\"*\",\n", " s=100,\n", ")\n", "\n", "plt.plot(\n", " (unconstr_test_results[f\"{PERF_METRIC}_low-percentile\"], unconstr_test_results[f\"{PERF_METRIC}_high-percentile\"]),\n", " (unconstr_disparity, unconstr_disparity),\n", " color=\"black\",\n", " ls=\":\",\n", " marker=\"|\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "plt.plot(\n", " (unconstr_performance, unconstr_performance),\n", " (unconstr_test_results[f\"{DISP_METRIC}_low-percentile\"], unconstr_test_results[f\"{DISP_METRIC}_high-percentile\"]),\n", " color=\"black\",\n", " ls=\":\",\n", " marker=\"_\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "# Plot postprocessing of unconstrained model\n", "plot_postprocessing_frontier(\n", " postproc_results_df,\n", " perf_metric=PERF_METRIC,\n", " disp_metric=DISP_METRIC,\n", " show_data_type=SHOW_RESULTS_ON,\n", " constant_clf_perf=constant_clf_accuracy[SHOW_RESULTS_ON],\n", " model_name=r\"$\\bigstar$\",\n", ")\n", "\n", "# Plot inprocessing intervention results with 95% CIs\n", "sns.scatterplot(\n", " x=[inproc_test_results[f\"{PERF_METRIC}_mean\"]],\n", " y=[inproc_test_results[f\"{DISP_METRIC}_mean\"]],\n", " color=\"red\",\n", " marker=\"P\",\n", " s=50,\n", ")\n", "\n", "plt.plot(\n", " (inproc_test_results[f\"{PERF_METRIC}_low-percentile\"], inproc_test_results[f\"{PERF_METRIC}_high-percentile\"]),\n", " (inproc_test_results[f\"{DISP_METRIC}_mean\"], inproc_test_results[f\"{DISP_METRIC}_mean\"]),\n", " color='red',\n", " ls=\":\",\n", " marker=\"|\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "plt.plot(\n", " (inproc_test_results[f\"{PERF_METRIC}_mean\"], inproc_test_results[f\"{PERF_METRIC}_mean\"]),\n", " (inproc_test_results[f\"{DISP_METRIC}_low-percentile\"], inproc_test_results[f\"{DISP_METRIC}_high-percentile\"]),\n", " color='red',\n", " ls=\":\",\n", " marker=\"_\",\n", " lw=1,\n", " ms=3,\n", ")\n", "\n", "# Plot postprocessing of inprocessing model\n", "plot_postprocessing_frontier(\n", " inproc_postproc_results_df,\n", " perf_metric=PERF_METRIC,\n", " disp_metric=DISP_METRIC,\n", " show_data_type=SHOW_RESULTS_ON,\n", " constant_clf_perf=constant_clf_accuracy[SHOW_RESULTS_ON],\n", " model_name=r\"$+$\",\n", " color=\"red\",\n", ")\n", "\n", "# Vertical line with minimum \"useful\" accuracy on this data\n", "curr_const_clf_acc = constant_clf_accuracy[SHOW_RESULTS_ON]\n", "plt.axvline(\n", " x=curr_const_clf_acc,\n", " ls=\"--\",\n", " color=\"grey\",\n", ")\n", "plt.gca().annotate(\n", " \"constant predictor acc.\",\n", " xy=(curr_const_clf_acc, ax_kwargs[\"ylim\"][1] / 2),\n", " zorder=10,\n", " rotation=90,\n", " horizontalalignment=\"right\",\n", " verticalalignment=\"center\",\n", " fontsize=\"small\",\n", " \n", ")\n", "\n", "# Title and legend\n", "ax_kwargs[\"title\"] = f\"Post-processing ({SHOW_RESULTS_ON})\"\n", "ax_kwargs[\"xlim\"] = (curr_const_clf_acc - 1e-2, 0.885)\n", "\n", "plt.legend(\n", " loc=\"upper left\",\n", " bbox_to_anchor=(1.03, 1),\n", " borderaxespad=0)\n", "\n", "plt.gca().set(**ax_kwargs)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }