Example: Run ACS benchmark task
[1]:
import folktexts
print(f"{folktexts.__version__=}")
folktexts.__version__='0.0.21'
[2]:
from pathlib import Path
import torch
import numpy as np
import pandas as pd
[3]:
import logging
logging.getLogger().setLevel(logging.INFO)
Set important local paths
NOTE: Can be ignored if you haven’t previously downloaded the model, just use load_model_tokenizer
with the model’s name on Huggingface.
Set your root directory (change as appropriate):
[4]:
ROOT_DIR = Path("/fast/groups/sf")
ROOT_DIR
[4]:
PosixPath('/fast/groups/sf')
Directory where LLMs are saved (change as appropriate):
[5]:
MODELS_DIR = ROOT_DIR / "huggingface-models"
Directory where data is saved or will be saved to (change as appropriate):
[6]:
DATA_DIR = ROOT_DIR / "data"
Other configs:
[7]:
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
# MODEL_NAME = "google/gemma-2b" # Smaller model that is faster to run
TASK_NAME = "ACSIncome"
RESULTS_ROOT_DIR = ROOT_DIR / "folktexts-results"
[8]:
from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path
model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)
model, tokenizer = load_model_tokenizer(model_folder_path)
INFO:root:Loading model '/lustre/fast/fast/groups/sf/huggingface-models/meta-llama--Meta-Llama-3-8B-Instruct'
INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
INFO:root:Moving model to device: cuda
[9]:
results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name
results_dir.mkdir(exist_ok=True, parents=True)
results_dir
[9]:
PosixPath('/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct')
Construct LLM Classifier
NOTE: Also compatible with models hosted through a web API by using the WebAPILLMClassifier
class instead of TransformersLLMClassifier
.
Load prediction task (which maps tabular data to text):
[10]:
from folktexts.acs import ACSTaskMetadata
task = ACSTaskMetadata.get_task(TASK_NAME, use_numeric_qa=False)
INFO:root:Changing Q&A mode for task 'ACSIncome' to multiple-choice.
[11]:
from folktexts.classifier import TransformersLLMClassifier
llm_clf = TransformersLLMClassifier(
model=model,
tokenizer=tokenizer,
task=task,
batch_size=32,
context_size=1000,
)
Load Dataset
[12]:
%%time
from folktexts.acs import ACSDataset
dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)
Loading ACS data...
CPU times: user 58 s, sys: 27.5 s, total: 1min 25s
Wall time: 1min 26s
Optionally, subsample to quickly get approximate results:
[13]:
dataset.subsample(0.01)
print(f"{dataset.subsampling=}")
INFO:root:Train size: 13316, Test size: 1665, Val size: 1665;
dataset.subsampling=0.01
Load and run ACS Benchmark
Note: Helper constructors exist at Benchmark.make_acs_benchmark
and Benchmark.make_benchmark
that avoid the above boilerplate code.
[14]:
from folktexts.benchmark import BenchmarkConfig, Benchmark
bench = Benchmark(llm_clf=llm_clf, dataset=dataset)
INFO:root:
** Benchmark initialization **
Model: meta-llama--Meta-Llama-3-8B-Instruct;
Task: ACSIncome;
Hash: 516109582;
Here’s an example prompt for the current prediction task:
[15]:
X_sample, _y_sample = dataset.sample_n_train_examples(n=1)
print(llm_clf.encode_row(X_sample.iloc[0], question=llm_clf.task.question))
The following data corresponds to a survey respondent. The survey was conducted among US residents in 2018. Please answer the question based on the information provided. The data provided is enough to reach an approximate answer.
Information:
- The age is: 37 years old.
- The class of worker is: Owner of non-incorporated business, professional practice, or farm.
- The highest educational attainment is: Regular high school diploma.
- The marital status is: Married.
- The occupation is: Painters and paperhangers.
- The place of birth is: New Jersey.
- The relationship to the reference person in the survey is: The reference person itself.
- The usual number of hours worked per week is: 40 hours.
- The sex is: Male.
- The race is: White.
Question: What is this person's estimated yearly income?
A. Below $50,000.
B. Above $50,000.
Answer:
Optionally, you can fit the model’s threshold on a few data samples.
This is generally quite fast as it is not fine-tuning; it only changes one parameter: the llm_clf.threshold
.
[16]:
%%time
X_sample, y_sample = dataset.sample_n_train_examples(n=50)
llm_clf.fit(X_sample, y_sample)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
WARNING:root:Setting meta-llama--Meta-Llama-3-8B-Instruct threshold to 0.7918923636748516.
CPU times: user 7.57 s, sys: 382 ms, total: 7.95 s
Wall time: 9 s
[16]:
TransformersLLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x14802eedf600>, task=ACSTaskMetadata(name='ACSIncome', description="predict whether an individual's income is above $50,000", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText... 128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), })In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
TransformersLLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x14802eedf600>, task=ACSTaskMetadata(name='ACSIncome', description="predict whether an individual's income is above $50,000", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText... 128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), })
Run benchmark…
[17]:
%%time
bench.run(results_root_dir=results_dir)
INFO:root:Test data features shape: (1665, 10)
INFO:root:
** Test results **
Model: meta-llama--Meta-Llama-3-8B-Instruct;
ECE: 24.8%;
ROC AUC : 82.8%;
Accuracy: 73.8%;
Bal. acc.: 74.1%;
INFO:root:Skipping group 'American Indian' as it's too small.
INFO:root:Skipping group 'Alaska Native' as it's too small.
INFO:root:Skipping group 'American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races' as it's too small.
INFO:root:Skipping group 'Native Hawaiian and Other Pacific Islander' as it's too small.
INFO:root:Skipping group 'Some other race alone (non-White)' as it's too small.
INFO:root:Skipping group 'Two or more races' as it's too small.
INFO:root:Skipping group American Indian plot as it's too small.
INFO:root:Skipping group Alaska Native plot as it's too small.
INFO:root:Skipping group American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races plot as it's too small.
INFO:root:Skipping group Native Hawaiian and Other Pacific Islander plot as it's too small.
INFO:root:Skipping group Some other race alone (non-White) plot as it's too small.
INFO:root:Skipping group Two or more races plot as it's too small.
INFO:root:Saving JSON file to '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/results.bench-3191680725.json'
INFO:root:Saved experiment results to '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/results.bench-3191680725.json'
CPU times: user 3min 51s, sys: 7.12 s, total: 3min 58s
Wall time: 4min 23s
[17]:
{'threshold': np.float64(0.7918923636748516),
'n_samples': 1665,
'n_positives': 605,
'n_negatives': 1060,
'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
'accuracy': 0.7375375375375376,
'tpr': 0.7553719008264462,
'fnr': 0.24462809917355371,
'fpr': 0.27264150943396226,
'tnr': 0.7273584905660377,
'balanced_accuracy': 0.7413651956962419,
'precision': 0.6126005361930295,
'ppr': 0.44804804804804804,
'log_loss': 0.8504629791262521,
'brier_score_loss': np.float64(0.24624856734149014),
'ppr_ratio': 0.0,
'ppr_diff': 0.5961538461538461,
'precision_ratio': 0.0,
'precision_diff': 1.0,
'balanced_accuracy_ratio': 0.0,
'balanced_accuracy_diff': 1.0,
'accuracy_ratio': 0.0,
'accuracy_diff': 1.0,
'tnr_ratio': 0.0,
'tnr_diff': 1.0,
'fnr_ratio': 0.0,
'fnr_diff': 1.0,
'fpr_ratio': 0.0,
'fpr_diff': 1.0,
'tpr_ratio': 0.0,
'tpr_diff': 1.0,
'equalized_odds_ratio': 0.0,
'equalized_odds_diff': 1.0,
'roc_auc': np.float64(0.8282707001403399),
'ece': 0.2475394044687882,
'ece_quantile': 0.24753940446878786,
'predictions_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/ACSIncome_subsampled-0.01_seed-42_hash-373553086.test_predictions.csv',
'config': {'numeric_risk_prompting': False,
'few_shot': None,
'reuse_few_shot_examples': False,
'batch_size': None,
'context_size': None,
'correct_order_bias': True,
'feature_subset': None,
'population_filter': None,
'seed': 42,
'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
'model_hash': 1896439497,
'task_name': 'ACSIncome',
'task_hash': 662640833,
'dataset_name': 'ACSIncome_subsampled-0.01_seed-42_hash-373553086',
'dataset_subsampling': 0.01,
'dataset_hash': 373553086},
'plots': {'roc_curve_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/roc_curve.pdf',
'calibration_curve_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/calibration_curve.pdf',
'score_distribution_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/score_distribution.pdf',
'score_distribution_per_label_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/score_distribution_per_label.pdf',
'roc_curve_per_subgroup_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/roc_curve_per_subgroup.pdf',
'calibration_curve_per_subgroup_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/imgs/calibration_curve_per_subgroup.pdf'}}
[18]:
bench.plot_results();
INFO:root:Skipping group 'American Indian' as it's too small.
INFO:root:Skipping group 'Alaska Native' as it's too small.
INFO:root:Skipping group 'American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races' as it's too small.
INFO:root:Skipping group 'Native Hawaiian and Other Pacific Islander' as it's too small.
INFO:root:Skipping group 'Some other race alone (non-White)' as it's too small.
INFO:root:Skipping group 'Two or more races' as it's too small.
INFO:root:Skipping group American Indian plot as it's too small.
INFO:root:Skipping group Alaska Native plot as it's too small.
INFO:root:Skipping group American Indian and Alaska Native tribes specified; or American Indian or Alaska Native, not specified and no other races plot as it's too small.
INFO:root:Skipping group Native Hawaiian and Other Pacific Islander plot as it's too small.
INFO:root:Skipping group Some other race alone (non-White) plot as it's too small.
INFO:root:Skipping group Two or more races plot as it's too small.
[19]:
from pprint import pprint
pprint(bench.results, depth=1)
{'accuracy': 0.7375375375375376,
'accuracy_diff': 1.0,
'accuracy_ratio': 0.0,
'balanced_accuracy': 0.7413651956962419,
'balanced_accuracy_diff': 1.0,
'balanced_accuracy_ratio': 0.0,
'brier_score_loss': np.float64(0.24624856734149014),
'config': {...},
'ece': 0.2475394044687882,
'ece_quantile': 0.24753940446878786,
'equalized_odds_diff': 1.0,
'equalized_odds_ratio': 0.0,
'fnr': 0.24462809917355371,
'fnr_diff': 1.0,
'fnr_ratio': 0.0,
'fpr': 0.27264150943396226,
'fpr_diff': 1.0,
'fpr_ratio': 0.0,
'log_loss': 0.8504629791262521,
'model_name': 'meta-llama--Meta-Llama-3-8B-Instruct',
'n_negatives': 1060,
'n_positives': 605,
'n_samples': 1665,
'plots': {...},
'ppr': 0.44804804804804804,
'ppr_diff': 0.5961538461538461,
'ppr_ratio': 0.0,
'precision': 0.6126005361930295,
'precision_diff': 1.0,
'precision_ratio': 0.0,
'predictions_path': '/lustre/fast/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct/meta-llama--Meta-Llama-3-8B-Instruct_bench-3191680725/ACSIncome_subsampled-0.01_seed-42_hash-373553086.test_predictions.csv',
'roc_auc': np.float64(0.8282707001403399),
'threshold': np.float64(0.7918923636748516),
'tnr': 0.7273584905660377,
'tnr_diff': 1.0,
'tnr_ratio': 0.0,
'tpr': 0.7553719008264462,
'tpr_diff': 1.0,
'tpr_ratio': 0.0}