"""Helper functions to solve the relaxed equal odds problem.
"""
import logging
import numpy as np
from scipy.spatial import ConvexHull
from sklearn.metrics import confusion_matrix
[docs]
def calc_cost_of_point(
fpr: float,
fnr: float,
prevalence: float,
false_pos_cost: float = 1.0,
false_neg_cost: float = 1.0,
) -> float:
"""Calculates the cost of the given ROC point.
Parameters
----------
fpr : float
The false positive rate (FPR).
fnr : float
The false negative rate (FNR).
prevalence : float
The prevalence of positive samples in the dataset,
i.e., np.sum(y_true) / len(y_true)
false_pos_cost : float, optional
The cost of a false positive error, by default 1.
false_neg_cost : float, optional
The cost of a false negative error, by default 1.
Returns
-------
cost : float
The cost of the given ROC point (divided by the size of the dataset).
"""
cost_vector = np.array([false_pos_cost, false_neg_cost])
weight_vector = np.array([1 - prevalence, prevalence])
return cost_vector * weight_vector @ np.array([fpr, fnr])
[docs]
def compute_roc_point_from_predictions(y_true, y_pred_binary):
"""Computes the ROC point associated with the provided binary predictions.
Parameters
----------
y_true : np.ndarray
The true labels.
y_pred_binary : np.ndarray
The binary predictions.
Returns
-------
tuple[float, float]
The resulting ROC point, i.e., a tuple (FPR, TPR).
"""
tn, fp, fn, tp = confusion_matrix(y_true, y_pred_binary).ravel()
# FPR = FP / LN
fpr = fp / (fp + tn)
# TPR = TP / LP
tpr = tp / (tp + fn)
return (fpr, tpr)
[docs]
def compute_global_roc_from_groupwise(
groupwise_roc_points: np.ndarray,
groupwise_label_pos_weight: np.ndarray,
groupwise_label_neg_weight: np.ndarray,
) -> np.ndarray:
"""Computes the global ROC point that corresponds to the provided group-wise
ROC points.
The global ROC is a linear combination of the group-wise points, with
different weights for computing FPR and TPR -- the first related to LNs, and
the second to LPs.
Parameters
----------
groupwise_roc_points : np.ndarray
An array of shape (n_groups, n_roc_dims) containing one ROC point per
group.
groupwise_label_pos_weight : np.ndarray
The relative size of each group in terms of its label POSITIVE samples
(out of all POSITIVE samples, how many are in each group).
groupwise_label_neg_weight : np.ndarray
The relative size of each group in terms of its label NEGATIVE samples
(out of all NEGATIVE samples, how many are in each group).
Returns
-------
global_roc_point : np.ndarray
A single point that corresponds to the global outcome of the given
group-wise ROC points.
"""
n_groups, _ = groupwise_roc_points.shape
# Some initial sanity checks
if (
len(groupwise_label_pos_weight) != len(groupwise_label_neg_weight)
or len(groupwise_label_pos_weight) != n_groups
):
raise ValueError(
"Invalid input shapes: length of all arguments must be equal (the "
"number of different sensitive groups)."
)
# Normalize group LP (/LN) weights by their size
if not np.isclose(groupwise_label_pos_weight.sum(), 1.0):
groupwise_label_pos_weight /= groupwise_label_pos_weight.sum()
if not np.isclose(groupwise_label_neg_weight.sum(), 1.0):
groupwise_label_neg_weight /= groupwise_label_neg_weight.sum()
# Compute global FPR (weighted by relative number of LNs in each group)
global_fpr = groupwise_label_neg_weight @ groupwise_roc_points[:, 0]
# Compute global TPR (weighted by relative number of LPs in each group)
global_tpr = groupwise_label_pos_weight @ groupwise_roc_points[:, 1]
global_roc_point = np.array([global_fpr, global_tpr])
return global_roc_point
[docs]
def roc_convex_hull(roc_points: np.ndarray) -> np.ndarray:
"""Computes the convex hull of the provided ROC points.
Parameters
----------
roc_points : np.ndarray
An array of shape (n_points, n_dims) containing all points
of a provided ROC curve.
Returns
-------
hull_points : np.ndarray
An array of shape (n_hull_points, n_dim) containing all
points in the convex hull of the ROC curve.
"""
# Save init data just for logging
init_num_points, _dims = roc_points.shape
# Compute convex hull
hull = ConvexHull(roc_points)
# NOTE: discarding points below the diagonal seems to lead to bugs later on, idk why...
# Discard points in the interior of the convex hull,
# and other useless points (below main diagonal)
# points_above_diagonal = np.argwhere(roc_points[:, 1] >= roc_points[:, 0]).ravel()
# hull_indices = sorted(set(hull.vertices) & set(points_above_diagonal))
hull_indices = hull.vertices
logging.info(
f"ROC convex hull contains {len(hull_indices) / init_num_points:.1%} "
f"of the original points."
)
return roc_points[hull_indices]