Render paper plots and tables

[1]:
import math
import logging
from pathlib import Path

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

Note: Change the following path to the aggregated results file in your local system (can be obtained using the parse-acs-results.ipynb notebook).

[2]:
ACS_AGG_RESULTS_PATH = Path("../results") / "aggregated-results.2024-08.csv"
ACS_AGG_RESULTS_PATH = Path(ACS_AGG_RESULTS_PATH)
[3]:
results_df = pd.read_csv(ACS_AGG_RESULTS_PATH, index_col=0)
print(f"{results_df.shape=}")
results_df.head(2)
results_df.shape=(210, 64)
[3]:
accuracy accuracy_diff accuracy_ratio balanced_accuracy balanced_accuracy_diff balanced_accuracy_ratio brier_score_loss ece ece_quantile equalized_odds_diff ... name is_inst num_features uses_all_features fit_thresh_on_100 fit_thresh_accuracy optimal_thresh optimal_thresh_accuracy score_stdev score_mean
penai/gpt-4o-mini__ACSTravelTime__-1__Num 0.570906 0.315286 0.624340 0.553706 0.070915 0.876372 0.267584 0.149517 NaN 0.382569 ... GPT 4o mini (it) True -1 True 0.250000 0.510013 0.450000 0.570968 0.223652 0.384149
penai/gpt-4o-mini__ACSTravelTime__-1__QA 0.551154 0.119294 0.812222 0.588927 0.113947 0.823835 0.404025 0.393327 0.393301 0.274655 ... GPT 4o mini (it) True -1 True 0.998499 0.602202 0.970688 0.593339 0.365205 0.772858

2 rows × 64 columns

Remove gemma-2-* results (need to investigate why results are so poor):

[4]:
results_df = results_df.drop(index=[id_ for id_ in results_df.index if "gemma-2-" in id_.lower()])
print(f"{results_df.shape=}")
results_df.shape=(170, 64)

Run baseline ML classifiers on the benchmark ACS tasks

[5]:
DATA_DIR = Path("/fast/groups/sf") / "data"
[6]:
ALL_TASKS = [
    "ACSIncome",
    "ACSMobility",
    "ACSEmployment",
    "ACSTravelTime",
    "ACSPublicCoverage",
]

model_col = "config_model_name"
task_col = "config_task_name"
numeric_prompt_col = "config_numeric_risk_prompting"

List all baseline classifiers here:

[7]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier
from xgboost import XGBClassifier    # NOTE: requires `pip install xgboost`

baselines = {
    "LR": LogisticRegression(),
    "GBM": HistGradientBoostingClassifier(),
    "XGBoost": XGBClassifier(),
}
[8]:
from folktexts.acs.acs_dataset import ACSDataset
from folktexts.evaluation import evaluate_predictions
from collections import defaultdict

def fit_and_eval(
    clf,
    X_train, y_train,
    X_test, y_test, s_test,
    fillna=False,
    save_scores_df_path: str = None,
) -> dict:
    """Fit and evaluate a given classifier on the given data."""
    assert len(X_train) == len(y_train) and len(X_test) == len(y_test) == len(s_test)

    train_nan_count = X_train.isna().any(axis=1).sum()
    if fillna and train_nan_count > 0:
        # Fill NaNs with value=-1
        X_train = X_train.fillna(axis="columns", value=-1)
        X_test = X_test.fillna(axis="columns", value=-1)

    # Fit on train data
    clf.fit(X_train, y_train)

    # Evaluate on test data
    y_test_scores = clf.predict_proba(X_test)[:, -1]
    test_results = evaluate_predictions(
        y_true=y_test.to_numpy(),
        y_pred_scores=y_test_scores,
        sensitive_attribute=s_test,
        threshold=0.5,
    )

    # Optionally, save test scores DF
    if save_scores_df_path:
        scores_df = pd.DataFrame({
            "risk_score": y_test_scores,
            "label": y_test,
        })
        scores_df.to_csv(save_scores_df_path)
        test_results["predictions_path"] = save_scores_df_path

    return test_results

def run_baselines(baselines, tasks) -> dict:
    """Run baseline classifiers on all acs tasks."""
    baseline_results = defaultdict(dict)

    # Prepare progress bar
    progress_bar = tqdm(
        total=len(tasks) * len(baselines),
        leave=True,
    )

    for task in tasks:
        progress_bar.set_postfix({"task": task})

        # Load ACS task data
        acs_dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)

        # Get train/test data
        X_train, y_train = acs_dataset.get_train()
        X_test, y_test = acs_dataset.get_test()

        # Get sensitive attribute test data
        s_test = None
        if acs_dataset.task.sensitive_attribute is not None:
            s_test = acs_dataset.get_sensitive_attribute_data().loc[y_test.index]

        for clf_name, clf in baselines.items():
            progress_bar.set_postfix({"task": task, "clf": clf_name})

            try:
                baseline_results[task][clf_name] = fit_and_eval(
                    clf=clf,
                    X_train=X_train, y_train=y_train,
                    X_test=X_test, y_test=y_test, s_test=s_test,
                    fillna=(clf_name == "LR"),
                    save_scores_df_path=ACS_AGG_RESULTS_PATH.parent / f"baseline_scores.{clf_name}.{task}.csv"
                )
            except Exception as err:
                logging.error(err)
            finally:
                progress_bar.update()

    return baseline_results

Flatten results and add extra columns.

[9]:
def parse_baseline_results(baseline_results) -> list:
    """Flatten and parse baseline results."""
    parsed_results_list = list()

    for task, task_results in baseline_results.items():

        for clf, clf_results in task_results.items():
            parsed_results = clf_results.copy()

            parsed_results["config_task_name"] = task
            parsed_results["config_model_name"] = clf
            parsed_results["name"] = clf
            parsed_results["num_features"] = -1
            parsed_results["uses_all_features"] = True

            parsed_results_list.append(parsed_results)

    return parsed_results_list

Check if baseline results were already computed. If so, load csv; otherwise, compute and save.

[10]:
BASELINE_RESULTS_PATH = ACS_AGG_RESULTS_PATH.parent / ("baseline-results." + ".".join(sorted(baselines.keys())) + ".csv")

# If saved results exists: load
if BASELINE_RESULTS_PATH.exists():
    print(f"Loading pre-computed baseline results from {BASELINE_RESULTS_PATH.as_posix()}")
    baselines_df = pd.read_csv(BASELINE_RESULTS_PATH, index_col=0)

# Compute baseline results
else:
    print(f"Computing baseline results and saving to {BASELINE_RESULTS_PATH.as_posix()}")
    # Compute baseline results
    baseline_results = run_baselines(baselines, tasks=ALL_TASKS)

    # Parse results
    parsed_results_list = parse_baseline_results(baseline_results)

    # Construct DF
    baselines_df = pd.DataFrame(parsed_results_list, index=[r["name"] for r in parsed_results_list])

    # Save DF to disk
    baselines_df.to_csv(BASELINE_RESULTS_PATH)

# Untie indices
baselines_df["new_index"] = [(r["name"] + "_" + r[task_col]) for _, r in baselines_df.iterrows()]
baselines_df = baselines_df.set_index("new_index", drop=True)

# Show 2 random rows
baselines_df.sample(2)
Loading pre-computed baseline results from ../results/baseline-results.GBM.LR.XGBoost.csv
[10]:
threshold n_samples n_positives n_negatives model_name accuracy tpr fnr fpr tnr ... equalized_odds_diff roc_auc ece ece_quantile predictions_path config_task_name config_model_name name num_features uses_all_features
new_index
XGBoost_ACSPublicCoverage 0.5 113829 33971 79858 NaN 0.801650 0.515175 0.484825 0.076486 0.923514 ... 0.368044 0.839742 0.004371 0.004271 /fast/groups/sf/folktexts-results/2024-07-03/b... ACSPublicCoverage XGBoost XGBoost -1 True
GBM_ACSIncome 0.5 166450 61233 105217 NaN 0.813584 0.727973 0.272027 0.136594 0.863406 ... 0.630389 0.890792 0.007721 0.007146 /fast/groups/sf/folktexts-results/2024-07-03/b... ACSIncome GBM GBM -1 True

2 rows × 42 columns

[11]:
all_results_df = pd.concat((results_df, baselines_df))
print(f"{all_results_df.shape=}")
all_results_df.shape=(185, 64)

Prepare results table for each task

[12]:
table_metrics = ["ece", "brier_score_loss", "roc_auc", "accuracy", "fit_thresh_accuracy", "score_stdev"] #, "score_mean"]

Add model size and model family columns:

[13]:
from folktexts.llm_utils import get_model_size_B

all_results_df["model_size"] = [
    (
        get_model_size_B(row["name"], default=300)  # Set GPT 4o size as the largest (300B)
        if row["name"] not in baselines else "-"
    )
    for _, row in all_results_df.iterrows()
]

def get_model_family(model_name) -> str:
    if "gpt" in model_name.lower():
        return "OpenAI"
    elif "llama" in model_name.lower():
        return "Llama"
    elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
        return "Mistral"
    elif "gemma" in model_name.lower():
        return "Gemma"
    elif "yi" in model_name.lower():
        return "Yi"
    elif "qwen" in model_name.lower():
        return "Qwen"
    else:
        return "-"

all_results_df["model_family"] = [get_model_family(row[model_col]) for _, row in all_results_df.iterrows()]
all_results_df.groupby([task_col, "model_family"])["accuracy"].count()
[13]:
config_task_name   model_family
ACSEmployment      -                3
                   Gemma            8
                   Llama            8
                   Mistral         12
                   OpenAI           2
                   Yi               4
ACSIncome          -                3
                   Gemma            8
                   Llama            8
                   Mistral         12
                   OpenAI           2
                   Yi               4
ACSMobility        -                3
                   Gemma            8
                   Llama            8
                   Mistral         12
                   OpenAI           2
                   Yi               4
ACSPublicCoverage  -                3
                   Gemma            8
                   Llama            8
                   Mistral         12
                   OpenAI           2
                   Yi               4
ACSTravelTime      -                3
                   Gemma            8
                   Llama            8
                   Mistral         12
                   OpenAI           2
                   Yi               4
Name: accuracy, dtype: int64
[14]:
def model_sort_key(name, task_df):
    """Sort key for paper table rows."""

    if "gpt" in name.lower():
        key = 11
    elif "llama" in name.lower():
        key = 10
    elif "mixtral" in name.lower():
        key = 9
    elif "mistral" in name.lower():
        key = 8
    elif "yi" in name.lower():
        key = 7
    elif "gemma" in name.lower():
        key = 6
    else:
        return 0

    row = task_df.loc[name]
    # return (1e5 * key) + (row["model_size"] * 100) + int(row["is_inst"] * 10) + (1 if row[numeric_prompt_col] else 0)
    return (row["model_size"] * 1e9) + int(row["is_inst"] * 1e6) + 1e3 * (1 if row[numeric_prompt_col] else 0) + key

[15]:
def latex_colored_float_format(val, all_values, higher_is_better=True):
    """Map a cell's value to its colored latex code.

    Current definition:
    - use cyan color gradient for good values;
    - use orange color gradient for bad values;
    - use no color for anything in between;
    """
    min_val, max_val = np.min(all_values), np.max(all_values)

    low_pct_val, high_pct_val = [
        min_val + (max_val - min_val) * interp_point
        for interp_point in [0.1, 0.9]
    ]

    # Use rounded value or original value for coloring?
    # > Using rounded value for consistency in table
    # val = np.round(val, decimals=2)

    # Use no color for middle 33% of values
    if low_pct_val <= val <= high_pct_val:
        return f"{val:.2f}"

    if val < low_pct_val:
        color = "orange" if higher_is_better else "cyan"
        color_value = 100 * (
            (low_pct_val - val) / (low_pct_val - min_val))

    elif val > high_pct_val:
        color = "cyan" if higher_is_better else "orange"
        color_value = 100 * (
            (val - high_pct_val) / (max_val - high_pct_val))

    else:
        raise RuntimeError(f"{val}")

    # Note: halving `color_value` to have softer colors
    color_value /= 4

    return (
        r"\cellcolor{"
        + f"{color}!{color_value:.1f}"
        + r"} "
        + f"{val:.2f}"
    )

