{ "cells": [ { "cell_type": "markdown", "id": "e85e7f8a-dcfa-489a-9a8e-77225d2e3499", "metadata": {}, "source": [ "# Example: Run ACS benchmark task" ] }, { "cell_type": "code", "execution_count": 1, "id": "98180e30-da85-41fb-a360-35b6acacb170", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "folktexts.__version__='0.0.21'\n" ] } ], "source": [ "import folktexts\n", "print(f\"{folktexts.__version__=}\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "e5c41432-8ee0-4fbb-b37b-4eea92504382", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import torch\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 3, "id": "7a4b1e7f-a893-488e-96f1-4094857115f2", "metadata": {}, "outputs": [], "source": [ "import logging\n", "logging.getLogger().setLevel(logging.INFO)" ] }, { "cell_type": "markdown", "id": "04d5e737-9b3c-468a-bd68-b1c161f57c84", "metadata": {}, "source": [ "## Set important local paths\n", "**NOTE:** Can be ignored if you haven't previously downloaded the model, just use `load_model_tokenizer` with the model's name on Huggingface." ] }, { "cell_type": "markdown", "id": "1cf6dfe9-4ddc-448f-afb9-0e4802ae480d", "metadata": {}, "source": [ "Set your root directory (_**change as appropriate**_):" ] }, { "cell_type": "code", "execution_count": 4, "id": "048f506d-83a6-4911-b5f3-2fc290ed9fd8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/fast/groups/sf')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ROOT_DIR = Path(\"/fast/groups/sf\")\n", "ROOT_DIR" ] }, { "cell_type": "markdown", "id": "0cb6c30e-5af1-4733-8732-788cba086eaf", "metadata": {}, "source": [ "Directory where LLMs are saved (_**change as appropriate**_):" ] }, { "cell_type": "code", "execution_count": 5, "id": "a7e60a72-e0c4-4fa8-a5d5-402f82098872", "metadata": {}, "outputs": [], "source": [ "MODELS_DIR = ROOT_DIR / \"huggingface-models\"" ] }, { "cell_type": "markdown", "id": "efbc677d-5208-4a8a-aab8-9cb7282ca39f", "metadata": {}, "source": [ "Directory where data is saved or will be saved to (_**change as appropriate**_):" ] }, { "cell_type": "code", "execution_count": 6, "id": "8040b3cd-da53-4b62-8aae-58b2bb7cba69", "metadata": {}, "outputs": [], "source": [ "DATA_DIR = ROOT_DIR / \"data\"" ] }, { "cell_type": "markdown", "id": "985344ec-6c9c-4e8b-8d8e-26f009937a7c", "metadata": {}, "source": [ "Other configs:" ] }, { "cell_type": "code", "execution_count": 7, "id": "a4316d94-110f-4014-a823-95300ac87eb8", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME = \"meta-llama/Meta-Llama-3-8B-Instruct\"\n", "# MODEL_NAME = \"google/gemma-2b\" # Smaller model that is faster to run\n", "\n", "TASK_NAME = \"ACSIncome\"\n", "\n", "RESULTS_ROOT_DIR = ROOT_DIR / \"folktexts-results\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "7e0f598d-dd37-471f-a0d4-4b3c1ae07167", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Loading model '/lustre/fast/fast/groups/sf/huggingface-models/meta-llama--Meta-Llama-3-8B-Instruct'\n", "INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f0a08cdd63fa4bf4815578e005a80497", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/4 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Moving model to device: cuda\n" ] } ], "source": [ "from folktexts.llm_utils import load_model_tokenizer, get_model_folder_path\n", "model_folder_path = get_model_folder_path(model_name=MODEL_NAME, root_dir=MODELS_DIR)\n", "model, tokenizer = load_model_tokenizer(model_folder_path)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ab088b1f-37ba-4d2e-b5d2-b829787b7fed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/fast/groups/sf/folktexts-results/meta-llama--Meta-Llama-3-8B-Instruct')" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "results_dir = RESULTS_ROOT_DIR / Path(model_folder_path).name\n", "results_dir.mkdir(exist_ok=True, parents=True)\n", "results_dir" ] }, { "cell_type": "markdown", "id": "2d87da67-3857-421f-b761-81814ea5498d", "metadata": {}, "source": [ "### Construct LLM Classifier\n", "\n", "**NOTE:** Also compatible with models hosted through a web API by using the `WebAPILLMClassifier` class instead of `TransformersLLMClassifier`." ] }, { "cell_type": "markdown", "id": "e664c322-4da5-49b3-a755-3ba0a075ee64", "metadata": {}, "source": [ "Load prediction task (which maps tabular data to text):" ] }, { "cell_type": "code", "execution_count": 10, "id": "7f0b725d-d97a-40a9-a71d-88fb5fefbddf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Changing Q&A mode for task 'ACSIncome' to multiple-choice.\n" ] } ], "source": [ "from folktexts.acs import ACSTaskMetadata\n", "task = ACSTaskMetadata.get_task(TASK_NAME, use_numeric_qa=False)" ] }, { "cell_type": "code", "execution_count": 11, "id": "0cfa6291-5c06-4618-a14d-86e49f2f4969", "metadata": {}, "outputs": [], "source": [ "from folktexts.classifier import TransformersLLMClassifier\n", "llm_clf = TransformersLLMClassifier(\n", " model=model,\n", " tokenizer=tokenizer,\n", " task=task,\n", " batch_size=32,\n", " context_size=1000,\n", ")" ] }, { "cell_type": "markdown", "id": "53f969d6-9d0e-48c2-9b34-a57aef0edb5c", "metadata": {}, "source": [ "### Load Dataset" ] }, { "cell_type": "code", "execution_count": 12, "id": "820004cb-25ce-4802-85c3-6ab251881141", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading ACS data...\n", "CPU times: user 58 s, sys: 27.5 s, total: 1min 25s\n", "Wall time: 1min 26s\n" ] } ], "source": [ "%%time\n", "from folktexts.acs import ACSDataset\n", "dataset = ACSDataset.make_from_task(task=task, cache_dir=DATA_DIR)" ] }, { "cell_type": "markdown", "id": "8145eddd-09da-4db9-8779-c61b04c3d777", "metadata": {}, "source": [ "Optionally, subsample to quickly get approximate results:" ] }, { "cell_type": "code", "execution_count": 13, "id": "c977cac3-162c-4cde-8594-e6795f62e29a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Train size: 13316, Test size: 1665, Val size: 1665;\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "dataset.subsampling=0.01\n" ] } ], "source": [ "dataset.subsample(0.01)\n", "print(f\"{dataset.subsampling=}\")" ] }, { "cell_type": "markdown", "id": "a3b4ce61-ecbc-4491-bd7b-0cfae38c6c10", "metadata": {}, "source": [ "### Load and run ACS Benchmark\n", "\n", "**_Note:_** Helper constructors exist at `Benchmark.make_acs_benchmark` and `Benchmark.make_benchmark` that avoid the above boilerplate code." ] }, { "cell_type": "code", "execution_count": 14, "id": "10d362b6-aac8-49dd-84c3-0ce799ffabba", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:\n", "** Benchmark initialization **\n", "Model: meta-llama--Meta-Llama-3-8B-Instruct;\n", "Task: ACSIncome;\n", "Hash: 516109582;\n", "\n" ] } ], "source": [ "from folktexts.benchmark import BenchmarkConfig, Benchmark\n", "\n", "bench = Benchmark(llm_clf=llm_clf, dataset=dataset)" ] }, { "cell_type": "markdown", "id": "8a648f32-05f4-4057-8624-2a39bfe7ab55", "metadata": {}, "source": [ "Here's an example prompt for the current prediction task:" ] }, { "cell_type": "code", "execution_count": 15, "id": "a83832fc-47c4-4e14-adc9-6393f197e2c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The following data corresponds to a survey respondent. The survey was conducted among US residents in 2018. Please answer the question based on the information provided. The data provided is enough to reach an approximate answer.\n", "\n", "Information:\n", "- The age is: 37 years old.\n", "- The class of worker is: Owner of non-incorporated business, professional practice, or farm.\n", "- The highest educational attainment is: Regular high school diploma.\n", "- The marital status is: Married.\n", "- The occupation is: Painters and paperhangers.\n", "- The place of birth is: New Jersey.\n", "- The relationship to the reference person in the survey is: The reference person itself.\n", "- The usual number of hours worked per week is: 40 hours.\n", "- The sex is: Male.\n", "- The race is: White.\n", "\n", "Question: What is this person's estimated yearly income?\n", "A. Below $50,000.\n", "B. Above $50,000.\n", "Answer:\n" ] } ], "source": [ "X_sample, _y_sample = dataset.sample_n_train_examples(n=1)\n", "print(llm_clf.encode_row(X_sample.iloc[0], question=llm_clf.task.question))" ] }, { "cell_type": "markdown", "id": "e107c703-17d4-4543-8f56-73c7b9583589", "metadata": {}, "source": [ "Optionally, you can fit the model's threshold on a few data samples.\n", "\n", "This is generally quite fast as it is _not fine-tuning_; it only changes one parameter: the `llm_clf.threshold`." ] }, { "cell_type": "code", "execution_count": 16, "id": "16f88399-1b9e-4fd3-a1ca-0aa9f2b91080", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c0e4d1784cc8416eb3f30a9b542754b5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Computing risk estimates: 0%| | 0/2 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n", "WARNING:root:Setting meta-llama--Meta-Llama-3-8B-Instruct threshold to 0.7918923636748516.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 7.57 s, sys: 382 ms, total: 7.95 s\n", "Wall time: 9 s\n" ] }, { "data": { "text/html": [ "
TransformersLLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x14802eedf600>, task=ACSTaskMetadata(name='ACSIncome', description="predict whether an individual's income is above $50,000", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText...\n", "\t128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "\t128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "\t128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "})In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
TransformersLLMClassifier(encode_row=functools.partial(<function encode_row_prompt at 0x14802eedf600>, task=ACSTaskMetadata(name='ACSIncome', description="predict whether an individual's income is above $50,000", features=['AGEP', 'COW', 'SCHL', 'MAR', 'OCCP', 'POBP', 'RELP', 'WKHP', 'SEX', 'RAC1P'], target='PINCP', cols_to_text={'AGEP': <folktexts.col_to_text.ColumnToText...\n", "\t128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "\t128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "\t128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n", "})