Source code for error_parity.binarize

"""Module to binarize continuous-score predictions.

Based on: https://github.com/AndreFCruz/hpt/blob/main/src/hpt/binarize.py
"""
import math
import logging
from typing import Optional

import numpy as np


[docs] def compute_binary_predictions( y_true: np.ndarray, y_pred_scores: np.ndarray, threshold: Optional[float] = None, tpr: Optional[float] = None, fpr: Optional[float] = None, ppr: Optional[int] = None, random_seed: Optional[int] = 42, ) -> np.ndarray: """Discretizes the given score predictions into binary labels. If necessary, will randomly untie binary predictions with equal score. Parameters ---------- y_true : np.ndarray The true binary labels y_pred_scores : np.ndarray Predictions as a continuous score between 0 and 1 threshold : Optional[float], optional Whether to use a specified (global) threshold, by default None tpr : Optional[float], optional Whether to target a specified TPR (true positive rate, or recall), by default None fpr : Optional[float], optional Whether to target a specified FPR (false positive rate), by default None ppr : Optional[float], optional Whether to target a specified PPR (positive prediction rate), by default None Returns ------- np.ndarray The binarized predictions according to the specified target. """ assert sum(1 for val in {threshold, fpr, tpr, ppr} if val is not None) == 1, ( f"Please provide exactly one of (threshold, fpr, tpr, ppr); got " f"{(threshold, fpr, tpr, ppr)}." ) # If threshold provided, just binarize it, no untying necessary if threshold: return (y_pred_scores >= threshold).astype(int) # Otherwise, we need to compute the allowed value for the numerator # and corresponding threshold (plus, may require random untying) label_pos = np.count_nonzero(y_true) label_neg = np.count_nonzero(1 - y_true) assert (total := label_pos + label_neg) == len(y_true) # sanity check # Indices of predictions ordered by score, descending y_pred_sorted_indices = np.argsort(-y_pred_scores) # Labels ordered by descending prediction score y_true_sorted = y_true[y_pred_sorted_indices] # Number of positive predictions allowed according to the given metric # (the allowed budget for the metric's numerator) positive_preds_budget: int # Samples that count for the positive_preds_budget # (LPs for TPR, LNs for FPR, and all samples for PPR) # (related to the metric's denominator) target_samples_mask: np.ndarray if tpr: # TPs budget to ensure >= the target TPR positive_preds_budget = math.ceil(tpr * label_pos) target_samples_mask = y_true_sorted == 1 # label positive samples non_target_samples_mask = y_true_sorted == 0 # label negative samples elif fpr: # FPs budget to ensure <= the target FPR positive_preds_budget = math.floor(fpr * label_neg) target_samples_mask = y_true_sorted == 0 # label negative samples non_target_samples_mask = y_true_sorted == 1 # label positive samples elif ppr: # PPs budget to ensure <= the target PPR positive_preds_budget = math.floor(ppr * total) target_samples_mask = np.ones_like(y_true_sorted).astype(bool) # all samples # Indices of target samples (relevant for the target metric), ordered by descending score target_samples_indices = y_pred_sorted_indices[target_samples_mask] # Find the threshold at which the specified numerator_budget is met threshold_idx = target_samples_indices[(positive_preds_budget - 1)] threshold = y_pred_scores[threshold_idx] #################################### # Code for random untying follows: # #################################### y_pred_binary = (y_pred_scores >= threshold).astype(int) # 1. compute actual number of positive predictions (on relevant target samples) actual_pos_preds = np.sum(y_pred_binary[target_samples_indices]) # 2. check if this number corresponds to the target if actual_pos_preds != positive_preds_budget: logging.warning( "Target metric for thresholding could not be met, will randomly " "untie samples with the same predicted score to fulfill target." ) assert actual_pos_preds > positive_preds_budget, ( "Sanity check: actual number of positive predictions should always " "be higher or equal to the target number when following this " f"algorithm; got actual={actual_pos_preds}, target={positive_preds_budget};" ) # 2.1. if target was not met, compute number of extra predicted positives extra_pos_preds = actual_pos_preds - positive_preds_budget # 2.2. randomly select extra_pos_preds among the relevant # samples (either TPs or FPs or PPs) with the same score rng = np.random.RandomState(random_seed) samples_at_target_threshold_mask = ( y_pred_scores[y_pred_sorted_indices] == threshold ) target_samples_at_target_threshold_indices = y_pred_sorted_indices[ samples_at_target_threshold_mask & target_samples_mask # Filter for samples at target threshold # Filter for relevant (target) samples ] # # The extra number of positive predictions must be fully explained by this score tie # import ipdb; ipdb.set_trace() # TODO: figure out why this assertion fails # assert extra_pos_preds < len(target_samples_at_target_threshold_indices) extra_pos_preds_indices = rng.choice( target_samples_at_target_threshold_indices, size=extra_pos_preds, replace=False, ) # 2.3. give extra_pos_preds_indices a negative prediction y_pred_binary[extra_pos_preds_indices] = 0 # 2.4. Randomly sample the non-target labels at same rate if tpr or fpr: sampled_fraction = 1 - (positive_preds_budget / actual_pos_preds) non_target_samples_at_target_threshold_indices = y_pred_sorted_indices[ samples_at_target_threshold_mask & non_target_samples_mask # Filter for samples at target threshold # Filter for positive samples ] num_samples = ( non_target_samples_at_target_threshold_indices.shape[0] * sampled_fraction ) num_samples = int(round(num_samples, 0)) if num_samples: extra_neg_preds_indices = rng.choice( non_target_samples_at_target_threshold_indices, size=num_samples, replace=False, ) y_pred_binary[extra_neg_preds_indices] = 0 # Sanity check: the number of positive_preds_budget should now be exactly fulfilled assert np.sum(y_pred_binary[target_samples_indices]) == positive_preds_budget return y_pred_binary