higher_is_better_cols = {"roc_auc", "accuracy", "fit_thresh_accuracy"}

Output latex results tables - colored!

Set whether to use Numeric prompting or multiple-choice Q&A prompting!

[16]:
# NUMERIC_PROMPTING = False
NUMERIC_PROMPTING = True

Whether to combine results from both prompting schemes in the latex tables:

[17]:
COMBINE_PROMPTING_RESULTS = True
# COMBINE_PROMPTING_RESULTS = False
[18]:
from utils import prettify_model_name

for task in ALL_TASKS:
    task_df = all_results_df[all_results_df[task_col] == task]

    # Sort table rows
    sorted_df_index = sorted(
        task_df.index.tolist(),
        key=lambda id_: model_sort_key(id_, task_df),
        reverse=True,
    )

    latex_tab_cols = [numeric_prompt_col] + table_metrics

    # latex_table = task_df.sort_values(["model_family", "model_size", "is_inst"], ascending=False).set_index(model_col)[table_metrics].round(3)
    latex_table = task_df.loc[sorted_df_index].set_index(model_col)[latex_tab_cols].round(3)

    # Keep rows for a single selected prompting scheme
    latex_table = latex_table[
        (latex_table[numeric_prompt_col] == NUMERIC_PROMPTING)
        | latex_table[numeric_prompt_col].isna()
    ]

    # But optionally join with the columns of the other prompting scheme
    if COMBINE_PROMPTING_RESULTS:
        cols_to_keep = ["ece", "brier_score_loss", "roc_auc", "accuracy"]
        latex_table = latex_table[cols_to_keep]

        latex_table_other_prompt = task_df.loc[sorted_df_index].set_index(model_col)
        latex_table_other_prompt = latex_table_other_prompt[
            (latex_table_other_prompt[numeric_prompt_col] != NUMERIC_PROMPTING)
            | latex_table_other_prompt[numeric_prompt_col].isna()
        ][cols_to_keep].round(3)

        latex_table = latex_table.join(
            latex_table_other_prompt,
            how="left",
            lsuffix=" (num)" if NUMERIC_PROMPTING else " (mult. choice)",
            validate="1:1",
        )

    # Fill NaNs
    latex_table = latex_table.fillna("-")

    # Prettify model names
    latex_table["Model"] = [
        prettify_model_name(id_) if id_ not in baselines.keys() else id_
        for id_, row in latex_table.iterrows()
    ]
    latex_table.set_index("Model", drop=True, inplace=True)

    for col in set(latex_table.columns.tolist()) - {"score_stdev", numeric_prompt_col, "model_family"}:
        index_without_baselines = [name for name in latex_table.index if name not in baselines]
        index_baselines = list(baselines.keys())

        col_data = latex_table.loc[index_without_baselines, col].copy()

        new_col_data = [
            latex_colored_float_format(
                val=col_data.loc[id_], all_values=col_data,
                higher_is_better=any(col_name in col for col_name in higher_is_better_cols),
            )
            for id_ in index_without_baselines
        ] + [
            latex_table.loc[id_, col]
            for id_ in index_baselines
        ]

        # Set compatible dtype
        latex_table = latex_table.astype({col: str})

        # Set new data
        latex_table.loc[index_without_baselines + index_baselines, col] = new_col_data

    # Drop prompting column of not combining both
    latex_table = latex_table.drop(columns=numeric_prompt_col, errors="ignore")

    # Rename columns to be latex compatible
    latex_table = latex_table.rename(columns={col: col.replace("_", " ") for col in latex_table.columns})

    print(f"*** {task.upper()} ***\n")

    latex_table_str = latex_table.to_latex(float_format="%.2f")
    print(latex_table_str)

    tables_dir = ACS_AGG_RESULTS_PATH.parent / "tables"
    tables_dir.mkdir(exist_ok=True)
    with open(tables_dir / f"{task}-table.{COMBINE_PROMPTING_RESULTS=}.{NUMERIC_PROMPTING=}.tex", "w") as f_out:
        print(latex_table_str, file=f_out)

    print("")
*** ACSINCOME ***

\begin{tabular}{lllllllll}
\toprule
 & ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model &  &  &  &  &  &  &  &  \\
\midrule
GPT 4o mini (it) & \cellcolor{cyan!25.0} 0.05 & \cellcolor{cyan!25.0} 0.16 & \cellcolor{cyan!20.6} 0.83 & \cellcolor{cyan!24.4} 0.78 & 0.24 & 0.24 & \cellcolor{cyan!17.6} 0.85 & 0.74 \\
Mixtral 8x22B (it) & 0.11 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!25.0} 0.84 & \cellcolor{cyan!18.9} 0.77 & 0.21 & \cellcolor{cyan!3.6} 0.22 & \cellcolor{cyan!11.2} 0.85 & \cellcolor{cyan!11.1} 0.76 \\
Mixtral 8x22B & 0.13 & \cellcolor{cyan!3.6} 0.18 & \cellcolor{cyan!9.6} 0.82 & \cellcolor{cyan!2.4} 0.74 & \cellcolor{cyan!13.2} 0.17 & \cellcolor{cyan!21.6} 0.19 & \cellcolor{cyan!16.5} 0.85 & 0.68 \\
Llama 3 70B (it) & 0.25 & 0.23 & \cellcolor{cyan!22.1} 0.84 & 0.67 & 0.27 & 0.27 & \cellcolor{cyan!25.0} 0.86 & 0.69 \\
Llama 3 70B & 0.27 & 0.24 & \cellcolor{cyan!6.7} 0.82 & 0.54 & 0.20 & \cellcolor{cyan!14.9} 0.20 & \cellcolor{cyan!20.8} 0.86 & 0.70 \\
Mixtral 8x7B (it) & 0.10 & \cellcolor{cyan!15.5} 0.17 & \cellcolor{cyan!22.8} 0.84 & \cellcolor{cyan!12.8} 0.76 & \cellcolor{cyan!16.8} 0.16 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!19.7} 0.86 & \cellcolor{cyan!25.0} 0.78 \\
Mixtral 8x7B & \cellcolor{cyan!4.0} 0.07 & \cellcolor{cyan!13.1} 0.17 & \cellcolor{cyan!3.7} 0.81 & \cellcolor{cyan!25.0} 0.78 & \cellcolor{cyan!10.6} 0.17 & \cellcolor{cyan!11.5} 0.21 & 0.83 & 0.65 \\
Yi 34B (it) & 0.22 & 0.21 & 0.80 & 0.48 & \cellcolor{cyan!1.3} 0.19 & \cellcolor{cyan!18.8} 0.19 & \cellcolor{cyan!19.7} 0.86 & 0.72 \\
Yi 34B & 0.15 & 0.19 & \cellcolor{cyan!17.7} 0.83 & 0.61 & 0.25 & \cellcolor{cyan!2.5} 0.22 & \cellcolor{cyan!13.3} 0.85 & 0.62 \\
Llama 3 8B (it) & 0.23 & 0.23 & \cellcolor{cyan!0.1} 0.81 & 0.67 & 0.32 & 0.30 & \cellcolor{cyan!13.3} 0.85 & 0.62 \\
Llama 3 8B & 0.14 & 0.24 & 0.63 & \cellcolor{orange!3.7} 0.40 & 0.25 & 0.26 & 0.81 & \cellcolor{orange!20.8} 0.38 \\
Mistral 7B (it) & 0.16 & 0.19 & \cellcolor{cyan!14.7} 0.83 & 0.70 & 0.21 & \cellcolor{cyan!4.7} 0.22 & 0.83 & \cellcolor{cyan!16.5} 0.77 \\
Gemma 7B (it) & 0.33 & 0.30 & 0.78 & 0.42 & \cellcolor{orange!14.7} 0.61 & \cellcolor{orange!4.7} 0.59 & \cellcolor{cyan!3.8} 0.84 & \cellcolor{orange!25.0} 0.37 \\
Mistral 7B & \cellcolor{orange!15.7} 0.36 & 0.32 & 0.75 & 0.49 & 0.20 & 0.23 & 0.80 & 0.73 \\
Gemma 7B & 0.15 & 0.20 & 0.80 & 0.73 & 0.24 & 0.27 & 0.76 & \cellcolor{orange!25.0} 0.37 \\
Gemma 2B (it) & 0.28 & 0.31 & \cellcolor{orange!25.0} 0.50 & \cellcolor{orange!25.0} 0.37 & \cellcolor{orange!25.0} 0.63 & \cellcolor{orange!25.0} 0.63 & 0.73 & \cellcolor{orange!25.0} 0.37 \\
Gemma 2B & \cellcolor{orange!25.0} 0.37 & \cellcolor{orange!25.0} 0.37 & \cellcolor{orange!25.0} 0.50 & 0.63 & \cellcolor{cyan!25.0} 0.14 & 0.25 & \cellcolor{orange!25.0} 0.62 & 0.45 \\
LR & 0.03 & 0.18 & 0.79 & 0.74 & 0.03 & 0.18 & 0.79 & 0.74 \\
GBM & 0.01 & 0.13 & 0.89 & 0.81 & 0.01 & 0.13 & 0.89 & 0.81 \\
XGBoost & 0.00 & 0.13 & 0.90 & 0.82 & 0.00 & 0.13 & 0.90 & 0.82 \\
\bottomrule
\end{tabular}


*** ACSMOBILITY ***

\begin{tabular}{lllllllll}
\toprule
 & ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model &  &  &  &  &  &  &  &  \\
