Source code for error_parity.plotting

"""Utils for plotting postprocessing frontier and postprocessing solution."""

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.figure
from matplotlib import pyplot as plt

from .pareto_curve import compute_inner_and_outer_adjustment_ci, get_envelope_of_postprocessing_frontier
from .threshold_optimizer import RelaxedThresholdOptimizer
from .classifiers import RandomizedClassifier


[docs] def plot_polygon_edges(polygon_points, **kwargs): point_to_plot = np.vstack((polygon_points, polygon_points[0])) plt.plot(point_to_plot[:, 0], point_to_plot[:, 1], **kwargs)
[docs] def plot_postprocessing_solution( *, postprocessed_clf: RelaxedThresholdOptimizer, plot_roc_curves: bool = False, plot_roc_hulls: bool = True, plot_group_optima: bool = True, plot_group_triangulation: bool = True, plot_global_optimum: bool = True, plot_diagonal: bool = True, plot_relaxation: bool = False, group_name_map: dict = None, figure: matplotlib.figure.Figure = None, **fig_kwargs, ): """Plots the group-specific solutions found for this predictor. Parameters ---------- postprocessed_clf : RelaxedThresholdOptimizer A postprocessed classifier already fitted on some data. plot_roc_curves : bool, optional Whether to plot the global ROC curves, by default False. plot_roc_hulls : bool, optional Whether to plot the global ROC convex hulls, by default True. plot_group_optima : bool, optional Whether to plot the group-specific optima, by default True. plot_group_triangulation : bool, optional Whether to plot the triangulation of a group-specific solution, when such triangulation is needed to achieve a target ROC point. plot_global_optimum : bool, optional Whether to plot the global optimum ROC point, by default True. plot_diagonal : bool, optional Whether to plot the ROC diagonal with FPR=TPR, by default True. plot_relaxation : bool, optional Whether to plot the constraint relaxation bounding box, by default False. group_name_map : dict, optional A dictionary mapping each group's value to an appropriate name to show in the plot legend, by default None. figure : matplotlib.figure.Figure, optional A matplotlib figure to use when plotting, by default will generate a new figure for plotting. """ postprocessed_clf._check_fit_status() from matplotlib import pyplot as plt from matplotlib.patches import Rectangle import seaborn as sns n_groups = len(postprocessed_clf.groupwise_roc_hulls) # Set group-wise colors and global color palette = sns.color_palette(n_colors=n_groups + 1) global_color = palette[0] all_group_colors = palette[1:] if figure is None: figure = plt.figure(**fig_kwargs) # For each group `idx` for idx in range(n_groups): group_ls = (["--", ":", "-."] * (1 + n_groups // 3))[idx] group_color = all_group_colors[idx] # Plot group-wise (actual) ROC curves if plot_roc_curves: roc_points = np.stack(postprocessed_clf.groupwise_roc_data[idx], axis=1)[:, 0:2] plot_polygon_edges( np.vstack((roc_points, [1, 0])), color=group_color, ls=group_ls, alpha=0.5, ) # Plot group-wise ROC hulls if plot_roc_hulls: plot_polygon_edges( postprocessed_clf.groupwise_roc_hulls[idx], color=group_color, ls=group_ls, ) # Plot group-wise fair optimum group_optimum = postprocessed_clf.groupwise_roc_points[idx] if plot_group_optima: plt.plot( group_optimum[0], group_optimum[1], label=group_name_map[idx] if group_name_map else f"group {idx}", color=group_color, marker="^", markersize=5, lw=0, ) # Plot triangulation of target point if plot_group_triangulation: ( _weights, triangulated_points, ) = RandomizedClassifier.find_points_for_target_ROC( roc_curve_data=postprocessed_clf._groupwise_roc_data[idx], target_roc_point=group_optimum, ) plt.plot( triangulated_points[:, 0], triangulated_points[:, 1], color=group_color, marker="x", lw=0, ) plt.fill( triangulated_points[:, 0], triangulated_points[:, 1], color=group_color, alpha=0.1, ) # Plot global optimum if plot_global_optimum: plt.plot( postprocessed_clf.global_roc_point[0], postprocessed_clf.global_roc_point[1], label="global", marker="*", color=global_color, alpha=0.6, markersize=5, lw=0, ) # TODO: plot maximum constraint relaxation # (may differ from realized relaxation, for example, if either the TPR or # FPR diff violation is strictly smaller than the allowed tolerance) # Plot rectangle to visualize constraint relaxation if plot_relaxation: # Get rectangle points min_x, max_x = ( np.min(postprocessed_clf.groupwise_roc_points[:, 0]), np.max(postprocessed_clf.groupwise_roc_points[:, 0]), ) min_y, max_y = ( np.min(postprocessed_clf.groupwise_roc_points[:, 1]), np.max(postprocessed_clf.groupwise_roc_points[:, 1]) ) # Draw relaxation rectangle rect = Rectangle( xy=(min_x, min_y), width=max_x - min_x, height=max_y - min_y, facecolor="grey", alpha=0.3, label="relaxation", ) # Add the patch to the Axes ax = plt.gca() ax.add_patch(rect) # Plot diagonal if plot_diagonal: plt.plot( [0, 1], [0, 1], ls="--", color="grey", alpha=0.5, label="random clf.", ) # Set axis settings fairness_constr_str = postprocessed_clf.constraint.replace("_", " ") if postprocessed_clf.constraint == "equalized_odds": l_p_norm = postprocessed_clf.l_p_norm if postprocessed_clf.l_p_norm != np.inf else r"\infty" fairness_constr_str += f" $\\ell_{l_p_norm}$" plt.suptitle(f"Solution to {postprocessed_clf.tolerance}-relaxed optimum", y=0.96) plt.title( f"(fairness constraint: {fairness_constr_str})", fontsize="small", ) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.legend(loc="lower right", borderaxespad=2)
[docs] def plot_postprocessing_frontier( postproc_results_df: pd.DataFrame, *, perf_metric: str, disp_metric: str, show_data_type: str, constant_clf_perf: float, model_name: str = None, color: str = "black", ): """Helper to plot the given post-processing frontier results. Will use bootstrapped results if available, including plotting confidence intervals. Parameters ---------- postproc_results_df : pd.DataFrame The DataFrame containing postprocessing results. This should be the output of a call to `compute_postprocessing_curve(.)`. perf_metric : str Which performance metric to plot (horizontal axis). disp_metric : str Which disparity metric to plot (vertical axis). show_data_type : str The type of data to show results for; usually this will be "test". constant_clf_perf : float Performance achieved by the constant classifier; this is the point of lowest performance and lowest disparity achievable by postprocessing. model_name : str, optional Shown in the plot legend. Name of the model to be postprocessed. color : str, optional Which color to use for plotting the postprocessing curve, by default "black". """ # Get relevant column names perf_col = f"{perf_metric}_mean_{show_data_type}" disp_col = f"{disp_metric}_mean_{show_data_type}" # Check if bootstrap means are available has_bootstrap_results = perf_col in postproc_results_df.columns if not has_bootstrap_results: perf_col = f"{perf_metric}_{show_data_type}" disp_col = f"{disp_metric}_{show_data_type}" assert perf_col in postproc_results_df.columns, ( f"Could not find the column '{perf_col}' for the perf. metric " f"'{perf_metric}' on data type '{show_data_type}'.") assert disp_col in postproc_results_df.columns, ( f"Could not find the column '{disp_col}' for the disp. metric " f"'{disp_metric}' on data type '{show_data_type}'.") # Get envelope of postprocessing adjustment frontier postproc_frontier = get_envelope_of_postprocessing_frontier( postproc_results_df, perf_col=perf_col, disp_col=disp_col, constant_clf_perf=constant_clf_perf, ) # Get inner and outer confidence intervals if has_bootstrap_results: postproc_frontier_xticks, interior_frontier_yticks, outer_frontier_yticks = \ compute_inner_and_outer_adjustment_ci( postproc_results_df, perf_metric=perf_metric, disp_metric=disp_metric, data_type=show_data_type, constant_clf_perf=constant_clf_perf, ) # Draw upper right portion of the line (dominated but not feasible) upper_right_frontier = np.array([ postproc_frontier[-1], (postproc_frontier[-1, 0] - 1e-6, 1.0), ]) sns.lineplot( x=upper_right_frontier[:, 0], y=upper_right_frontier[:, 1], linestyle=":", # label=r"dominated", color="grey", ) # Plot postprocessing frontier sns.lineplot( x=postproc_frontier[:, 0], y=postproc_frontier[:, 1], label=( "post-processing" if model_name is None else f"post-processing of {model_name}" ), linestyle="-.", color=color, ) # Draw confidence intervals (shaded area) if has_bootstrap_results: ax = plt.gca() ax.fill_between( x=postproc_frontier_xticks, y1=interior_frontier_yticks, y2=outer_frontier_yticks, interpolate=True, color=color, alpha=0.1, label=r"$95\%$ conf. interv.", )