Example: Run ACS benchmark task

import folktexts
from pathlib import Path
import torch
import numpy as np
import pandas as pd
import logging

Set important local paths

NOTE: Can be ignored if you haven’t previously downloaded the model, just use load_model_tokenizer with the model’s name on Huggingface.

Set your root directory (change as appropriate):

ROOT_DIR = Path("/fast/groups/sf")

Directory where LLMs are saved (change as appropriate):

MODELS_DIR = ROOT_DIR / "huggingface-models"

Directory where data is saved or will be saved to (change as appropriate):

DATA_DIR = ROOT_DIR / "data"

Other configs:

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_NAME = "google/gemma-2b"    # Smaller model that is faster to run


RESULTS_ROOT_DIR = ROOT_DIR / "folktexts-results"
from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)
model, tokenizer = load_model_tokenizer(model_folder_path)
INFO:root:Loading model '/lustre/fast/fast/groups/sf/huggingface-models/meta-llama--Meta-Llama-3-8B-Instruct'
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO:root:Moving model to device: cuda
results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name
results_dir.mkdir(exist_ok=True, parents=True)

Construct LLM Classifier

NOTE: Also compatible with models hosted through a web API by using the WebAPILLMClassifier class instead of TransformersLLMClassifier.

Load prediction task (which maps tabular data to text):

from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(TASK_NAME, use_numeric_qa=False)
INFO:root:Changing Q&A mode for task 'ACSIncome' to multiple-choice.
from folktexts.classifier import TransformersLLMClassifier
llm_clf = TransformersLLMClassifier(

Load Dataset

from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)
Loading ACS data...
CPU times: user 58 s, sys: 27.5 s, total: 1min 25s
Wall time: 1min 26s

Optionally, subsample to quickly get approximate results:

INFO:root:Train size: 13316, Test size:  1665, Val size:   1665;

Load and run ACS Benchmark

Note: Helper constructors exist at Benchmark.make_acs_benchmark and Benchmark.make_benchmark that avoid the above boilerplate code.

from folktexts.benchmark import BenchmarkConfig, Benchmark

bench = Benchmark(llm_clf=llm_clf, dataset=dataset)
** Benchmark initialization **
Model: meta-llama--Meta-Llama-3-8B-Instruct;
Task: ACSIncome;
Hash: 516109582;

Here’s an example prompt for the current prediction task:

X_sample, _y_sample = dataset.sample_n_train_examples(n=1)
print(llm_clf.encode_row(X_sample.iloc[0], question=llm_clf.task.question))
The following data corresponds to a survey respondent. The survey was conducted among US residents in 2018. Please answer the question based on the information provided. The data provided is enough to reach an approximate answer.

- The age is: 37 years old.
- The class of worker is: Owner of non-incorporated business, professional practice, or farm.
- The highest educational attainment is: Regular high school diploma.
- The marital status is: Married.
- The occupation is: Painters and paperhangers.
- The place of birth is: New Jersey.
- The relationship to the reference person in the survey is: The reference person itself.
- The usual number of hours worked per week is: 40 hours.
- The sex is: Male.
- The race is: White.

Question: What is this person's estimated yearly income?
A. Below $50,000.
B. Above $50,000.

Optionally, you can fit the model’s threshold on a few data samples.

This is generally quite fast as it is not fine-tuning; it only changes one parameter: the llm_clf.threshold.