\midrule
GPT 4o mini (it) & 0.22 & 0.25 & 0.49 & \cellcolor{cyan!25.0} 0.73 & 0.26 & 0.26 & \cellcolor{cyan!0.6} 0.57 & \cellcolor{cyan!25.0} 0.73 \\
Mixtral 8x22B (it) & 0.05 & \cellcolor{cyan!17.9} 0.20 & \cellcolor{cyan!25.0} 0.54 & \cellcolor{cyan!25.0} 0.73 & 0.40 & 0.40 & \cellcolor{orange!12.8} 0.51 & 0.39 \\
Mixtral 8x22B & 0.13 & 0.22 & 0.49 & \cellcolor{cyan!25.0} 0.73 & \cellcolor{cyan!10.2} 0.11 & \cellcolor{cyan!21.2} 0.21 & 0.55 & \cellcolor{cyan!25.0} 0.73 \\
Llama 3 70B (it) & 0.05 & \cellcolor{cyan!10.7} 0.20 & 0.52 & \cellcolor{cyan!25.0} 0.73 & 0.20 & \cellcolor{cyan!1.5} 0.25 & 0.57 & 0.58 \\
Llama 3 70B & 0.06 & \cellcolor{cyan!7.1} 0.20 & 0.53 & \cellcolor{cyan!20.7} 0.73 & 0.22 & \cellcolor{cyan!4.8} 0.24 & 0.55 & 0.53 \\
Mixtral 8x7B (it) & 0.11 & 0.21 & 0.51 & \cellcolor{cyan!25.0} 0.73 & 0.26 & 0.26 & \cellcolor{cyan!12.8} 0.58 & \cellcolor{cyan!25.0} 0.73 \\
Mixtral 8x7B & 0.24 & 0.25 & \cellcolor{orange!25.0} 0.48 & \cellcolor{cyan!25.0} 0.73 & \cellcolor{cyan!0.8} 0.14 & \cellcolor{cyan!18.9} 0.21 & 0.57 & \cellcolor{cyan!25.0} 0.73 \\
Yi 34B (it) & 0.23 & 0.25 & 0.50 & \cellcolor{orange!25.0} 0.27 & \cellcolor{cyan!18.2} 0.09 & \cellcolor{cyan!23.1} 0.20 & 0.56 & \cellcolor{cyan!15.4} 0.72 \\
Yi 34B & 0.15 & 0.23 & 0.52 & 0.44 & \cellcolor{cyan!25.0} 0.07 & \cellcolor{cyan!25.0} 0.20 & 0.57 & \cellcolor{cyan!20.7} 0.73 \\
Llama 3 8B (it) & 0.11 & 0.21 & 0.49 & \cellcolor{cyan!23.4} 0.73 & 0.15 & \cellcolor{cyan!13.7} 0.22 & 0.56 & \cellcolor{cyan!5.3} 0.70 \\
Llama 3 8B & 0.14 & 0.21 & 0.51 & \cellcolor{cyan!19.1} 0.72 & \cellcolor{cyan!12.5} 0.10 & \cellcolor{cyan!21.7} 0.20 & 0.55 & \cellcolor{cyan!24.5} 0.73 \\
Mistral 7B (it) & 0.17 & 0.23 & 0.49 & \cellcolor{cyan!24.5} 0.73 & 0.26 & 0.26 & 0.57 & \cellcolor{cyan!25.0} 0.73 \\
Gemma 7B (it) & \cellcolor{orange!8.9} 0.25 & 0.26 & \cellcolor{orange!3.1} 0.49 & \cellcolor{cyan!22.3} 0.73 & 0.25 & 0.26 & \cellcolor{cyan!25.0} 0.58 & \cellcolor{cyan!21.8} 0.73 \\
Mistral 7B & \cellcolor{orange!25.0} 0.27 & \cellcolor{orange!25.0} 0.27 & 0.50 & \cellcolor{cyan!25.0} 0.73 & 0.20 & \cellcolor{cyan!8.6} 0.23 & 0.53 & \cellcolor{cyan!23.4} 0.73 \\
Gemma 7B & 0.19 & 0.24 & 0.49 & \cellcolor{cyan!25.0} 0.73 & 0.41 & 0.37 & \cellcolor{orange!25.0} 0.50 & \cellcolor{orange!25.0} 0.27 \\
Gemma 2B (it) & \cellcolor{cyan!25.0} 0.02 & \cellcolor{cyan!25.0} 0.20 & 0.50 & \cellcolor{cyan!25.0} 0.73 & \cellcolor{orange!25.0} 0.73 & \cellcolor{orange!25.0} 0.73 & 0.52 & \cellcolor{orange!25.0} 0.27 \\
Gemma 2B & \cellcolor{orange!25.0} 0.27 & \cellcolor{orange!25.0} 0.27 & 0.50 & \cellcolor{cyan!25.0} 0.73 & 0.25 & 0.26 & 0.51 & 0.34 \\
LR & 0.02 & 0.19 & 0.61 & 0.74 & 0.02 & 0.19 & 0.61 & 0.74 \\
GBM & 0.01 & 0.17 & 0.74 & 0.76 & 0.01 & 0.17 & 0.74 & 0.76 \\
XGBoost & 0.00 & 0.16 & 0.74 & 0.76 & 0.00 & 0.16 & 0.74 & 0.76 \\
\bottomrule
\end{tabular}


*** ACSEMPLOYMENT ***

\begin{tabular}{lllllllll}
\toprule
 & ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model &  &  &  &  &  &  &  &  \\
\midrule
GPT 4o mini (it) & 0.23 & 0.23 & 0.80 & 0.73 & 0.28 & 0.29 & 0.79 & 0.65 \\
Mixtral 8x22B (it) & 0.06 & \cellcolor{cyan!21.9} 0.14 & \cellcolor{cyan!24.4} 0.87 & \cellcolor{cyan!16.5} 0.79 & \cellcolor{orange!23.0} 0.38 & \cellcolor{orange!8.0} 0.39 & 0.60 & 0.51 \\
Mixtral 8x22B & 0.15 & 0.18 & 0.82 & \cellcolor{cyan!20.7} 0.80 & 0.21 & 0.24 & \cellcolor{cyan!25.0} 0.86 & 0.52 \\
Llama 3 70B (it) & \cellcolor{cyan!0.7} 0.05 & \cellcolor{cyan!25.0} 0.14 & \cellcolor{cyan!25.0} 0.88 & \cellcolor{cyan!25.0} 0.81 & 0.17 & \cellcolor{cyan!17.1} 0.19 & \cellcolor{cyan!23.3} 0.85 & \cellcolor{cyan!15.4} 0.73 \\
Llama 3 70B & 0.05 & \cellcolor{cyan!16.7} 0.15 & \cellcolor{cyan!13.6} 0.86 & \cellcolor{cyan!5.8} 0.78 & 0.25 & 0.26 & \cellcolor{cyan!2.3} 0.82 & 0.52 \\
Mixtral 8x7B (it) & 0.07 & \cellcolor{cyan!16.7} 0.15 & \cellcolor{cyan!23.2} 0.87 & \cellcolor{cyan!7.2} 0.78 & 0.22 & 0.24 & \cellcolor{cyan!6.8} 0.82 & \cellcolor{cyan!13.7} 0.73 \\
Mixtral 8x7B & 0.08 & 0.17 & 0.81 & 0.73 & 0.30 & 0.31 & \cellcolor{cyan!1.1} 0.81 & \cellcolor{orange!24.1} 0.45 \\
Yi 34B (it) & 0.15 & 0.21 & 0.81 & 0.51 & 0.14 & 0.21 & 0.79 & 0.69 \\
Yi 34B & 0.13 & 0.23 & 0.66 & 0.50 & 0.08 & 0.23 & 0.70 & 0.62 \\
Llama 3 8B (it) & 0.08 & 0.17 & 0.82 & 0.77 & 0.07 & \cellcolor{cyan!25.0} 0.19 & 0.79 & \cellcolor{cyan!25.0} 0.74 \\
Llama 3 8B & 0.15 & 0.23 & 0.75 & \cellcolor{orange!25.0} 0.46 & 0.34 & 0.34 & 0.76 & \cellcolor{orange!25.0} 0.45 \\
Mistral 7B (it) & \cellcolor{cyan!7.4} 0.04 & 0.19 & 0.79 & 0.69 & \cellcolor{orange!6.6} 0.35 & 0.36 & 0.72 & 0.63 \\
Gemma 7B (it) & \cellcolor{cyan!8.8} 0.04 & 0.22 & 0.71 & 0.60 & \cellcolor{orange!10.0} 0.36 & 0.38 & 0.59 & 0.58 \\
Mistral 7B & 0.14 & 0.19 & 0.80 & \cellcolor{cyan!12.9} 0.79 & 0.26 & 0.30 & 0.76 & \cellcolor{orange!25.0} 0.45 \\
Gemma 7B & \cellcolor{orange!25.0} 0.35 & \cellcolor{orange!25.0} 0.38 & \cellcolor{orange!1.6} 0.50 & 0.51 & 0.15 & 0.25 & 0.65 & \cellcolor{orange!4.1} 0.48 \\
Gemma 2B (it) & 0.12 & 0.27 & \cellcolor{orange!25.0} 0.46 & \cellcolor{orange!25.0} 0.46 & \cellcolor{orange!25.0} 0.38 & \cellcolor{orange!25.0} 0.41 & \cellcolor{orange!25.0} 0.42 & \cellcolor{orange!16.3} 0.46 \\
Gemma 2B & \cellcolor{cyan!25.0} 0.01 & 0.23 & 0.57 & 0.53 & \cellcolor{cyan!25.0} 0.01 & 0.24 & 0.63 & 0.54 \\
LR & 0.02 & 0.15 & 0.86 & 0.78 & 0.02 & 0.15 & 0.86 & 0.78 \\
GBM & 0.00 & 0.12 & 0.91 & 0.83 & 0.00 & 0.12 & 0.91 & 0.83 \\
XGBoost & 0.00 & 0.12 & 0.91 & 0.83 & 0.00 & 0.12 & 0.91 & 0.83 \\
\bottomrule
\end{tabular}


*** ACSTRAVELTIME ***

\begin{tabular}{lllllllll}
\toprule
 & ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model &  &  &  &  &  &  &  &  \\
\midrule
GPT 4o mini (it) & 0.15 & 0.27 & 0.58 & 0.57 & 0.39 & 0.40 & 0.65 & 0.55 \\
Mixtral 8x22B (it) & 0.12 & \cellcolor{cyan!23.7} 0.24 & \cellcolor{cyan!23.5} 0.64 & \cellcolor{cyan!25.0} 0.59 & 0.31 & 0.33 & 0.66 & \cellcolor{cyan!17.1} 0.59 \\
Mixtral 8x22B & 0.30 & 0.34 & 0.57 & \cellcolor{cyan!2.4} 0.58 & 0.20 & 0.28 & 0.63 & \cellcolor{orange!25.0} 0.44 \\
Llama 3 70B (it) & 0.12 & \cellcolor{cyan!22.5} 0.24 & \cellcolor{cyan!25.0} 0.64 & 0.53 & 0.15 & \cellcolor{cyan!22.6} 0.24 & \cellcolor{cyan!25.0} 0.70 & \cellcolor{cyan!25.0} 0.60 \\
Llama 3 70B & 0.08 & \cellcolor{cyan!12.4} 0.25 & 0.52 & 0.46 & 0.09 & \cellcolor{cyan!21.4} 0.24 & 0.67 & 0.55 \\
Mixtral 8x7B (it) & 0.09 & \cellcolor{cyan!25.0} 0.24 & 0.61 & 0.57 & \cellcolor{orange!25.0} 0.45 & \cellcolor{orange!25.0} 0.45 & 0.66 & 0.52 \\
Mixtral 8x7B & 0.07 & \cellcolor{cyan!16.2} 0.25 & 0.57 & \cellcolor{cyan!2.4} 0.58 & 0.28 & 0.32 & 0.60 & \cellcolor{orange!25.0} 0.44 \\
Yi 34B (it) & \cellcolor{cyan!3.7} 0.06 & \cellcolor{cyan!11.2} 0.25 & 0.50 & \cellcolor{orange!25.0} 0.44 & 0.35 & 0.36 & 0.65 & 0.56 \\
Yi 34B & 0.14 & 0.27 & 0.53 & \cellcolor{orange!25.0} 0.44 & \cellcolor{cyan!4.5} 0.08 & \cellcolor{cyan!25.0} 0.24 & 0.62 & 0.56 \\
Llama 3 8B (it) & 0.11 & \cellcolor{cyan!12.4} 0.25 & 0.56 & 0.56 & 0.19 & 0.28 & 0.60 & 0.57 \\
Llama 3 8B & 0.12 & 0.26 & \cellcolor{orange!25.0} 0.48 & \cellcolor{orange!25.0} 0.44 & 0.08 & \cellcolor{cyan!7.1} 0.25 & 0.53 & 0.56 \\
Mistral 7B (it) & 0.11 & \cellcolor{cyan!11.2} 0.25 & 0.55 & 0.56 & \cellcolor{orange!2.7} 0.41 & 0.42 & 0.59 & 0.57 \\
Gemma 7B (it) & 0.10 & \cellcolor{cyan!4.9} 0.26 & \cellcolor{orange!4.9} 0.49 & \cellcolor{orange!25.0} 0.44 & \cellcolor{orange!12.0} 0.42 & 0.43 & 0.53 & 0.56 \\
Mistral 7B & \cellcolor{orange!25.0} 0.44 & \cellcolor{orange!25.0} 0.44 & 0.50 & 0.56 & \cellcolor{cyan!22.5} 0.05 & \cellcolor{cyan!14.2} 0.25 & 0.57 & 0.56 \\
Gemma 7B & \cellcolor{cyan!25.0} 0.03 & \cellcolor{cyan!13.7} 0.25 & 0.52 & 0.55 & \cellcolor{cyan!25.0} 0.04 & \cellcolor{cyan!19.0} 0.24 & 0.61 & 0.58 \\
Gemma 2B (it) & 0.19 & 0.28 & 0.50 & 0.56 & 0.34 & 0.36 & \cellcolor{orange!9.3} 0.49 & 0.56 \\
Gemma 2B & \cellcolor{orange!25.0} 0.44 & \cellcolor{orange!25.0} 0.44 & 0.50 & 0.56 & 0.09 & \cellcolor{cyan!4.7} 0.26 & \cellcolor{orange!25.0} 0.48 & \cellcolor{orange!25.0} 0.44 \\
LR & 0.04 & 0.24 & 0.58 & 0.56 & 0.04 & 0.24 & 0.58 & 0.56 \\
GBM & 0.02 & 0.20 & 0.75 & 0.69 & 0.02 & 0.20 & 0.75 & 0.69 \\
XGBoost & 0.02 & 0.19 & 0.77 & 0.70 & 0.02 & 0.19 & 0.77 & 0.70 \\
\bottomrule
\end{tabular}


