"""Module to access ACS data using the folktables package.
"""
from __future__ import annotations
import logging
from pathlib import Path
import pandas as pd
from folktables import ACSDataSource
from folktables.load_acs import state_list
from ..dataset import Dataset
from .acs_tasks import ACSTaskMetadata
DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = 0.1
DEFAULT_SEED = 42
DEFAULT_SURVEY_YEAR = "2018"
DEFAULT_SURVEY_HORIZON = "1-Year"
DEFAULT_SURVEY_UNIT = "person"
[docs]
class ACSDataset(Dataset):
"""Wrapper for ACS folktables datasets."""
def __init__(
self,
data: pd.DataFrame,
full_acs_data: pd.DataFrame,
task: ACSTaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
):
self._full_acs_data = full_acs_data
super().__init__(
data=data,
task=task,
test_size=test_size,
val_size=val_size,
subsampling=subsampling,
seed=seed,
)
[docs]
@classmethod
def make_from_task(
cls,
task: str | ACSTaskMetadata,
cache_dir: str | Path = None,
survey_year: str = DEFAULT_SURVEY_YEAR,
horizon: str = DEFAULT_SURVEY_HORIZON,
survey: str = DEFAULT_SURVEY_UNIT,
seed: int = DEFAULT_SEED,
**kwargs,
):
"""Construct an ACSDataset object from a given ACS task.
Can customize survey sample parameters (survey year, horizon, survey type).
Parameters
----------
task : str | ACSTaskMetadata
The name of the ACS task or the task object itself.
cache_dir : str | Path, optional
The directory where ACS data is (or will be) saved to, by default
uses DEFAULT_DATA_DIR.
survey_year : str, optional
The year from which to load survey data, by default DEFAULT_SURVEY_YEAR.
horizon : str, optional
The time horizon of survey data to load, by default DEFAULT_SURVEY_HORIZON.
survey : str, optional
The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT.
seed : int, optional
The random seed, by default DEFAULT_SEED.
**kwargs
Extra key-word arguments to be passed to the Dataset constructor.
"""
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables"
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)
# Parse task if given a string
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task
# Load ACS data source
print("Loading ACS data...")
data_source = ACSDataSource(
survey_year=survey_year, horizon=horizon, survey=survey,
root_dir=cache_dir.as_posix(),
)
# Get full ACS dataset
full_acs_data = data_source.get_data(
states=state_list, download=True, random_seed=seed)
# Parse data for this task
parsed_data = cls._parse_task_data(full_acs_data, task_obj)
return cls(
data=parsed_data,
full_acs_data=full_acs_data,
task=task_obj,
seed=seed,
**kwargs,
)
@property
def task(self) -> ACSTaskMetadata:
return self._task
@task.setter
def task(self, new_task: ACSTaskMetadata):
# Parse data rows for new ACS task
self._data = self._parse_task_data(self._full_acs_data, new_task)
# Re-make train/test/val split
self._train_indices, self._test_indices, self._val_indices = (
self._make_train_test_val_split(
self._data, self.test_size, self.val_size, self._rng)
)
# Check if sub-sampling is necessary (it's applied only to train/test/val indices)
if self.subsampling is not None:
self._subsample_train_test_val_indices(self.subsampling)
self._task = new_task
@classmethod
def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.DataFrame:
"""Parse a DataFrame for compatibility with the given task object.
Parameters
----------
full_df : pd.DataFrame
Full DataFrame. Some rows and/or columns may be discarded for each
task.
task : ACSTaskMetadata
The task object used to parse the given data.
Returns
-------
parsed_df : pd.DataFrame
Parsed DataFrame in accordance with the given task.
"""
# Pre-process the data if necessary
if isinstance(task, ACSTaskMetadata) and task.folktables_obj is not None:
parsed_df = task.folktables_obj._preprocess(full_df)
else:
parsed_df = full_df
# Threshold the target column if necessary
if task.target is not None and task.target_threshold is not None and task.get_target() not in parsed_df.columns:
parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target])
return parsed_df