{ "cells": [ { "cell_type": "markdown", "id": "e85e7f8a-dcfa-489a-9a8e-77225d2e3499", "metadata": {}, "source": [ "# Example: ACS benchmark with a web-API model (gpt-5-mini)" ] }, { "cell_type": "code", "execution_count": 1, "id": "98180e30-da85-41fb-a360-35b6acacb170", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:13.969476Z", "iopub.status.busy": "2026-06-09T16:08:13.969383Z", "iopub.status.idle": "2026-06-09T16:08:20.231555Z", "shell.execute_reply": "2026-06-09T16:08:20.231020Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "folktexts.__version__='0.6.0'\n" ] } ], "source": [ "from pathlib import Path\n", "import folktexts\n", "print(f\"{folktexts.__version__=}\")" ] }, { "cell_type": "markdown", "id": "bb6ff853", "metadata": {}, "source": [ "Load environment variables (e.g., API keys):" ] }, { "cell_type": "code", "execution_count": 2, "id": "4850aae9-7276-46a6-94ed-30243cbb76ff", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:20.233389Z", "iopub.status.busy": "2026-06-09T16:08:20.233107Z", "iopub.status.idle": "2026-06-09T16:08:22.324826Z", "shell.execute_reply": "2026-06-09T16:08:22.324273Z" } }, "outputs": [], "source": [ "from dotenv import load_dotenv\n", "load_dotenv()\n", "\n", "# gpt-5 reasoning models only accept temperature=1; let litellm drop the\n", "# library's default temperature=0 (and any other unsupported params).\n", "import litellm\n", "litellm.drop_params = True" ] }, { "cell_type": "markdown", "id": "efbc677d-5208-4a8a-aab8-9cb7282ca39f", "metadata": {}, "source": [ "Directory where data is saved or will be saved to (_**change as appropriate**_):" ] }, { "cell_type": "code", "execution_count": 3, "id": "8040b3cd-da53-4b62-8aae-58b2bb7cba69", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.327187Z", "iopub.status.busy": "2026-06-09T16:08:22.326913Z", "iopub.status.idle": "2026-06-09T16:08:22.329572Z", "shell.execute_reply": "2026-06-09T16:08:22.329069Z" } }, "outputs": [], "source": [ "ROOT_DIR = Path(\"~/\").expanduser()\n", "DATA_DIR = Path(\"/fast/groups/sf/data\") # pre-cached folktables data on this cluster\n", "RESULTS_ROOT_DIR = ROOT_DIR / \"folktexts-results\"\n", "RESULTS_ROOT_DIR.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "markdown", "id": "985344ec-6c9c-4e8b-8d8e-26f009937a7c", "metadata": {}, "source": [ "Set LLM and task name:" ] }, { "cell_type": "code", "execution_count": 4, "id": "a4316d94-110f-4014-a823-95300ac87eb8", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.330676Z", "iopub.status.busy": "2026-06-09T16:08:22.330575Z", "iopub.status.idle": "2026-06-09T16:08:22.332322Z", "shell.execute_reply": "2026-06-09T16:08:22.331932Z" } }, "outputs": [], "source": [ "MODEL_NAME = \"openai/gpt-5-mini\"\n", "TASK_NAME = \"ACSIncome\" # ACSIncome -> income prediction" ] }, { "cell_type": "markdown", "id": "2d87da67-3857-421f-b761-81814ea5498d", "metadata": {}, "source": [ "### Construct LLM Classifier\n", "Load task (maps tabular data to text prompts), and configure the `LLMClassifier`:" ] }, { "cell_type": "code", "execution_count": 5, "id": "4876c2b3", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.333287Z", "iopub.status.busy": "2026-06-09T16:08:22.333190Z", "iopub.status.idle": "2026-06-09T16:08:22.334801Z", "shell.execute_reply": "2026-06-09T16:08:22.334447Z" } }, "outputs": [], "source": [ "# This demo uses chain-of-thought (CoT) prompting: the model reasons step-by-step\n", "# and ends with a \"Probability: X%\" line, which folktexts extracts via regex.\n", "# CoT suits chat/reasoning models like gpt-5-mini that answer verbosely rather\n", "# than emitting a bare probability (numeric prompting collapses on chat APIs)." ] }, { "cell_type": "code", "execution_count": 6, "id": "7f0b725d-d97a-40a9-a71d-88fb5fefbddf", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.335819Z", "iopub.status.busy": "2026-06-09T16:08:22.335726Z", "iopub.status.idle": "2026-06-09T16:08:22.337834Z", "shell.execute_reply": "2026-06-09T16:08:22.337432Z" } }, "outputs": [], "source": [ "from folktexts.acs import ACSTaskMetadata\n", "from folktexts.qa_interface import ChainOfThoughtQA\n", "\n", "task = ACSTaskMetadata.get_task(name=TASK_NAME)\n", "# Build a chain-of-thought question from the task's numeric question.\n", "_base_q = task.direct_numeric_qa\n", "task.set_question(ChainOfThoughtQA(column=_base_q.column, text=_base_q.text))" ] }, { "cell_type": "code", "execution_count": 7, "id": "0cfa6291-5c06-4618-a14d-86e49f2f4969", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.338832Z", "iopub.status.busy": "2026-06-09T16:08:22.338740Z", "iopub.status.idle": "2026-06-09T16:08:22.344559Z", "shell.execute_reply": "2026-06-09T16:08:22.344180Z" } }, "outputs": [], "source": [ "from folktexts.classifier import WebAPILLMClassifier\n", "\n", "llm_clf = WebAPILLMClassifier(\n", " model_name=MODEL_NAME,\n", " task=task,\n", " batch_size=20,\n", ")" ] }, { "cell_type": "markdown", "id": "53f969d6-9d0e-48c2-9b34-a57aef0edb5c", "metadata": {}, "source": [ "### Load Dataset" ] }, { "cell_type": "code", "execution_count": 8, "id": "820004cb-25ce-4802-85c3-6ab251881141", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:22.345539Z", "iopub.status.busy": "2026-06-09T16:08:22.345443Z", "iopub.status.idle": "2026-06-09T16:08:59.933495Z", "shell.execute_reply": "2026-06-09T16:08:59.932724Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading ACS data...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "dataset.subsampling=0.0005\n", "CPU times: user 22.8 s, sys: 14.7 s, total: 37.5 s\n", "Wall time: 37.6 s\n" ] } ], "source": [ "%%time\n", "from folktexts.acs import ACSDataset\n", "dataset = ACSDataset.make_from_task(\n", " task=task,\n", " survey_year=\"2018\",\n", " subsampling=0.0005, # gpt-5-mini CoT is slow (~17s/row); keep the demo small\n", " cache_dir=DATA_DIR,\n", ")\n", "print(f\"{dataset.subsampling=}\")" ] }, { "cell_type": "markdown", "id": "8a648f32-05f4-4057-8624-2a39bfe7ab55", "metadata": {}, "source": [ "### Example tabular row, prompt, and LLM prediction" ] }, { "cell_type": "code", "execution_count": 9, "id": "a83832fc-47c4-4e14-adc9-6393f197e2c2", "metadata": { "execution": { "iopub.execute_input": "2026-06-09T16:08:59.935179Z", "iopub.status.busy": "2026-06-09T16:08:59.935048Z", "iopub.status.idle": "2026-06-09T16:08:59.945397Z", "shell.execute_reply": "2026-06-09T16:08:59.945002Z" } }, "outputs": [ { "data": { "text/html": [ "
| \n", " | AGEP | \n", "COW | \n", "SCHL | \n", "MAR | \n", "OCCP | \n", "POBP | \n", "RELP | \n", "WKHP | \n", "SEX | \n", "RAC1P | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 1068274 | \n", "59 | \n", "1.0 | \n", "18.0 | \n", "2 | \n", "3603.0 | \n", "17 | \n", "0 | \n", "40.0 | \n", "2 | \n", "1 | \n", "