*** ACSPUBLICCOVERAGE ***

\begin{tabular}{lllllllll}
\toprule
 & ece (num) & brier score loss (num) & roc auc (num) & accuracy (num) & ece & brier score loss & roc auc & accuracy \\
Model &  &  &  &  &  &  &  &  \\
\midrule
GPT 4o mini (it) & 0.10 & 0.20 & 0.68 & \cellcolor{cyan!17.6} 0.73 & 0.33 & 0.34 & \cellcolor{cyan!25.0} 0.71 & 0.60 \\
Mixtral 8x22B (it) & \cellcolor{cyan!25.0} 0.04 & \cellcolor{cyan!25.0} 0.18 & \cellcolor{cyan!11.0} 0.71 & \cellcolor{cyan!25.0} 0.75 & 0.24 & 0.25 & \cellcolor{cyan!11.7} 0.70 & \cellcolor{cyan!12.7} 0.72 \\
Mixtral 8x22B & \cellcolor{orange!15.5} 0.29 & \cellcolor{orange!3.4} 0.29 & 0.54 & \cellcolor{cyan!0.5} 0.70 & 0.32 & 0.30 & 0.59 & \cellcolor{orange!25.0} 0.30 \\
Llama 3 70B (it) & 0.13 & 0.20 & \cellcolor{cyan!25.0} 0.73 & \cellcolor{cyan!24.4} 0.75 & 0.16 & \cellcolor{cyan!16.2} 0.21 & \cellcolor{cyan!7.0} 0.69 & \cellcolor{cyan!25.0} 0.75 \\
Llama 3 70B & 0.12 & 0.21 & 0.64 & 0.53 & 0.18 & \cellcolor{cyan!8.3} 0.22 & 0.67 & 0.63 \\
Mixtral 8x7B (it) & \cellcolor{cyan!0.3} 0.06 & \cellcolor{cyan!18.5} 0.19 & 0.69 & \cellcolor{cyan!22.2} 0.74 & 0.20 & \cellcolor{cyan!7.4} 0.23 & \cellcolor{cyan!12.7} 0.70 & \cellcolor{cyan!23.3} 0.74 \\
Mixtral 8x7B & 0.20 & 0.25 & 0.56 & 0.70 & 0.41 & 0.37 & 0.57 & \cellcolor{orange!25.0} 0.30 \\
Yi 34B (it) & 0.22 & 0.24 & 0.57 & \cellcolor{orange!25.0} 0.31 & \cellcolor{cyan!16.7} 0.06 & \cellcolor{cyan!25.0} 0.19 & 0.67 & \cellcolor{cyan!19.4} 0.74 \\
Yi 34B & 0.09 & 0.20 & 0.67 & 0.64 & \cellcolor{cyan!25.0} 0.04 & \cellcolor{cyan!17.2} 0.21 & 0.59 & \cellcolor{cyan!2.1} 0.70 \\
Llama 3 8B (it) & 0.17 & 0.22 & 0.64 & 0.68 & 0.11 & \cellcolor{cyan!12.7} 0.21 & 0.59 & \cellcolor{cyan!7.7} 0.71 \\
Llama 3 8B & 0.20 & 0.25 & 0.51 & \cellcolor{orange!6.2} 0.34 & 0.41 & 0.38 & 0.55 & \cellcolor{orange!25.0} 0.30 \\
Mistral 7B (it) & 0.07 & 0.20 & 0.67 & 0.65 & 0.30 & 0.30 & 0.61 & \cellcolor{cyan!0.4} 0.70 \\
Gemma 7B (it) & 0.18 & 0.24 & 0.57 & 0.61 & 0.30 & 0.34 & \cellcolor{orange!14.6} 0.46 & 0.50 \\
Mistral 7B & \cellcolor{orange!22.1} 0.30 & \cellcolor{orange!18.5} 0.30 & 0.50 & 0.70 & 0.29 & 0.30 & \cellcolor{orange!25.0} 0.45 & \cellcolor{orange!25.0} 0.30 \\
Gemma 7B & 0.18 & 0.26 & 0.48 & 0.70 & 0.15 & \cellcolor{cyan!3.4} 0.23 & 0.49 & 0.49 \\
Gemma 2B (it) & 0.24 & \cellcolor{orange!14.2} 0.29 & \cellcolor{orange!25.0} 0.42 & 0.42 & \cellcolor{orange!25.0} 0.70 & \cellcolor{orange!25.0} 0.70 & 0.54 & \cellcolor{orange!25.0} 0.30 \\
Gemma 2B & \cellcolor{orange!25.0} 0.30 & \cellcolor{orange!25.0} 0.30 & 0.50 & 0.70 & 0.26 & 0.28 & 0.54 & \cellcolor{orange!25.0} 0.30 \\
LR & 0.03 & 0.19 & 0.70 & 0.72 & 0.03 & 0.19 & 0.70 & 0.72 \\
GBM & 0.01 & 0.14 & 0.83 & 0.80 & 0.01 & 0.14 & 0.83 & 0.80 \\
XGBoost & 0.00 & 0.14 & 0.84 & 0.80 & 0.00 & 0.14 & 0.84 & 0.80 \\
\bottomrule
\end{tabular}


Render paper plots

Filter by the currently selected prompting scheme:

[19]:
all_results_df_original = all_results_df.copy()

all_results_df = all_results_df[
    (all_results_df[numeric_prompt_col] == NUMERIC_PROMPTING)
    | (all_results_df[numeric_prompt_col].isna())
]
[20]:
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_style("whitegrid", rc={"grid.linestyle": "--"})
plt.rc("figure", figsize=(4.5, 3.5), dpi=200)

palette_name = "tab10"   # "colorblind"
palette = sns.color_palette(palette_name)
sns.set_palette(palette_name)
palette
[20]:
[21]:
IMGS_DIR = ACS_AGG_RESULTS_PATH.parent / "imgs"
IMGS_DIR.mkdir(exist_ok=True)

def save_fig(fig, name, add_prompt_suffix=True):
    num_or_multiple_choice = "numeric-prompt" if NUMERIC_PROMPTING else "multiple-choice-prompt"
    save_path = (IMGS_DIR / name).with_suffix(".pdf")
    if add_prompt_suffix:
        save_path = save_path.with_suffix(f".{num_or_multiple_choice}.pdf")

    fig.savefig(save_path, bbox_inches="tight")
    print(f"Saved figure to '{save_path.as_posix()}'")

Load scores distributions for each model (and with varying degrees of information).

[22]:
predictions_path_col = "predictions_path"

def load_model_scores_df(df_row: pd.Series) -> pd.DataFrame:
    """Loads csv containing model scores corresponding to the given DF row."""
    if predictions_path_col in df_row and not pd.isna(df_row[predictions_path_col]):
        return pd.read_csv(df_row[predictions_path_col], index_col=0)
    return None
[23]:
scores_df_map = {
    id_: load_model_scores_df(row)
    for id_, row in tqdm(all_results_df.iterrows(), total=len(all_results_df))
}

1st page illustrative plot

[24]:
example_task = "ACSIncome"

# example_model = "Mistral-7B-v0.1"
example_model = "Meta-Llama-3-70B"

baseline_model = "XGBoost"
[25]:
# Data for baseline model
baseline_row = all_results_df[(all_results_df["name"] == baseline_model) & (all_results_df[task_col] == example_task)].iloc[0]

# Data for base and instruct models
example_df = all_results_df[
    (all_results_df[task_col] == example_task)
    & (all_results_df["base_name"] == example_model)
    & (all_results_df[numeric_prompt_col] == NUMERIC_PROMPTING)
]

# Sort examples_df to have (base, instruct) ordering
example_df = example_df.sort_values("is_inst", ascending=True)
example_df
[25]:
accuracy accuracy_diff accuracy_ratio balanced_accuracy balanced_accuracy_diff balanced_accuracy_ratio brier_score_loss ece ece_quantile equalized_odds_diff ... num_features uses_all_features fit_thresh_on_100 fit_thresh_accuracy optimal_thresh optimal_thresh_accuracy score_stdev score_mean model_size model_family
Meta-Llama-3-70B__ACSIncome__-1__Num 0.543953 0.134809 0.771466 0.636284 0.078752 0.887756 0.237630 0.270157 NaN 0.191080 ... -1 True 0.89560 0.784722 0.8956 0.784722 0.253299 0.638033 70 Llama
Meta-Llama-3-70B-Instruct__ACSIncome__-1__Num 0.665731 0.086680 0.880953 0.722658 0.061832 0.919317 0.230512 0.246092 NaN 0.232483 ... -1 True 0.75718 0.735764 0.8239 0.784848 0.290286 0.613968 70 Llama

2 rows × 66 columns

[26]:
ALPHA = 0.7
N_BINS = 20

fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, figsize=(11, 2.5), gridspec_kw=dict(wspace=0.4))

###
# ROC barplot
###
sns.barplot(
    data=example_df,
    x="is_inst",
    y="roc_auc",
    hue="name",
    alpha=ALPHA,
    width=0.5,
    ax=ax1,
)

# Add horizontal line and label
ax1.axhline(y=baseline_row["roc_auc"], label=baseline_model, ls="--", color="black", xmin=0.05, xmax=0.95)
ax1.text(
    x=1.15,
    y=baseline_row["roc_auc"] + 1e-3,
    s=baseline_model,
    color="black",
    fontsize=9,
    ha='center',
    va='bottom',
    backgroundcolor='white',
    zorder=-1,
)

ax1.set_ylim(0, baseline_row["roc_auc"] + 9e-2)
# ax1.set_ylim(0.5, baseline_row["roc_auc"] + 7e-2)
ax1.set_ylabel("ROC AUC")
ax1.set_xlabel(None)
ax1.legend().remove()   # Remove the legend
ax1.set_xticklabels(["Base", "Instr."])
ax1.set_title("Predictive signal" + r"$\uparrow$")

###
# Score distribution
###
bins = np.histogram_bin_edges([], bins=N_BINS, range=(0, 1))
for id_ in example_df.index.tolist():
    sns.histplot(
        scores_df_map[id_]["risk_score"],
        alpha=ALPHA * 0.9,
        stat="proportion",
        bins=bins,
        zorder=100,
        ax=ax2,
    )

# # Draw baseline score distribution
# sns.histplot(
#     scores_df_map[baseline_row.name]["risk_score"],
#     alpha=1,
#     bins=bins,
#     stat="proportion",
#     fill=False,
#     color=palette[2],
#     edgecolor=palette[2],
#     hatch="/",
#     zorder=-1,
#     ax=ax2,
# )

ax2.set_xlabel(None)
ax2.set_title("Score distribution")

###
# Calibration error
###
sns.barplot(
    data=example_df,
    x="is_inst",
    y="ece",
    hue="name",
    alpha=ALPHA,
    width=0.5,
    ax=ax3,
)

# Add horizontal line and label
ax3.axhline(y=baseline_row["ece"], label=baseline_model, ls="--", color="black", xmin=0.05, xmax=0.95)

ax3.set_ylabel("ECE")
ax3.set_xlabel(None)
ax3.set_xticklabels(["Base", "Instr."])
ax3.set_title("Calibration error" + "$\downarrow$")
# ax3.legend(
#     loc="center left",
#     bbox_to_anchor=(1.05, 0.5),
# )

hs, ls = ax3.get_legend_handles_labels()
ax3.legend().remove()
ax2.legend(handles=hs[:2], labels=ls[:2], loc="upper center")

