Source code for error_parity.cvxpy_utils

"""A set of helper functions for using cvxpy."""
from __future__ import annotations
import logging
from itertools import product

import numpy as np
import cvxpy as cp
from cvxpy.expressions.variable import Variable
from cvxpy.expressions.expression import Expression

from .roc_utils import calc_cost_of_point, compute_global_roc_from_groupwise


# Maximum distance from solution to feasibility or optimality
SOLUTION_TOLERANCE = 1e-9

# Set of all fairness constraints with a cvxpy LP implementation
ALL_CONSTRAINTS = {
    "equalized_odds",           # equal TPR and equal FPR across groups
    "true_positive_rate_parity",    # TPR parity, same as FNR parity
    "false_positive_rate_parity",   # FPR parity, same as TNR parity
    "true_negative_rate_parity",    # TNR parity, same as FPR parity
    "false_negative_rate_parity",   # FNR parity, same as TPR parity
    "demographic_parity",       # equal positive prediction rates across groups
}

NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE = (
    "Currently only the following constraints are supported: {}.".format(
        ", ".join(sorted(ALL_CONSTRAINTS))
    )
)


[docs] def compute_line(p1: np.ndarray, p2: np.ndarray) -> tuple[float, float]: """Computes the slope and intercept of the line that passes through the two given points. The intercept is the value at x=0! (or NaN for vertical lines) For vertical lines just use the x-value of one of the points to find the intercept at y=0. Parameters ---------- p1 : np.ndarray A 2-D point. p2 : np.ndarray A 2-D point. Returns ------- tuple[float, float] A tuple pair with (slope, intercept) of the line that goes from p1 to p2. Raises ------ ValueError Raised when input is invalid, e.g., when p1 == p2. """ p1x, p1y = p1 p2x, p2y = p2 if all(p1 == p2): raise ValueError("Invalid points: p1==p2;") # Vertical line if np.isclose(p2x, p1x): slope = np.inf intercept = np.nan # Diagonal or horizontal line else: slope = (p2y - p1y) / (p2x - p1x) intercept = p1y - slope * p1x return slope, intercept
[docs] def compute_halfspace_inequality( # noqa: C901 p1: np.ndarray, p2: np.ndarray, ) -> tuple[float, float, float]: """Computes the halfspace inequality defined by the vector p1->p2, such that Ax + b <= 0, where A and b are extracted from the line that goes through p1->p2. As such, the inequality enforces that points must lie on the LEFT of the line defined by the p1->p2 vector. In other words, input points are assumed to be in COUNTER CLOCK-WISE order (right-hand rule). Parameters ---------- p1 : np.ndarray A point in the halfspace. p2 : np.ndarray Another point in the halfspace. Returns ------- tuple[float, float, float] Returns an array of size=(n_dims + 1), with format [A; b], representing the inequality Ax + b <= 0. Raises ------ RuntimeError Thrown in case if inconsistent internal state variables. """ slope, intercept = compute_line(p1, p2) # Unpack the points for ease of use p1x, p1y = p1 p2x, p2y = p2 # if slope is infinity, the constraint only applies to the values of x; # > the halfspace's b intercept value will correspond to this value of x; if np.isinf(slope): # Sanity check for vertical line if not np.isclose(p1x, p2x): raise RuntimeError( "Got infinite slope for line containing two points with " "different x-axis coordinates." ) # Vector pointing downwards? then, x >= b if p2y < p1y: return [-1, 0, p1x] # Vector pointing upwards? then, x <= b elif p2y > p1y: return [1, 0, -p1x] # elif slope is zero, the constraint only applies to the values of y; # > the halfspace's b intercept value will correspond to this value of y; elif np.isclose(slope, 0.0): # Sanity checks for horizontal line if not np.isclose(p1y, p2y) or not np.isclose(p1y, intercept): raise RuntimeError( f"Invalid horizontal line; points p1 and p2 should have same " f"y-axis value as intercept ({p1y}, {p2y}, {intercept})." ) # Vector pointing leftwards? then, y <= b if p2x < p1x: return [0, 1, -p1y] # Vector pointing rightwards? then, y >= b elif p2x > p1x: return [0, -1, p1y] # else, we have a standard diagonal line else: # Vector points left? # then, y <= mx + b <=> -mx + y - b <= 0 if p2x < p1x: return [-slope, 1, -intercept] # Vector points right? # then, y >= mx + b <=> mx - y + b <= 0 elif p2x > p1x: return [slope, -1, intercept] logging.error(f"No constraint can be concluded from points p1={p1} and p2={p2};") return [0, 0, 0]
[docs] def make_cvxpy_halfspace_inequality( p1: np.ndarray, p2: np.ndarray, cvxpy_point: Variable, ) -> Expression: """Creates a single cvxpy inequality constraint that enforces the given point, `cvxpy_point`, to lie on the left of the vector p1->p2. Points must be sorted in counter clock-wise order! Parameters ---------- p1 : np.ndarray A point p1. p2 : np.ndarray Another point p2. cvxpy_point : Variable The cvxpy variable over which the constraint will be applied. Returns ------- Expression A linear inequality constraint of type Ax + b <= 0. """ x_coeff, y_coeff, b = compute_halfspace_inequality(p1, p2) return np.array([x_coeff, y_coeff]) @ cvxpy_point + b <= 0
[docs] def make_cvxpy_point_in_polygon_constraints( polygon_vertices: np.ndarray, cvxpy_point: Variable, ) -> list[Expression]: """Creates the set of cvxpy constraints that force the given cvxpy variable point to lie within the polygon defined by the given vertices. Parameters ---------- polygon_vertices : np.ndarray A sequence of points that make up a polygon. Points must be sorted in COUNTER CLOCK-WISE order! (right-hand rule) cvxpy_point : cvxpy.Variable A cvxpy variable representing a point, over which the constraints will be applied. Returns ------- list[Expression] A list of cvxpy constraints. """ return [ make_cvxpy_halfspace_inequality( polygon_vertices[i], polygon_vertices[(i + 1) % len(polygon_vertices)], cvxpy_point, ) for i in range(len(polygon_vertices)) ]
[docs] def compute_fair_optimum( # noqa: C901 *, fairness_constraint: str, tolerance: float, groupwise_roc_hulls: dict[int, np.ndarray], group_sizes_label_pos: np.ndarray, group_sizes_label_neg: np.ndarray, groupwise_prevalence: np.ndarray, global_prevalence: float, false_positive_cost: float = 1.0, false_negative_cost: float = 1.0, l_p_norm: int | str = np.inf, ) -> tuple[np.ndarray, np.ndarray]: """Computes the solution to finding the optimal fair (equal odds) classifier. Can relax the equal odds constraint by some given tolerance. Parameters ---------- fairness_constraint : str The name of the fairness constraint under which the LP will be optimized. Possible inputs are: 'equalized_odds' match true positive and false positive rates across groups tolerance : float A value for the tolerance when enforcing the fairness constraint. groupwise_roc_hulls : dict[int, np.ndarray] A dict mapping each group to the convex hull of the group's ROC curve. The convex hull is an np.array of shape (n_points, 2), containing the points that form the convex hull of the ROC curve, sorted in COUNTER CLOCK-WISE order. group_sizes_label_pos : np.ndarray The relative or absolute number of positive samples in each group. group_sizes_label_neg : np.ndarray The relative or absolute number of negative samples in each group. global_prevalence : float The global prevalence of positive samples. false_positive_cost : float, optional The cost of a FALSE POSITIVE error, by default 1. false_negative_cost : float, optional The cost of a FALSE NEGATIVE error, by default 1. l_p_norm : int | str, optional The type of l-p norm to use when computing the distance between two ROC points. Used only for the "equalized_odds" constraint. By default uses `np.inf` (l-infinity distance): the maximum between groups' TPR and FPR differences. Using `l_p_norm=1` will correspond to the `average_abs_odds_difference`. See the following link for more information on this parameter: https://www.cvxpy.org/api_reference/cvxpy.atoms.other_atoms.html#norm Returns ------- (groupwise_roc_points, global_roc_point) : tuple[np.ndarray, np.ndarray] A tuple pair, (<1>, <2>), containing: 1: an array with the group-wise ROC points for the solution. 2: an array with the single global ROC point for the solution. """ if fairness_constraint not in ALL_CONSTRAINTS: raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE) n_groups = len(groupwise_roc_hulls) if n_groups != len(group_sizes_label_neg) or n_groups != len(group_sizes_label_pos): raise ValueError( "Invalid arguments; all of the following should have the same " "length: groupwise_roc_hulls, group_sizes_label_neg, group_sizes_label_pos;" f"got: {len(groupwise_roc_hulls)}, {len(group_sizes_label_neg)}, {len(group_sizes_label_pos)};" ) # Group-wise ROC points --- in the form (FPR, TPR) groupwise_roc_points_vars = [ cp.Variable(shape=2, name=f"ROC point for group {i}", nonneg=True) for i in range(n_groups) ] # Define global ROC point as a linear combination of the group-wise ROC points global_roc_point_var = cp.Variable(shape=2, name="Global ROC point", nonneg=True) constraints = [ # Global FPR is the average of group FPRs weighted by LNs in each group global_roc_point_var[0] == group_sizes_label_neg @ np.array([p[0] for p in groupwise_roc_points_vars]), # Global TPR is the average of group TPRs weighted by LPs in each group global_roc_point_var[1] == group_sizes_label_pos @ np.array([p[1] for p in groupwise_roc_points_vars]), ] # ** APPLY FAIRNESS CONSTRAINTS ** # NOTE: feature request: compatibility with multiple constraints simultaneously # If "equalized_odds" # - i.e., l-p distance between any two groups' ROC points must be less than `tolerance`; # - DEFAULT: l-infinity distance (max distance between any two points in the ROC curve); if fairness_constraint == "equalized_odds": constraints += [ cp.norm( groupwise_roc_points_vars[i] - groupwise_roc_points_vars[j], p=l_p_norm, ) <= tolerance for i, j in product(range(n_groups), range(n_groups)) if i < j ] # If some rate parity, i.e., parity of one of {TPR, FPR, TNR, FNR} # i.e., constrain absolute distance between any two groups' rate metric elif fairness_constraint.endswith("rate_parity"): roc_idx_of_interest: int if ( fairness_constraint == "true_positive_rate_parity" # TPR or fairness_constraint == "false_negative_rate_parity" # FNR ): roc_idx_of_interest = 1 elif ( fairness_constraint == "false_positive_rate_parity" # FPR or fairness_constraint == "true_negative_rate_parity" # TNR ): roc_idx_of_interest = 0 else: # This point should never be reached as fairness_constraint was previously validated raise ValueError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE) constraints += [ cp.abs( groupwise_roc_points_vars[i][roc_idx_of_interest] - groupwise_roc_points_vars[j][roc_idx_of_interest] ) <= tolerance for i, j in product(range(n_groups), range(n_groups)) if i < j ] # If demographic parity, i.e., equal positive prediction rates across groups # note: this ignores the labels Y and only considers predictions Y_hat elif fairness_constraint == "demographic_parity": # NOTE: PPR = TPR * prevalence + FPR * (1 - prevalence) def group_positive_prediction_rate(group_idx: int): """Computes group-wise PPR as a function of the ROC cvxpy vars.""" group_prevalence = groupwise_prevalence[group_idx] group_tpr = groupwise_roc_points_vars[group_idx][1] group_fpr = groupwise_roc_points_vars[group_idx][0] return group_tpr * group_prevalence + group_fpr * (1 - group_prevalence) # Add constraints on the absolute difference between group-wos constraints += [ cp.abs( group_positive_prediction_rate(i) - group_positive_prediction_rate(j) ) <= tolerance for i, j in product(range(n_groups), range(n_groups)) if i < j ] # NOTE: implement other constraints here else: raise NotImplementedError(NOT_SUPPORTED_CONSTRAINTS_ERROR_MESSAGE) # Constraints for points in respective group-wise ROC curves for idx in range(n_groups): constraints += make_cvxpy_point_in_polygon_constraints( polygon_vertices=groupwise_roc_hulls[idx], cvxpy_point=groupwise_roc_points_vars[idx], ) # Define cost function obj = cp.Minimize( calc_cost_of_point( fpr=global_roc_point_var[0], fnr=1 - global_roc_point_var[1], prevalence=global_prevalence, false_pos_cost=false_positive_cost, false_neg_cost=false_negative_cost, ) ) # Define cvxpy problem prob = cp.Problem(obj, constraints) # Run solver prob.solve(solver=cp.ECOS, abstol=SOLUTION_TOLERANCE, feastol=SOLUTION_TOLERANCE) # NOTE: these tolerances are supposed to be smaller than the default np.isclose tolerances # (useful when comparing if two points are the same, within the cvxpy accuracy tolerance) # Log solution logging.info( f"cvxpy solver took {prob.solver_stats.solve_time}s; status is {prob.status}." ) if prob.status not in ["infeasible", "unbounded"]: # Otherwise, problem.value is inf or -inf, respectively. logging.info(f"Optimal solution value: {prob.value}") for variable in prob.variables(): logging.info(f"Variable {variable.name()}: value {variable.value}") else: # This line should never be reached (there are always trivial fair # solutions in the ROC diagonal) raise ValueError(f"cvxpy problem has no solution; status={prob.status}") groupwise_roc_points = np.vstack([p.value for p in groupwise_roc_points_vars]) global_roc_point = global_roc_point_var.value # Validating solution cost solution_cost = calc_cost_of_point( fpr=global_roc_point[0], fnr=1 - global_roc_point[1], prevalence=global_prevalence, false_pos_cost=false_positive_cost, false_neg_cost=false_negative_cost, ) if not np.isclose(solution_cost, prob.value): logging.error( f"Solution was found but cost did not pass validation! " f"Found solution ROC point {global_roc_point} with theoretical cost " f"{prob.value}, but actual cost is {solution_cost};" ) # Validating congruency between group-wise ROC points and global ROC point global_roc_from_groupwise = compute_global_roc_from_groupwise( groupwise_roc_points=groupwise_roc_points, groupwise_label_pos_weight=group_sizes_label_pos, groupwise_label_neg_weight=group_sizes_label_neg, ) if not all(np.isclose(global_roc_from_groupwise, global_roc_point)): logging.error( f"Solution: global ROC point ({global_roc_point}) does not seem to " f"match group-wise ROC points; global should be " f"({global_roc_from_groupwise}) to be consistent with group-wise;" ) return groupwise_roc_points, global_roc_point