Source code for error_parity.threshold_optimizer

"""Solver for the relaxed equal odds problem.

"""
from __future__ import annotations

import logging
from itertools import product
from typing import Callable

import numpy as np
from sklearn.metrics import roc_curve

from .cvxpy_utils import (
    compute_fair_optimum,
    ALL_CONSTRAINTS,
    NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE,
)
from .roc_utils import (
    roc_convex_hull,
    calc_cost_of_point,
)
from .classifiers import (
    Classifier,
    RandomizedClassifier,
    EnsembleGroupwiseClassifiers,
)


[docs] class RelaxedThresholdOptimizer(Classifier): """Class to encapsulate all the logic needed to compute the optimal equal odds classifier (with possibly relaxed constraints). """ def __init__( self, *, predictor: Callable[[np.ndarray], np.ndarray], constraint: str = "equalized_odds", tolerance: float = 0.0, false_pos_cost: float = 1.0, false_neg_cost: float = 1.0, l_p_norm: int = np.inf, max_roc_ticks: int = 1000, seed: int = 42, ): """Initializes the relaxed equal odds wrapper. Parameters ---------- predictor : callable[(np.ndarray), float] A trained score predictor that takes in samples, X, in shape (num_samples, num_features), and outputs real-valued scores, R, in shape (num_samples,). constraint : str The fairness constraint to use. By default "equalized_odds". tolerance : float The absolute tolerance for the equal odds fairness constraint. Will allow for `tolerance` difference between group-wise ROC points. false_pos_cost : float, optional The cost of a FALSE POSITIVE error, by default 1.0. false_neg_cost : float, optional The cost of a FALSE NEGATIVE error, by default 1.0. l_p_norm : int, optional The l-p norm to use when computing distances between group ROC points. Used only for the "equalized odds" constraint (different l-p norms lead to different equalized-odds relaxations). By default np.inf, which corresponds to the l-inf norm. max_roc_ticks : int, optional The maximum number of ticks (points) in each group's ROC, when computing the optimal fair classifier, by default 1000. seed : int A random seed used for reproducibility when producing randomized classifiers. """ # Save arguments self.predictor = predictor self.constraint = constraint self.tolerance = tolerance self.false_pos_cost = false_pos_cost self.false_neg_cost = false_neg_cost self.l_p_norm = l_p_norm self.max_roc_ticks = max_roc_ticks self.seed = seed # Validate constraint if self.constraint not in ALL_CONSTRAINTS: raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE) if self.l_p_norm != np.inf and self.constraint != "equalized_odds": raise ValueError( f"l-p norm is only supported for the 'equalized_odds' constraint. " f"Got constraint='{self.constraint}' and l_p_norm={self.l_p_norm}." ) if not (isinstance(self.l_p_norm, int) or self.l_p_norm == np.inf): raise ValueError( f"Invalid l-p norm={self.l_p_norm}. Must be an integer or np.inf." ) # Initialize instance variables self._groupwise_roc_data: dict = None self._groupwise_roc_hulls: dict = None self._groupwise_roc_points: np.ndarray = None self._groupwise_prevalence: np.ndarray = None self._global_roc_point: np.ndarray = None self._global_prevalence: float = None self._realized_classifier: EnsembleGroupwiseClassifiers = None @property def groupwise_roc_data(self) -> dict: """Group-specific ROC data containing (FPR, TPR, threshold) triplets.""" return self._groupwise_roc_data @property def groupwise_roc_hulls(self) -> dict: """Group-specific ROC convex hulls achieved by underlying predictor.""" return self._groupwise_roc_hulls @property def groupwise_roc_points(self) -> np.ndarray: """Group-specific ROC points achieved by solution.""" return self._groupwise_roc_points @property def groupwise_prevalence(self) -> np.ndarray: """Group-specific prevalence, i.e., P(Y=1|A=a)""" return self._groupwise_prevalence @property def global_roc_point(self) -> np.ndarray: """Global ROC point achieved by solution.""" return self._global_roc_point @property def global_prevalence(self) -> np.ndarray: """Global prevalence, i.e., P(Y=1).""" return self._global_prevalence
[docs] def cost( self, false_pos_cost: float = None, false_neg_cost: float = None, ) -> float: """Computes the theoretical cost of the solution found. NOTE: use false_pos_cost==false_neg_cost==1 for the 0-1 loss (the standard error rate), which is equal to `1 - accuracy`. Parameters ---------- false_pos_cost : float, optional The cost of a FALSE POSITIVE error, by default will take the value given in the object's constructor. false_neg_cost : float, optional The cost of a FALSE NEGATIVE error, by default will take the value given in the object's constructor. Returns ------- float The cost of the solution found. """ self._check_fit_status() global_fpr, global_tpr = self.global_roc_point return calc_cost_of_point( fpr=global_fpr, fnr=1 - global_tpr, prevalence=self._global_prevalence, false_pos_cost=false_pos_cost or self.false_pos_cost, false_neg_cost=false_neg_cost or self.false_neg_cost, )
[docs] def constraint_violation( self, constraint_name: str = None, l_p_norm: int = None, ) -> float: """Theoretical constraint violation of the LP solution found. Parameters ---------- constraint_name : str, optional Optionally, may provide another constraint name that will be used instead of this classifier's self.constraint; l_p_norm : int, optional Which l-p norm to use when computing distances between group ROC points. Used only for the "equalized odds" constraint. Returns ------- float The fairness constraint violation. """ self._check_fit_status() # Warn if provided a different constraint constraint_name = constraint_name or self.constraint if constraint_name != self.constraint: logging.warning( f"Calculating constraint violation for {constraint_name} constraint;\n" f"Note: this classifier was fitted with a {self.constraint} constraint;" ) # Warn if provided a different l-p norm l_p_norm = l_p_norm or self.l_p_norm if l_p_norm != self.l_p_norm: logging.warning( f"Calculating constraint violation with l-{l_p_norm} norm;\n" f"Note: this classifier was fitted with l-{self.l_p_norm} norm;" ) # Validate constraint if constraint_name not in ALL_CONSTRAINTS: raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE) if constraint_name == "equalized_odds": return self.equalized_odds_violation(l_p_norm=l_p_norm) elif constraint_name.endswith("rate_parity"): constraint_to_error_type = { "true_positive_rate_parity": "fn", "false_positive_rate_parity": "fp", "true_negative_rate_parity": "fp", "false_negative_rate_parity": "fn", } return self.error_rate_parity_constraint_violation( error_type=constraint_to_error_type[constraint_name], ) elif constraint_name == "demographic_parity": return self.demographic_parity_violation() else: raise NotImplementedError( f"Standalone constraint violation not yet computed for " f"constraint='{constraint_name}'." )
[docs] def error_rate_parity_constraint_violation(self, error_type: str) -> float: """Computes the theoretical violation of an error-rate parity constraint. Parameters ---------- error_type : str One of the following values: "fp", for false positive errors (FPR or TNR parity); "fn", for false negative errors (TPR or FNR parity). Returns ------- float The maximum constraint violation among all groups. """ self._check_fit_status() valid_error_types = ("fp", "fn") if error_type not in valid_error_types: raise ValueError( f"Invalid error_type='{error_type}', must be one of " f"{valid_error_types}." ) roc_idx_of_interest = 0 if error_type == "fp" else 1 return self._max_l_p_between_points( points=[ np.reshape( # NOTE: must pass an array object, not scalars roc_point[roc_idx_of_interest], # use only FPR or TPR (whichever was constrained) newshape=(1,)) for roc_point in self.groupwise_roc_points ], l_p_norm=np.inf, )
[docs] def equalized_odds_violation(self, l_p_norm: int = None) -> float: """Computes the theoretical violation of the equal odds constraint (i.e., the maximum l-inf distance between the ROC point of any pair of groups). Parameters ---------- l_p_norm : int, optional Which l-p norm to use when computing distances between group ROC points. Returns ------- float The equal-odds constraint violation. """ self._check_fit_status() # Warn if provided a different l-p norm l_p_norm = l_p_norm or self.l_p_norm if l_p_norm != self.l_p_norm: logging.warning( f"Calculating constraint violation with l-{l_p_norm} norm;\n" f"Note: this classifier was fitted with l-{self.l_p_norm} norm;" ) # Compute l-p distance between each pair of groups return self._max_l_p_between_points( points=self.groupwise_roc_points, l_p_norm=l_p_norm, )
[docs] def demographic_parity_violation(self) -> float: """Computes the theoretical violation of the demographic parity constraint. That is, the maximum distance between groups' PPR (positive prediction rate). Returns ------- float The demographic parity constraint violation. """ self._check_fit_status() # Compute groups' PPR (positive prediction rate) return self._max_l_p_between_points( points=[ # NOTE: must pass an array object, not scalars np.reshape( group_tpr * group_prev + group_fpr * (1 - group_prev), newshape=(1,), ) for (group_fpr, group_tpr), group_prev in zip(self.groupwise_roc_points, self.groupwise_prevalence) ], l_p_norm=np.inf, )
@staticmethod def _max_l_p_between_points( points: list[float | np.ndarray], l_p_norm: int, ) -> float: # Number of points (should correspond to the number of groups) n_points = len(points) # Compute l-inf distance between each pair of groups l_inf_constraint_violation = [ (np.linalg.norm(points[i] - points[j], ord=l_p_norm), (i, j)) for i, j in product(range(n_points), range(n_points)) if i < j ] # Return the maximum max_violation, (groupA, groupB) = max(l_inf_constraint_violation) logging.info( f"Maximum fairness violation is between " f"group={groupA} (p={points[groupA]}) and " f"group={groupB} (p={points[groupB]});" ) return max_violation
[docs] def fit( self, X: np.ndarray, y: np.ndarray, *, group: np.ndarray, y_scores: np.ndarray = None, ): """Fit this predictor to achieve the (possibly relaxed) equal odds constraint on the provided data. Parameters ---------- X : np.ndarray The input features. y : np.ndarray The input labels. group : np.ndarray The group membership of each sample. Assumes groups are numbered [0, 1, ..., num_groups-1]. y_scores : np.ndarray, optional The pre-computed model predictions on this data. Returns ------- callable Returns self. """ # Compute group stats self._global_prevalence = np.sum(y) / len(y) unique_groups = np.unique(group) num_groups = len(unique_groups) if np.max(unique_groups) > num_groups - 1: raise ValueError( f"Groups should be numbered starting at 0, and up to " f"num_groups-1. Got {num_groups} groups, but max value is " f"{np.max(unique_groups)} != num_groups-1 == {num_groups-1}." ) # Relative group sizes for LN and LP samples group_sizes_label_neg = np.array( [np.sum(1 - y[group == g]) for g in unique_groups] ) group_sizes_label_pos = np.array([np.sum(y[group == g]) for g in unique_groups]) if np.sum(group_sizes_label_neg) + np.sum(group_sizes_label_pos) != len(y): raise RuntimeError("Failed sanity check. Are you using non-binary labels?") # Convert to relative sizes group_sizes_label_neg = group_sizes_label_neg.astype(float) / np.sum( group_sizes_label_neg ) group_sizes_label_pos = group_sizes_label_pos.astype(float) / np.sum( group_sizes_label_pos ) # Compute group-wise prevalence rates self._groupwise_prevalence = np.array( [np.mean(y[group == g]) for g in unique_groups] ) # Compute group-wise ROC curves if y_scores is None: y_scores = self.predictor(X) # Flatten y_scores array if needed if isinstance(y_scores, np.ndarray) and len(y_scores.shape) > 1: y_scores = y_scores.ravel() self._groupwise_roc_data = dict() for g in unique_groups: group_filter = group == g roc_curve_data = roc_curve( y[group_filter], y_scores[group_filter], ) # Check if max_roc_ticks is exceeded fpr, tpr, thrs = roc_curve_data if self.max_roc_ticks is not None and len(fpr) > self.max_roc_ticks: indices_to_keep = np.arange( 0, len(fpr), len(fpr) / self.max_roc_ticks ).astype(int) # Bottom-left (0,0) and top-right (1,1) points must be kept indices_to_keep[-1] = len(fpr) - 1 roc_curve_data = ( fpr[indices_to_keep], tpr[indices_to_keep], thrs[indices_to_keep], ) self._groupwise_roc_data[g] = roc_curve_data # Compute convex hull of each ROC curve self._groupwise_roc_hulls = dict() for g in unique_groups: group_fpr, group_tpr, _group_thresholds = self._groupwise_roc_data[g] curr_roc_points = np.stack((group_fpr, group_tpr), axis=1) curr_roc_points = np.vstack( (curr_roc_points, [1, 0]) ) # Add point (1, 0) to ROC curve self._groupwise_roc_hulls[g] = roc_convex_hull(curr_roc_points) # Find the group-wise optima that fulfill the fairness criteria self._groupwise_roc_points, self._global_roc_point = compute_fair_optimum( fairness_constraint=self.constraint, tolerance=self.tolerance, groupwise_roc_hulls=self._groupwise_roc_hulls, group_sizes_label_pos=group_sizes_label_pos, group_sizes_label_neg=group_sizes_label_neg, groupwise_prevalence=self.groupwise_prevalence, global_prevalence=self.global_prevalence, false_positive_cost=self.false_pos_cost, false_negative_cost=self.false_neg_cost, l_p_norm=self.l_p_norm, ) # Construct each group-specific classifier all_rand_clfs = { g: RandomizedClassifier.construct_at_target_ROC( predictor=self.predictor, roc_curve_data=self._groupwise_roc_data[g], target_roc_point=self._groupwise_roc_points[g], seed=self.seed, ) for g in unique_groups } # Construct the global classifier (can be used for all groups) self._realized_classifier = EnsembleGroupwiseClassifiers( group_to_clf=all_rand_clfs ) return self
def __call__(self, X: np.ndarray, *, group: np.ndarray) -> np.ndarray: """Generate predictions for the given input data.""" return self._realized_classifier(X, group)
[docs] def predict(self, X: np.ndarray, *, group: np.ndarray) -> np.ndarray: """Generate predictions for the given input data. Parameters ---------- X : np.ndarray Input samples. group : np.ndarray Input sensitive groups. Returns ------- np.ndarray A sequence of predictions, one per input sample and input group. """ return self(X, group=group)
def _check_fit_status(self, raise_error: bool = True) -> bool: """Check whether this classifier has been fit on some data. Parameters ---------- raise_error : bool, optional Whether to raise an error if the classifier is uninitialized (otherwise will just return False), by default True. Returns ------- is_fit : bool Whether the classifier was already fit on some data. Raises ------ RuntimeError If `raise_error==True`, raises an error if the classifier is uninitialized. """ if self._realized_classifier is None: if not raise_error: return False raise RuntimeError( "This classifier has not yet been fitted to any data.") return True def __copy__(self): """Create a shallow copy of this object. The returned copy is in a blank state, i.e., it has not been fit to any data. """ return self.__class__( predictor=self.predictor, constraint=self.constraint, tolerance=self.tolerance, false_pos_cost=self.false_pos_cost, false_neg_cost=self.false_neg_cost, l_p_norm=self.l_p_norm, max_roc_ticks=self.max_roc_ticks, seed=self.seed, )