save_fig(fig, f"teaser-{example_model}")
/tmp/ipykernel_3388197/4070898515.py:38: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax1.set_xticklabels(["Base", "Instr."])
/tmp/ipykernel_3388197/4070898515.py:90: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax3.set_xticklabels(["Base", "Instr."])
Saved figure to '../results/imgs/teaser-Meta-Llama-3-70B.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_41_2.png

Score distributions for base/instruct pairs

[27]:
print(f"Plotting score distributions for the following task: {example_task}.")

task_df = all_results_df[
    (all_results_df[task_col] == example_task)
    & (
        (all_results_df[numeric_prompt_col] == NUMERIC_PROMPTING)
        | (all_results_df[numeric_prompt_col].isna())                  # Baseline models will have NaNs in this col.
    )]
print(f"{len(task_df)=}")
Plotting score distributions for the following task: ACSIncome.
len(task_df)=20

Sort model pairs as wished for the paper plot:

[28]:
sorted_model_families = ["OpenAI", "Llama", "Mistral", "Yi", "Gemma", "-"]

model_pairs = sorted([
        (row["base_name"], row[model_col])
        for id_, row in task_df[task_df["is_inst"] == True].iterrows()
    ],
    key=lambda val: (
        1e6 * sorted_model_families.index((row := task_df[task_df[model_col] == val[0]].iloc[0])["model_family"])
        + row["model_size"]
    )
)
model_pairs
[28]:
[('penai/gpt-4o-mini', 'penai/gpt-4o-mini'),
 ('Meta-Llama-3-8B', 'Meta-Llama-3-8B-Instruct'),
 ('Meta-Llama-3-70B', 'Meta-Llama-3-70B-Instruct'),
 ('Mistral-7B-v0.1', 'Mistral-7B-Instruct-v0.2'),
 ('Mixtral-8x7B-v0.1', 'Mixtral-8x7B-Instruct-v0.1'),
 ('Mixtral-8x22B-v0.1', 'Mixtral-8x22B-Instruct-v0.1'),
 ('Yi-34B', 'Yi-34B-Chat'),
 ('gemma-2b', 'gemma-1.1-2b-it'),
 ('gemma-7b', 'gemma-1.1-7b-it')]
[29]:
PLOT_THRESHOLD_LINE = False

N_PLOTS = len(model_pairs) + 1  # plus 1 for the XGBoost baseline

N_COLS = 5
N_ROWS = math.ceil(N_PLOTS / N_COLS)

fig, axes = plt.subplots(
    nrows=N_ROWS, ncols=N_COLS,
    figsize=(12, 3 * N_ROWS),
    sharey=True, sharex=False,
    gridspec_kw=dict(
        hspace=0.5,
        wspace=0.06,
    ),
)

# Plot settings
plot_config = dict(
    bins=N_BINS,
    binrange=(0, 1),
    stat="percent",
)

# Plot score distribution for model pairs
for idx, (base_name, instr_name) in enumerate(model_pairs):
    ax_row = idx // N_COLS
    ax_col = idx % N_COLS

    ax = axes[ax_row, ax_col]
    base_row = task_df[task_df[model_col] == base_name].iloc[0]
    it_row = task_df[task_df[model_col] == instr_name].iloc[0]

    # Get scores
    base_scores = scores_df_map[base_row.name]["risk_score"]
    it_scores = scores_df_map[it_row.name]["risk_score"]

    ax.set_title(base_row["name"])

    ax.set_xlabel("risk scores")
    if ax_col == 0:
        ax.set_ylabel("density (%)")

    if ax_row == 0:
        ax.set_ylim(0, 70)

    ax.set_xlim(-0.03, 1.03)

    # Render plot
    if "gpt" not in instr_name.lower():
        sns.histplot(base_scores, label="Base", color=palette[0], ax=ax, **plot_config)
    sns.histplot(it_scores, label="Instruction-tuned", color=palette[1], ax=ax, **plot_config)

    # Draw faint vertical line with t^* of each model
    threshold_col = "optimal_thresh"
    # threshold_col = "fit_thresh_on_100"

    if PLOT_THRESHOLD_LINE:
        ax.axvline(x=base_row[threshold_col], ls="--", color=palette[0], alpha=0.5)
        ax.axvline(x=it_row[threshold_col], ls="--", color=palette[1], alpha=0.5)

    # Render legend on a few sub-plots
    if (ax_col == 0 and ax_row == 1) or (ax_col == N_COLS-1 and ax_row == 0):
        ax.legend(loc="upper center")

        if PLOT_THRESHOLD_LINE:
            ax.text(
                x=base_row[threshold_col] - 5e-2,
                y=43,
                s="$t^*$",
                color=palette[0],
                fontsize=10,
                ha='center',
                va='bottom',
                # zorder=-1,
            ) #.set_bbox(dict(facecolor='white', alpha=0.5))

            ax.text(
                x=it_row[threshold_col] + 5e-2,
                y=43,
                s="$t^*$",
                color=palette[1],
                fontsize=10,
                ha='center',
                va='bottom',
                # zorder=-1,
            )


# Plot score distribution for XGBoost baseline
baseline_name = "XGBoost"
ax = axes[1, -1]
ax.set_title(baseline_name)

baseline_row = task_df[task_df["name"] == baseline_name].iloc[0]
xgboost_scores = scores_df_map[baseline_row.name]["risk_score"]

sns.histplot(
    xgboost_scores,
    label=baseline_name,
    color=palette[2],
    ax=ax,
    **plot_config,
)

ax.set_xlabel("risk scores")

# # Remove unused plot?
# axes[1, 3].remove()

axes[1, 4].set_ylabel("density (%)")
axes[1, 4].set_yticks([0, 20, 40, 60])
axes[1, 4].set_yticklabels([0, 20, 40, 60])

plt.plot()

# Save figure
save_fig(fig, f"score-distribution" + ("_with-threshold" if PLOT_THRESHOLD_LINE else ""))
Saved figure to '../results/imgs/score-distribution.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_46_1.png

Calibration curves

Base models vs Instruction-tuned models:

[30]:
TASK = "ACSIncome"

# df = results_df[(results_df[task_col] == TASK) & (results_df[numeric_prompt_col] == NUMERIC_PROMPTING)]

df = all_results_df[(all_results_df[task_col] == TASK) & (all_results_df[numeric_prompt_col] == NUMERIC_PROMPTING)]
print(f"{len(df)=}")
len(df)=17
[31]:
from sklearn.calibration import CalibrationDisplay
from utils import prettify_model_name

base_models_to_plot = [m for m in df["base_name"].unique() if "gemma" not in m.lower()]

base_models_to_plot = [
    "Meta-Llama-3-8B",
    "Meta-Llama-3-70B",
    "Mistral-7B-v0.1",
    "Mixtral-8x7B-v0.1",
    "Mixtral-8x22B-v0.1",
    # "Yi-34B",
]

instr_models_to_plot = [
    df[
        (df["base_name"] == m)
        & (df[model_col] != m)
    ][model_col].iloc[0]
    for m in base_models_to_plot
]

# Add GPT models if any
instr_models_to_plot += [m[model_col] for _, m in df.iterrows() if "gpt" in m["name"].lower()]

fig, axes = plt.subplots(ncols=2, figsize=(8, 3), gridspec_kw=dict(wspace=0.33), sharey=False)

# Plot base models
for ax, model_list in zip(axes, (base_models_to_plot, instr_models_to_plot)):

    for idx, m in enumerate(model_list):

        curr_row = df[df[model_col] == m].iloc[0]
        curr_scores = scores_df_map[curr_row.name]

        disp = CalibrationDisplay.from_predictions(
            y_true=curr_scores["label"],
            y_prob=curr_scores["risk_score"],
            n_bins=10,
            strategy="quantile",
            name=prettify_model_name(curr_row["base_name"]),
            linestyle=["-.", ":", "--"][idx % 3],
            ax=ax,
        )

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

ylabel = "Fraction of positives " + r"$\left(\mathrm{P}[Y=1]\right)$"
axes[0].set_ylabel(ylabel)
axes[1].set_ylabel(ylabel)

xlabel = "Mean predicted probability" # + r" $(\mathbb{E}[R])$"
axes[0].set_xlabel(xlabel)
axes[1].set_xlabel(xlabel)

axes[0].set_title("Base models")
axes[1].set_title("Instruction-tuned models")

axes[0].legend().remove()

# Legent to the right of the right-most plot
h, l = axes[1].get_legend_handles_labels()
l[0] = "Perf. calibrated"
axes[1].legend(handles=h, labels=l, loc="upper left", bbox_to_anchor=(1.1, 1.1))


plt.plot()

save_fig(fig, f"calibration-curves-base-and-instr.pdf")
Saved figure to '../results/imgs/calibration-curves-base-and-instr.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_50_1.png
[32]:
from sklearn.calibration import CalibrationDisplay
from utils import prettify_model_name

base_models_to_plot = [m for m in df["base_name"].unique() if "gemma" not in m.lower()]

base_models_to_plot = [
    "Meta-Llama-3-70B",
    "Mixtral-8x7B-v0.1",
    "Mixtral-8x22B-v0.1",
    "Yi-34B",
]

instr_models_to_plot = [
    df[
        (df["base_name"] == m)
        & (df[model_col] != m)
    ][model_col].iloc[0]
    for m in base_models_to_plot
]

# Add GPT models if any
instr_models_to_plot += [m[model_col] for _, m in df.iterrows() if "gpt" in m["name"].lower()]

fig, axes = plt.subplots(ncols=2, figsize=(8, 3), gridspec_kw=dict(wspace=0.33), sharey=False)

# Plot base models
for ax, model_list in zip(axes, (base_models_to_plot, instr_models_to_plot)):

    for idx, m in enumerate(model_list):

        curr_row = df[df[model_col] == m].iloc[0]
        curr_scores = scores_df_map[curr_row.name]

        disp = CalibrationDisplay.from_predictions(
            y_true=curr_scores["label"],
            y_prob=curr_scores["risk_score"],
            n_bins=10,
            strategy="quantile",
            name=prettify_model_name(curr_row["base_name"]),
            linestyle=["-.", ":", "--"][idx % 3],
            ax=ax,
        )

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

ylabel = "Fraction of positives " + r"$\left(\mathrm{P}[Y=1]\right)$"
axes[0].set_ylabel(ylabel)
axes[1].set_ylabel(ylabel)

xlabel = "Mean predicted probability" # + r" $(\mathbb{E}[R])$"
axes[0].set_xlabel(xlabel)
axes[1].set_xlabel(xlabel)

axes[0].set_title("Base models")
axes[1].set_title("Instruction-tuned models")

axes[0].legend().remove()

# Legent to the right of the right-most plot
h, l = axes[1].get_legend_handles_labels()
l[0] = "Perf. calibrated"
axes[1].legend(handles=h, labels=l, loc="upper left", bbox_to_anchor=(1.1, 1.1))

plt.plot()

save_fig(fig, f"calibration-curves-base-and-instr.large-models.pdf")
Saved figure to '../results/imgs/calibration-curves-base-and-instr.large-models.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_52_1.png
[33]:
from sklearn.calibration import CalibrationDisplay
from utils import prettify_model_name

base_models_to_plot = [m for m in df["base_name"].unique() if "gemma" not in m.lower()]

base_models_to_plot = [
    "Meta-Llama-3-70B",
    "Mixtral-8x7B-v0.1",
    "Mixtral-8x22B-v0.1",
    "Yi-34B",
]

instr_models_to_plot = [
    df[
        (df["base_name"] == m)
        & (df[model_col] != m)
    ][model_col].iloc[0]
    for m in base_models_to_plot
]

# Add GPT models if any
instr_models_to_plot += [m[model_col] for _, m in df.iterrows() if "gpt" in m["name"].lower()]

fig, axes = plt.subplots(ncols=2, figsize=(5, 2.2), gridspec_kw=dict(wspace=0.15), sharey=True)