X_sample, y_sample = dataset.sample_n_train_examples(n=50)
llm_clf.fit(X_sample, y_sample)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
WARNING:root:Setting meta-llama--Meta-Llama-3-8B-Instruct threshold to 0.7918923636748516.
CPU times: user 7.57 s, sys: 382 ms, total: 7.95 s
Wall time: 9 s
TransformersLLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x14802eedf600>, task=ACSTaskMetadata(name='ACSIncome', description="predict whether an individual's income is above $50,000", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText...
    128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
    128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
Run benchmark…

INFO:root:Test data features shape: (1665, 10)
** Test results **
Model: meta-llama--Meta-Llama-3-8B-Instruct;
         ECE:       24.8%;
         ROC AUC :  82.8%;
         Accuracy:  73.8%;
         Bal. acc.: 74.1%;

INFO:root:Skipping group 'American Indian' as it's too small.
INFO:root:Skipping group 'Alaska Native' as it's too small.
INFO:root:Skipping group 'American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races' as it's too small.
INFO:root:Skipping group 'Native Hawaiian and Other Pacific Islander' as it's too small.
INFO:root:Skipping group 'Some other race alone (non-White)' as it's too small.
INFO:root:Skipping group 'Two or more races' as it's too small.
INFO:root:Skipping group American Indian plot as it's too small.
INFO:root:Skipping group Alaska Native plot as it's too small.
INFO:root:Skipping group American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races plot as it's too small.
INFO:root:Skipping group Native Hawaiian and Other Pacific Islander plot as it's too small.
INFO:root:Skipping group Some other race alone (non-White) plot as it's too small.
INFO:root:Skipping group Two or more races plot as it's too small.
INFO:root:Saving JSON file to '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/results.bench-3191680725.json'
INFO:root:Saved experiment results to '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/results.bench-3191680725.json'
CPU times: user 3min 51s, sys: 7.12 s, total: 3min 58s
Wall time: 4min 23s
{'threshold': np.float64(0.7918923636748516),
 'n_samples': 1665,
 'n_positives': 605,
 'n_negatives': 1060,
 'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
 'accuracy': 0.7375375375375376,
 'tpr': 0.7553719008264462,
 'fnr': 0.24462809917355371,
 'fpr': 0.27264150943396226,
 'tnr': 0.7273584905660377,
 'balanced_accuracy': 0.7413651956962419,
 'precision': 0.6126005361930295,
 'ppr': 0.44804804804804804,
 'log_loss': 0.8504629791262521,
 'brier_score_loss': np.float64(0.24624856734149014),
 'ppr_ratio': 0.0,
 'ppr_diff': 0.5961538461538461,
 'precision_ratio': 0.0,
 'precision_diff': 1.0,
 'balanced_accuracy_ratio': 0.0,
 'balanced_accuracy_diff': 1.0,
 'accuracy_ratio': 0.0,
 'accuracy_diff': 1.0,
 'tnr_ratio': 0.0,
 'tnr_diff': 1.0,
 'fnr_ratio': 0.0,
 'fnr_diff': 1.0,
 'fpr_ratio': 0.0,
 'fpr_diff': 1.0,
 'tpr_ratio': 0.0,
 'tpr_diff': 1.0,
 'equalized_odds_ratio': 0.0,
 'equalized_odds_diff': 1.0,
 'roc_auc': np.float64(0.8282707001403399),
 'ece': 0.2475394044687882,
 'ece_quantile': 0.24753940446878786,
 'predictions_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/ACSIncome_subsampled-0.01_seed-42_hash-373553086.test_predictions.csv',
 'config': {'numeric_risk_prompting': False,
  'few_shot': None,
  'reuse_few_shot_examples': False,
  'batch_size': None,
  'context_size': None,
  'correct_order_bias': True,
  'feature_subset': None,
  'population_filter': None,
  'seed': 42,
  'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
  'model_hash': 1896439497,
  'task_name': 'ACSIncome',
  'task_hash': 662640833,
  'dataset_name': 'ACSIncome_subsampled-0.01_seed-42_hash-373553086',
  'dataset_subsampling': 0.01,
  'dataset_hash': 373553086},
 'plots': {'roc_curve_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/roc_curve.pdf',
  'calibration_curve_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/calibration_curve.pdf',
  'score_distribution_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/score_distribution.pdf',
  'score_distribution_per_label_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/score_distribution_per_label.pdf',
  'roc_curve_per_subgroup_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/roc_curve_per_subgroup.pdf',
  'calibration_curve_per_subgroup_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/calibration_curve_per_subgroup.pdf'}}
from pprint import pprint
pprint(bench.results, depth=1)
{'accuracy': 0.7375375375375376,
 'accuracy_diff': 1.0,
 'accuracy_ratio': 0.0,
 'balanced_accuracy': 0.7413651956962419,
 'balanced_accuracy_diff': 1.0,
 'balanced_accuracy_ratio': 0.0,
 'brier_score_loss': np.float64(0.24624856734149014),
 'config': {...},
 'ece': 0.2475394044687882,
 'ece_quantile': 0.24753940446878786,
 'equalized_odds_diff': 1.0,
 'equalized_odds_ratio': 0.0,
 'fnr': 0.24462809917355371,
 'fnr_diff': 1.0,
 'fnr_ratio': 0.0,
 'fpr': 0.27264150943396226,
 'fpr_diff': 1.0,
 'fpr_ratio': 0.0,
 'log_loss': 0.8504629791262521,
 'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
 'n_negatives': 1060,
 'n_positives': 605,
 'n_samples': 1665,
 'plots': {...},
 'ppr': 0.44804804804804804,
 'ppr_diff': 0.5961538461538461,
 'ppr_ratio': 0.0,
 'precision': 0.6126005361930295,
 'precision_diff': 1.0,
 'precision_ratio': 0.0,
 'predictions_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/ACSIncome_subsampled-0.01_seed-42_hash-373553086.test_predictions.csv',
 'roc_auc': np.float64(0.8282707001403399),
 'threshold': np.float64(0.7918923636748516),
 'tnr': 0.7273584905660377,
 'tnr_diff': 1.0,
 'tnr_ratio': 0.0,
 'tpr': 0.7553719008264462,
 'tpr_diff': 1.0,
 'tpr_ratio': 0.0}