# Plot base models
for ax, model_list in zip(axes, (base_models_to_plot, instr_models_to_plot)):

    for idx, m in enumerate(model_list):

        curr_row = df[df[model_col] == m].iloc[0]
        curr_scores = scores_df_map[curr_row.name]

        disp = CalibrationDisplay.from_predictions(
            y_true=curr_scores["label"],
            y_prob=curr_scores["risk_score"],
            n_bins=10,
            strategy="quantile",
            name=prettify_model_name(curr_row["base_name"]),
            linestyle=["-.", ":", "--"][idx % 3],
            ax=ax,
            ms=3.5,
        )

    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)

ylabel = "Fraction of positives " + r"$\left(\mathrm{P}[Y=1]\right)$"
axes[0].set_ylabel(ylabel)
# axes[1].set_ylabel(ylabel)
axes[1].set_ylabel("")

# xlabel = "Mean predicted probability" # + r" $(\mathbb{E}[R])$"
xlabel = "Mean risk score"
axes[0].set_xlabel(xlabel)
axes[1].set_xlabel(xlabel)

axes[0].set_title("Base models")
axes[1].set_title("Instruction-tuned")

# Set prettier x ticks
xticks = (0, 0.25, 0.5, 0.75, 1)
xticks_labels = ("0.0", "0.25", "0.5", "0.75", "1.0")
axes[0].set_xticks(xticks, labels=xticks_labels)
axes[1].set_xticks(xticks, labels=xticks_labels)

axes[0].legend().remove()

# Legent to the right of the right-most plot
if NUMERIC_PROMPTING:
    h, l = axes[1].get_legend_handles_labels()
    l[0] = "Perf. calibrated"
    axes[1].legend(handles=h, labels=l, loc="upper left", bbox_to_anchor=(1, 1))
else:
    axes[1].legend().remove()

# Figure title
if NUMERIC_PROMPTING:
    title = "Numeric Prompting"
else:
    title = "Multiple-choice Prompting"
plt.suptitle(title, y=1.1, fontsize=14)

plt.plot()

save_fig(fig, f"calibration-curves-base-and-instr.large-models.smaller.pdf")
Saved figure to '../results/imgs/calibration-curves-base-and-instr.large-models.smaller.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_54_1.png

Calibration per sub-group

Plot calibration curves per race sub-group, for each model pair.

[34]:
%%time
acs_dataset = ACSDataset.make_from_task(task=TASK, cache_dir=DATA_DIR)
Loading ACS data...
CPU times: user 41.3 s, sys: 28.3 s, total: 1min 9s
Wall time: 1min 9s

Plot:

[35]:
# Omit small groups from the plots (results can be too noisy)
OMIT_GROUPS_BELOW_PCT = 4
N_BINS = 10
[36]:
sens_attr_col = "RAC1P"
sens_attr_map = {
    1: "White",
    2: "Black",
    6: "Asian",
    # 8: "Other",   # some other race alone
}

name_to_sens_attr_map = {val: key for key, val in sens_attr_map.items()}

def plot_cali_curve(model_id, ax, df):
    """Plot sub-group calibration curves using sklearn"""

    curr_row = df.loc[model_id]
    curr_scores_df = scores_df_map[curr_row.name]

    y_true = curr_scores_df["label"].to_numpy()
    y_scores = curr_scores_df["risk_score"].to_numpy()

    sens_attr = acs_dataset.data.loc[curr_scores_df.index][sens_attr_col].to_numpy()

    for idx, group_val in enumerate(np.unique(sens_attr)):
        group_filter = sens_attr == group_val

        if np.sum(group_filter) < len(group_filter) * 0.01 * OMIT_GROUPS_BELOW_PCT:
            # print(f"Skipping group={group_val}")
            continue

        disp = CalibrationDisplay.from_predictions(
            y_true=y_true[group_filter],
            y_prob=y_scores[group_filter],
            n_bins=N_BINS,
            # strategy="uniform",
            strategy="quantile",
            name=str(sens_attr_map.get(group_val, group_val)),
            linestyle=["-.", ":", "--"][idx % 3],
            ax=ax,
        )

    # Set other plot settings
    ax.set_title(curr_row["name"])
    return disp
[37]:
from folktexts.evaluation import bootstrap_estimate

def plot_cali_curve_ci(model_id, ax, df, ci=95):
    """Plot sub-group calibration curves with bootstrap confidence intervals"""

    curr_row = df.loc[model_id]
    curr_scores_df = scores_df_map[curr_row.name]

    y_true = curr_scores_df["label"].to_numpy()
    y_scores = curr_scores_df["risk_score"].to_numpy()

    sens_attr = acs_dataset.data.loc[curr_scores_df.index][sens_attr_col].to_numpy()

    # Compute bins
    if len(np.unique(y_scores)) == 1:
        ax.set_title(curr_row["name"])
        return

    y_scores_binned, bin_edges = pd.qcut(y_scores, q=N_BINS, retbins=True, duplicates="drop")
    bin_names = y_scores_binned.categories   # Name used for each bin: "(bin_low_range, bin_high_range)"

    # To hold plot render artifacts
    plot_artifacts = {}

    for idx, group_val in enumerate(np.unique(sens_attr)):
        group_filter = sens_attr == group_val

        if (
            np.sum(group_filter) < len(group_filter) * 0.01 * OMIT_GROUPS_BELOW_PCT
            or group_val not in sens_attr_map.keys()
        ):
            # print(f"Skipping group={group_val}")
            continue

        group_labels = y_true[group_filter]
        group_scores = y_scores[group_filter]

        group_cali_points = []
        for b in bin_names:
            bin_filter = y_scores_binned == b

            group_labels_current_bin = y_true[group_filter & bin_filter]
            group_scores_current_bin = y_scores[group_filter & bin_filter]

            results = bootstrap_estimate(
                eval_func=lambda labels, scores, sensattr: {
                    "mean_score": np.mean(scores),
                    "mean_label": np.mean(labels),
                },
                y_true=group_labels_current_bin,
                y_pred_scores=group_scores_current_bin,
                sensitive_attribute=None,
            )

            # Store plot point for this bin
            group_cali_points.append((
                results["mean_score_mean"],    # x coordinate
                results["mean_label_mean"],    # y coordinate
                (                              # y coordinate c.i. range
                    results["mean_label_low-percentile"],
                    results["mean_label_high-percentile"],
                ),
            ))

        # Get group label from sensitive attribute map
        group_label = str(sens_attr_map.get(group_val, group_val))

        # Plot this group's calibration curve w/ conf. intervals!
        line, caps, bars = ax.errorbar(
            # x=[mean_score for mean_score, *_ in group_cali_points],
            x=[(bin_edges[idx-1] + bin_edges[idx]) / 2. for idx in range(1, len(bin_edges))],
            y=[mean_label for _, mean_label, _ in group_cali_points],
            yerr=np.hstack(
                [
                    np.abs(np.array(yerr_range).reshape(2, 1) - y)
                    for _x, y, yerr_range in group_cali_points]
            ),
            fmt="o", markersize=1,
            linestyle=["-.", "-", "--"][idx % 3],
            linewidth=2,
            label=group_label,
        )
        plot_artifacts[group_label] = (line, caps, bars)

    # Global plot settings
    ax.plot(
        [0,1], [0,1],
        ls="dotted", color="black",
    )

    # Draw custom legend
    handles = []
    for group, (l, c, b) in plot_artifacts.items():
        l.set_label(group)
        l.set_markersize(0)
        handles.append(l)

    # Perfectly calibrated line
    perf_cal_line = plt.Line2D([], [], ls=":", color="black", label="Perf. calibrated")

    # Confidence interv. marker line
    conf_int_line = plt.Line2D([], [], lw=0, ms=10, marker="|", markeredgewidth=2, color="grey", label="$95\%$ conf. int.")

    ax.legend(handles=handles + [conf_int_line, perf_cal_line])
    ax.set_title(curr_row["name"])

    return disp
[38]:
from sklearn.calibration import CalibrationDisplay

# Get results for the current task
task_df = all_results_df[all_results_df[task_col] == TASK]

# Get all base/instruct model pairs
base_models_to_plot = [
    "Meta-Llama-3-8B",
    "Meta-Llama-3-70B",
    "Mistral-7B-v0.1",
    "Mixtral-8x7B-v0.1",
    "Mixtral-8x22B-v0.1",
    "Yi-34B",
]
# NOTE: Gemma models have degenerate distributions that always predict the same score, can't plot calibration properly...

instr_models_to_plot = [
    df[
        (df["base_name"] == m)
        & (df[model_col] != m)
    ][model_col].iloc[0]
    for m in base_models_to_plot
]

# Add GPT models if any
instr_models_to_plot += [m[model_col] for _, m in df.iterrows() if "gpt" in m["name"].lower()]
base_models_to_plot += [None for _ in range(len(instr_models_to_plot) - len(base_models_to_plot))]

base_it_pairs = list(zip(base_models_to_plot, instr_models_to_plot))

for base_model, it_model in tqdm(base_it_pairs):
    fig, axes = plt.subplots(ncols=2, figsize=(7, 2.8), gridspec_kw=dict(wspace=0.35), sharey=False)

    for idx, ax, m in zip(range(len(axes)), axes, (base_model, it_model)):
        if m is None:
            ax.remove()
            continue

        curr_row = task_df[(task_df[model_col] == m) & (task_df["uses_all_features"])].iloc[0]
        curr_id = curr_row.name

        # # Original sklearn calibration plot:
        # plot_cali_curve(curr_id, ax=ax, df=task_df)

        # # Plot with confidence intervals:
        plot_cali_curve_ci(curr_id, ax=ax, df=task_df)

        # Remove left plot legend
        if idx == 0:
            ax.get_legend().remove()
        else:
            ax.get_legend().set(loc="upper left", bbox_to_anchor=(1.1, 1))

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

        ax.set_xlabel("Mean risk score")
        ax.set_ylabel("Fraction of positives")

    plt.plot()
    save_fig(fig, f"calibration-curves-per-subgroup.{base_model}.pdf")

Saved figure to '../results/imgs/calibration-curves-per-subgroup.Meta-Llama-3-8B.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Meta-Llama-3-70B.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mistral-7B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mixtral-8x7B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mixtral-8x22B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Yi-34B.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.None.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_62_2.png
../_images/notebooks_paper-plots-and-tables_62_3.png
../_images/notebooks_paper-plots-and-tables_62_4.png
../_images/notebooks_paper-plots-and-tables_62_5.png
../_images/notebooks_paper-plots-and-tables_62_6.png
../_images/notebooks_paper-plots-and-tables_62_7.png
../_images/notebooks_paper-plots-and-tables_62_8.png
[39]:
from sklearn.calibration import CalibrationDisplay

# Get results for the current task
task_df = all_results_df[all_results_df[task_col] == TASK]

# Get all base/instruct model pairs
base_models_to_plot = [
    "Meta-Llama-3-8B",
    "Meta-Llama-3-70B",
    "Mistral-7B-v0.1",
    "Mixtral-8x7B-v0.1",
    "Mixtral-8x22B-v0.1",
    "Yi-34B",
]
# NOTE: Gemma models have degenerate distributions that always predict the same score, can't plot calibration properly...

instr_models_to_plot = [
    df[
        (df["base_name"] == m)
        & (df[model_col] != m)
    ][model_col].iloc[0]
    for m in base_models_to_plot
]

# Add GPT models if any
instr_models_to_plot += [m[model_col] for _, m in df.iterrows() if "gpt" in m["name"].lower()]
base_models_to_plot += [None for _ in range(len(instr_models_to_plot) - len(base_models_to_plot))]

base_it_pairs = list(zip(base_models_to_plot, instr_models_to_plot))

for base_model, it_model in tqdm(base_it_pairs):
    fig, axes = plt.subplots(ncols=2, figsize=(4.8, 2.2), gridspec_kw=dict(wspace=0.15), sharey=True)

    for idx, ax, m in zip(range(len(axes)), axes, (base_model, it_model)):
        if m is None:
            ax.remove()
            continue

        curr_row = task_df[(task_df[model_col] == m) & (task_df["uses_all_features"])].iloc[0]
        curr_id = curr_row.name

        # # Original sklearn calibration plot:
        # plot_cali_curve(curr_id, ax=ax, df=task_df)

        # # Plot with confidence intervals:
        plot_cali_curve_ci(curr_id, ax=ax, df=task_df)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

        ax.set_xlabel("Mean risk score")
        if idx == 0:
            ax.set_ylabel("Fraction of positives")

        # Set prettier x-ticks
        xticks = (0, 0.25, 0.5, 0.75, 1)
        xticks_labels = ("0.0", "0.25", "0.5", "0.75", "1.0")
        ax.set_xticks(xticks, labels=xticks_labels)

    # Prepare legend
    axes[0].legend().remove()
    if NUMERIC_PROMPTING:
        axes[1].get_legend().set(loc="upper left", bbox_to_anchor=(1.03, 1))
    else:
        axes[1].legend().remove()

    # Figure title
    if NUMERIC_PROMPTING:
        title = "Numeric Prompting"
    else:
        title = "Multiple-choice Prompting"

    # Don't use suptitle for GPT4 (no base model)
    if base_model is not None:
        plt.suptitle(title, y=1.1, fontsize=14)
    else:
        model_title = axes[1].get_title()
        plt.title(f"{title}\n{model_title}")

    plt.plot()
    save_fig(fig, f"calibration-curves-per-subgroup.{base_model}.smaller.pdf")

Saved figure to '../results/imgs/calibration-curves-per-subgroup.Meta-Llama-3-8B.smaller.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Meta-Llama-3-70B.smaller.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mistral-7B-v0.1.smaller.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mixtral-8x7B-v0.1.smaller.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Mixtral-8x22B-v0.1.smaller.numeric-prompt.pdf'
Saved figure to '../results/imgs/calibration-curves-per-subgroup.Yi-34B.smaller.numeric-prompt.pdf'
WARNING:matplotlib.legend:No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Saved figure to '../results/imgs/calibration-curves-per-subgroup.None.smaller.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_64_4.png
../_images/notebooks_paper-plots-and-tables_64_5.png
../_images/notebooks_paper-plots-and-tables_64_6.png
../_images/notebooks_paper-plots-and-tables_64_7.png
../_images/notebooks_paper-plots-and-tables_64_8.png
../_images/notebooks_paper-plots-and-tables_64_9.png
../_images/notebooks_paper-plots-and-tables_64_10.png

Score distribution per sub-group

[40]:
sens_attr_data = acs_dataset.data[sens_attr_col]
sens_attr_data.head(3)
[40]:
1    2
2    1
8    1
Name: RAC1P, dtype: int64
[41]:
# Plot settings
plot_config = dict(
    bins=20,
    binrange=(0, 1),
    stat="percent",
)

for base_model, it_model in tqdm(base_it_pairs):
    fig, axes = plt.subplots(ncols=2, figsize=(8, 3), gridspec_kw=dict(wspace=0.33), sharey=False)

    for idx, ax, m in zip(range(len(axes)), axes, (base_model, it_model)):
        if m is None:
            ax.remove()
            continue

        model_row = task_df[(task_df[model_col] == m) & (task_df["uses_all_features"])].iloc[0]
        model_id = model_row.name

        scores_df = scores_df_map[model_id]
        scores_df["group"] = sens_attr_data.loc[scores_df.index]

        # sns.histplot(
        #     data=scores_df,
        #     x="risk_score",
        #     hue="group",
        #     ax=ax,
        #     **plot_config,
        # )

        for group_val, group_name in sens_attr_map.items():
            group_scores_df = scores_df[sens_attr_data.loc[scores_df.index] == group_val]
            sns.histplot(
                group_scores_df["risk_score"],
                label=group_name,
                ax=ax,
                alpha=0.5,
                **plot_config,
            )

        # Remove left plot legend
        if idx == 0:
            ax.legend(loc="upper left")

        ax.set_xlim(0, 1)
        # ax.set_ylim(0, 1)

        ax.set_title(model_row["name"])

    plt.plot()
    save_fig(fig, f"score-per-subgroup.{base_model}.pdf")

Saved figure to '../results/imgs/score-per-subgroup.Meta-Llama-3-8B.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.Meta-Llama-3-70B.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.Mistral-7B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.Mixtral-8x7B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.Mixtral-8x22B-v0.1.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.Yi-34B.numeric-prompt.pdf'
Saved figure to '../results/imgs/score-per-subgroup.None.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_67_2.png
../_images/notebooks_paper-plots-and-tables_67_3.png
../_images/notebooks_paper-plots-and-tables_67_4.png
../_images/notebooks_paper-plots-and-tables_67_5.png
../_images/notebooks_paper-plots-and-tables_67_6.png
../_images/notebooks_paper-plots-and-tables_67_7.png
../_images/notebooks_paper-plots-and-tables_67_8.png

Compute mean under-/over-estimation of risk score

It’s basocally akin to computing ECE without the absolute…

[42]:
def compute_confidence_bias_metric_helper(data):
    """Helper to compute the under- / over-confidence metric.
    """
    mean_label = data["label"].mean()
    mean_score = data["risk_score"].mean()

    if mean_score < 0.5:
        return ((1 - mean_score) - (1 - mean_label))
    else:
        return (mean_score - mean_label)

def compute_confidence_bias_metric(model_id, quantiled_bins=False, scores_df=None):
    if scores_df is None:
        scores_df = scores_df_map[model_id]

    discretization_func = pd.qcut if quantiled_bins else pd.cut
    scores_df["score_bin"] = discretization_func(scores_df["risk_score"], 10, labels=range(10))

    bias_metric_val = (1 / len(scores_df)) * sum(
        len(bin_data) * compute_confidence_bias_metric_helper(bin_data)
        for bin_ in range(10)
        if len(bin_data := scores_df[scores_df["score_bin"] == bin_]) > 0
    )
    return float(bias_metric_val)
[43]:
confidence_scores = {
    m: compute_confidence_bias_metric(m)
    for m in tqdm(scores_df_map.keys())
}
[44]:
all_results_df["under_over_score"] = all_results_df.apply(lambda row: confidence_scores[row.name], axis=1)
all_results_df.sort_values("under_over_score", ascending=True)["under_over_score"]
[44]:
Meta-Llama-3-8B-Instruct__ACSPublicCoverage__-1__Num     -0.161384
Mixtral-8x7B-Instruct-v0.1__ACSPublicCoverage__-1__Num   -0.061690
Mixtral-8x7B-Instruct-v0.1__ACSEmployment__-1__Num       -0.056597
Meta-Llama-3-70B-Instruct__ACSPublicCoverage__-1__Num    -0.052221
gemma-7b__ACSIncome__-1__Num                             -0.033482
                                                            ...
gemma-7b__ACSEmployment__-1__Num                          0.341868
Mistral-7B-v0.1__ACSIncome__-1__Num                       0.355603
gemma-2b__ACSIncome__-1__Num                              0.367876
Mistral-7B-v0.1__ACSTravelTime__-1__Num                   0.438141
gemma-2b__ACSTravelTime__-1__Num                          0.438312
Name: under_over_score, Length: 100, dtype: float64
[45]:
TASK = "ACSIncome"
# TASK = "ACSTravelTime"
# TASK = "ACSMobility"
# TASK = "ACSEmployment"
[46]:
plot_df = all_results_df[
    (all_results_df[task_col] == TASK)
    & (~all_results_df[numeric_prompt_col].isna())
]

# Omit Gemma models bc they're garbage :)
plot_df = plot_df.drop(index=[id_ for id_ in plot_df.index if "gemma" in id_.lower() or "yi" in id_.lower()])


fig, ax = plt.subplots(figsize=(4, 3))
sns.barplot(
    data=plot_df.sort_values("under_over_score", ascending=True),
    x="under_over_score",
    y="name",
    # y=task_col,
    hue="is_inst",
    ax=ax,
)

handles, _ = plt.gca().get_legend_handles_labels()
plt.legend(handles=handles, labels=["Base", "Instruct"])

plt.ylabel("")
plt.xlabel(r"$\mathrm{R}_\mathrm{bias}$")

plt.title(f"Risk score confidence bias")

save_fig(fig, f"under_over_score.{TASK}.pdf")
Saved figure to '../results/imgs/under_over_score.ACSIncome.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_73_1.png

Compute risk score fairness per sub-group

[47]:
sens_attr_data = acs_dataset.data[sens_attr_col]
sens_attr_data.head(3)
[47]:
1    2
2    1
8    1
Name: RAC1P, dtype: int64
[48]:
def compute_signed_score_bias(scores_df, quantiled_bins=False, n_bins=10):
    discretization_func = pd.qcut if quantiled_bins else pd.cut
    scores_df["score_bin"] = discretization_func(scores_df["risk_score"], n_bins, duplicates="drop", retbins=False)
    bin_categories = scores_df["score_bin"].dtype.categories

    return (1 / len(scores_df)) * sum(
        len(bin_data) * (bin_data["risk_score"].mean() - bin_data["label"].mean())
        for bin_ in bin_categories
        if len(bin_data := scores_df[scores_df["score_bin"] == bin_]) > 0
    )
[49]:
def compute_signed_subgroup_score_bias(model, group: int | str) -> float:
    scores_df = scores_df_map[model]
    if isinstance(group, str):
        group = name_to_sens_attr_map[group]

    scores_df = scores_df[sens_attr_data.loc[scores_df.index] == group]
    # return compute_confidence_metric(model, scores_df=scores_df)
    return compute_signed_score_bias(scores_df)


def compute_calibration_fairness(model: str, priv_group: int | str, unpriv_group: int | str) -> float:
    return (
        compute_signed_subgroup_score_bias(model, group=priv_group)
        - compute_signed_subgroup_score_bias(model, group=unpriv_group)
    )
[50]:
calibration_per_subgroup_df = pd.DataFrame([
    pd.Series(
        {
            f"{group}_score_bias": compute_signed_subgroup_score_bias(model_id, group)
            for group in sens_attr_map.values()
        },
        name=model_id,
    )
    for model_id in task_df.index
])

calibration_per_subgroup_df["White_v_Black_score_bias"] = \
    calibration_per_subgroup_df["White_score_bias"] - calibration_per_subgroup_df["Black_score_bias"]

calibration_per_subgroup_df["White_v_Asian_score_bias"] = \
    calibration_per_subgroup_df["White_score_bias"] - calibration_per_subgroup_df["Asian_score_bias"]

calibration_per_subgroup_df["Asian_v_Black_score_bias"] = \
    calibration_per_subgroup_df["Asian_score_bias"] - calibration_per_subgroup_df["Black_score_bias"]

calibration_per_subgroup_df.sample(3)
/tmp/ipykernel_3388197/3464681156.py:3: SettingWithCopyWarning:
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  scores_df["score_bin"] = discretization_func(scores_df["risk_score"], n_bins, duplicates="drop", retbins=False)
[50]:
White_score_bias Black_score_bias Asian_score_bias White_v_Black_score_bias White_v_Asian_score_bias Asian_v_Black_score_bias
Mistral-7B-v0.1__ACSIncome__-1__Num 0.353045 0.329428 0.319214 0.023617 0.033831 -0.010215
gemma-1.1-2b-it__ACSIncome__-1__Num 0.260100 0.403764 0.202488 -0.143664 0.057612 -0.201276
Yi-34B-Chat__ACSIncome__-1__Num 0.210531 0.229600 0.212839 -0.019069 -0.002308 -0.016761

Note: This is an example of comparing risk score bias between different sub-groups. A more detailed analysis must be conducted to draw any sort of conclusions regarding prediction fairness of the LLM.

[51]:
GROUP_A = "Asian"
# GROUP_A = "White"
GROUP_B = "Black"
# GROUP_B = "Asian"

COLUMN = f"{GROUP_A}_v_{GROUP_B}_score_bias"

plot_df = calibration_per_subgroup_df.join(results_df, how="left", validate="1:1")
plot_df = plot_df.drop(index=[id_ for id_ in plot_df.index if "gemma" in id_.lower()])   # omit Gemma models bc they're garbage :)

fig, ax = plt.subplots(figsize=(3, 3))
sns.barplot(
    data=plot_df.sort_values(COLUMN, ascending=True),
    x=COLUMN,
    y="name",
    hue="is_inst",
    ax=ax,
)

handles, _ = plt.gca().get_legend_handles_labels()
plt.legend(handles=handles, labels=["Base", "Instruct"])

plt.ylabel("")
plt.xlabel("$\Delta_{SCE}$")

plt.title(
    "Signed risk score bias\n"
    + r"($SCE_{" + GROUP_A
    + r"} - SCE_{" + GROUP_B
    + r"}$)"
)

plt.plot()

save_fig(fig, f"score_bias.{COLUMN}.pdf")
Saved figure to '../results/imgs/score_bias.Asian_v_Black_score_bias.numeric-prompt.pdf'
../_images/notebooks_paper-plots-and-tables_80_1.png
[52]:
plot_df.sort_values(COLUMN, ascending=False).head(3)
[52]:
White_score_bias Black_score_bias Asian_score_bias White_v_Black_score_bias White_v_Asian_score_bias Asian_v_Black_score_bias accuracy accuracy_diff accuracy_ratio balanced_accuracy ... name is_inst num_features uses_all_features fit_thresh_on_100 fit_thresh_accuracy optimal_thresh optimal_thresh_accuracy score_stdev score_mean
Mixtral-8x22B-v0.1__ACSIncome__-1__Num 0.054322 -0.069436 0.130153 0.123758 -0.075831 0.199589 0.740583 0.148527 0.828562 0.761428 ... Mixtral 8x22B False -1.0 True 0.65 0.769300 0.45 0.740595 0.379947 0.411113
Mixtral-8x22B-Instruct-v0.1__ACSIncome__-1__Num 0.110399 0.055872 0.156710 0.054527 -0.046311 0.100837 0.767882 0.096716 0.885938 0.770218 ... Mixtral 8x22B (it) True -1.0 True 0.65 0.769967 0.55 0.767882 0.298323 0.474955
penai/gpt-4o-mini__ACSIncome__-1__Num -0.001300 -0.019960 -0.000407 0.018661 -0.000893 0.019553 0.777393 0.080767 0.904748 0.758922 ... GPT 4o mini (it) True -1.0 True 0.65 0.777302 0.35 0.775512 0.317359 0.363743

3 rows × 70 columns


Plots comparing different prompting schemes

[53]:
### Plot all models / i.t. / base?
IT_OR_BASE = "any"
# IT_OR_BASE = "it"
# IT_OR_BASE = "base"

### Plot which metric
METRIC = "ece"
# METRIC = "roc_auc"
[54]:
# Prepare data
df = all_results_df_original.copy()

min_max_per_task = {
    task: (np.min(values), np.max(values))
    for task in ALL_TASKS
    if len(values := df[(df[task_col] == task) & (~df[numeric_prompt_col].isna())][METRIC]) > 0
}

# Sort models
sorted_df_index = sorted(
    df.index.tolist(),
    key=lambda id_: model_sort_key(id_, df),
    reverse=True,
)
df = df.loc[sorted_df_index]

if IT_OR_BASE == "it":
    print("Plotting IT models")
    df = df[df["is_inst"] == True]
elif IT_OR_BASE == "base":
    print("Plotting BASE models")
    df = df[df["is_inst"] == False]
else:
    print("Plotting ALL models")
    df = df.sort_values(by=["is_inst", "model_size"], ascending=False)

# Data for multiple-choice prompting
df_mult_choice = df[df[numeric_prompt_col] == False].set_index(["name", task_col], drop=False)

# Data for numeric risk prompting
df_numeric = df[df[numeric_prompt_col] == True].set_index(["name", task_col], drop=False)

# Sorted list of model names
model_names = df_numeric["name"].unique().tolist()[::-1]
n_models = len(model_names)
print(f"{model_names=}")
Plotting ALL models
model_names=['Gemma 2B', 'Gemma 7B', 'Mistral 7B', 'Llama 3 8B', 'Yi 34B', 'Mixtral 8x7B', 'Llama 3 70B', 'Mixtral 8x22B', 'Gemma 2B (it)', 'Gemma 7B (it)', 'Mistral 7B (it)', 'Llama 3 8B (it)', 'Yi 34B (it)', 'Mixtral 8x7B (it)', 'Llama 3 70B (it)', 'Mixtral 8x22B (it)', 'GPT 4o mini (it)']
[55]:
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

# Create a figure and axis
fig, axes = plt.subplots(ncols=len(ALL_TASKS), figsize=(8, n_models / 3), sharey=True)

for task_i, task in enumerate(ALL_TASKS):

    # Get current plot axis
    ax = axes[task_i]

    # Plot arrows for each model
    for i, model_name in enumerate(model_names):

        curr_mult_choice_metric = df_mult_choice.loc[model_name, task][METRIC]
        curr_numeric_metric = df_numeric.loc[model_name, task][METRIC]

        ax.scatter(
            curr_mult_choice_metric, i,
            color=palette[0], marker="s",
            label='Multiple-choice' if i == 0 else "",
        )
        ax.scatter(
            curr_numeric_metric, i,
            color=palette[1], marker="o",
            label='Numeric' if i == 0 else "",
        )

        # Draw a FancyArrowPatch arrow from before to after fine-tuning
        if METRIC == "ece":
            curr_metric_improved = curr_numeric_metric <= curr_mult_choice_metric
        else:
            curr_metric_improved = curr_numeric_metric >= curr_mult_choice_metric

        arrow = FancyArrowPatch(
            (curr_mult_choice_metric, i), (curr_numeric_metric, i),
            arrowstyle='->',
            color=palette[2] if curr_metric_improved else palette[3],
            mutation_scale=15,
            lw=1.5,
        )
        ax.add_patch(arrow)

        # Set axes ranges according to overall min/max values
        lo, hi = min_max_per_task[task]
        if METRIC == "ece":
            lo = max(lo - 5e-2, 0)
            hi = min(hi + 5e-2, 0.475)
        elif METRIC == "roc_auc":
            lo = max(lo - 5e-2, 0.475)
            hi = min(hi + 4e-2, 1)

        if task == "ACSIncome" and METRIC == "roc_auc":
            lo = 0.58
        ax.set_xlim(lo, hi)

    # Set y-ticks and y-labels
    ax.set_yticks(range(n_models))
    ax.set_yticklabels(model_names)

    # # Prettify x-axis ticks
    # xticks = ax.get_xticks()
    # ax.set_xticks(xticks, [f"{x}" for x in xticks])

    # Add labels and title
    # ax.set_xlabel('ECE')
    ax.set_title(task[3:])

    # Add horizontal line separating instruct and base models
    if IT_OR_BASE == "any":
        n_instruct_models = len(df_mult_choice[(df_mult_choice["is_inst"] == True) & (df_mult_choice[task_col] == task)])
        line_y = n_instruct_models - 1.5
        ax.axhline(line_y, color='black', linestyle='--', linewidth=1) #, label="instruct / base")

        # Add text label near the horizontal line
        if task_i > len(ALL_TASKS) - 3:
            text_kwargs = dict(
                horizontalalignment='right',
                transform=ax.get_yaxis_transform(),  # Use axis coordinates for y-axis
                fontsize=9, color='black',
            )

            ax.text(1, line_y + 0.1, 'i.t.', verticalalignment='bottom', **text_kwargs)
            ax.text(1, line_y - 0.1, 'base', verticalalignment='top', **text_kwargs)



# Add legend for marker
axes[0].legend(loc="upper right", bbox_to_anchor=(-0.15, 0.01), title="Prompting scheme")

# X-axis label
if METRIC == "ece":
    x_axis_label = "Expected Calibration Error (ECE)"
elif METRIC == "roc_auc":
    x_axis_label = "Predictive Power (ROC AUC)"

fig.text(0.5, 0.02 if IT_OR_BASE == "any" else -0.04, x_axis_label, horizontalalignment="center", verticalalignment="center")

# Save figure
save_fig(fig=fig, name=f"{METRIC}_comparison_diff_prompts.{IT_OR_BASE}-models.pdf", add_prompt_suffix=False)

# Show plot
plt.show()
Saved figure to '../results/imgs/ece_comparison_diff_prompts.any-models.pdf'
../_images/notebooks_paper-plots-and-tables_85_1.png

Scatter plot of ECE vs AUC for a few models (before and after numeric prompting)

[56]:
X_METRIC = "ece"
Y_METRIC = "roc_auc"

# Filter out small models
df_mult_choice = df_mult_choice[df_mult_choice["model_size"] > 50]
df_numeric = df_numeric[df_numeric["model_size"] > 50]

fig, axes = plt.subplots(ncols=len(ALL_TASKS), figsize=(len(ALL_TASKS) * 2.6, 2.2), gridspec_kw=dict(wspace=0.4))

unique_model_names = df_mult_choice.name.unique()
markers = dict(zip(unique_model_names, ["s", "o", "D", "P", "X", "<", ">"]))

for task_i, task in enumerate(ALL_TASKS):

    # Create a figure and axis
    # fig, ax = plt.subplots(figsize=(2.5, 2.5))

    ax = axes[task_i]

    # Plot arrows for each model
    for i, model_name in enumerate(df_mult_choice.name.unique()):

        # Get data for each plotting scheme
        curr_mult_choice_row = df_mult_choice.loc[model_name, task]
        curr_numeric_row = df_numeric.loc[model_name, task]

        # Get points to plot
        curr_mult_choice_point = (curr_mult_choice_row[X_METRIC], curr_mult_choice_row[Y_METRIC])
        curr_numeric_point = (curr_numeric_row[X_METRIC], curr_numeric_row[Y_METRIC])

        # Plot point for multiple-choice prompting
        ax.scatter(
            *curr_mult_choice_point,
            color=palette[0], marker=markers[model_name],
            label='Multiple-choice' if i == 0 else "",
        )

        # Plot point for numeric-style prompting
        ax.scatter(
            *curr_numeric_point,
            color=palette[1], marker=markers[model_name],
            label='Numeric' if i == 0 else "",
        )

        if curr_numeric_point[0] <=  curr_mult_choice_point[0] and curr_numeric_point[1] >=  curr_mult_choice_point[1]:
            arrow_color = palette[2]
        elif curr_numeric_point[0] >  curr_mult_choice_point[0] and curr_numeric_point[1] <  curr_mult_choice_point[1]:
            arrow_color = palette[3]
        else:
            arrow_color = "grey"

        # Draw a FancyArrowPatch arrow from before to after fine-tuning
        arrow = FancyArrowPatch(
            curr_mult_choice_point, curr_numeric_point,
            arrowstyle='->',
            color=arrow_color,
            mutation_scale=15,
            lw=1,
            ls="--",
        )
        ax.add_patch(arrow)

    ### Plot configs
    # Add labels and title
    ax.set_title(task[3:])
    ax.set_xlabel("ECE")
    if task_i == 0:
        ax.set_ylabel("AUC")

# Create custom global legend
from matplotlib.lines import Line2D

# Create a global legend for markers (model names)
marker_handles = [Line2D([0], [0], marker=markers[model], color='w', markerfacecolor='black', markersize=8, label=model) for model in unique_model_names]

# Create a global legend for colors (prompting schemes)
color_handles = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=palette[0], markersize=8, label="Multiple-choice"),
    Line2D([0], [0], marker='o', color='w', markerfacecolor=palette[1], markersize=8, label="Numeric")
]

# Position the legends outside the subplots
fig.legend(handles=marker_handles, title="Models", loc='upper left', bbox_to_anchor=(0.13, -0.12), ncol=min(4, len(unique_model_names)))
fig.legend(handles=color_handles, title="Prompting Scheme", loc='upper right', bbox_to_anchor=(0.9, -0.12), ncol=2)

# Save figure
save_fig(fig=fig, name=f"comparison_diff_prompts.{IT_OR_BASE}-models.pdf", add_prompt_suffix=False)

plt.show()
Saved figure to '../results/imgs/comparison_diff_prompts.any-models.pdf'
../_images/notebooks_paper-plots-and-tables_87_1.png