Source code for folktexts.cli.download_models

#!/usr/bin/env python3
import gc
import logging
from argparse import ArgumentParser
from pathlib import Path

import torch
from tqdm import tqdm

# Temporary directory to use as download cache
TMP_DIR = Path("/tmp/")

# Default list of models to download
DEFAULT_MODEL_LIST = [
    # Google Gemma2 models
    "google/gemma-2b",
    "google/gemma-1.1-2b-it",
    "google/gemma-7b",
    "google/gemma-1.1-7b-it",

    "google/gemma-2-9b",
    "google/gemma-2-9b-it",
    "google/gemma-2-27b",
    "google/gemma-2-27b-it",

    # Meta Llama3 models
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Meta-Llama-3-8B-Instruct",
    "meta-llama/Meta-Llama-3-70B",
    "meta-llama/Meta-Llama-3-70B-Instruct",

    # Mistral AI models
    "mistralai/Mistral-7B-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2",
    "mistralai/Mixtral-8x7B-v0.1",
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mixtral-8x22B-v0.1",
    "mistralai/Mixtral-8x22B-Instruct-v0.1",

    # Yi models
    "01-ai/Yi-34B",
    "01-ai/Yi-34B-Chat",

    # Qwen2 models
    # "Qwen/Qwen2-1.5B",
    # "Qwen/Qwen2-1.5B-Instruct",
    # "Qwen/Qwen2-7B",
    # "Qwen/Qwen2-7B-Instruct",
    # "Qwen/Qwen2-72B",
    # "Qwen/Qwen2-72B-Instruct",
]


[docs] def setup_arg_parser() -> ArgumentParser: # Init parser parser = ArgumentParser(description="Download huggingface transformer models and tokenizers to disk") parser.add_argument( "--model", type=str, help="[string] Model name on huggingface hub - can provide multiple models!", required=False, action="append", ) parser.add_argument( "--save-dir", type=str, help="[string] Directory to save the downloaded models to", required=True, ) parser.add_argument( "--tmp-cache-dir", type=str, help="[string] Cache dir to temporarily download models to", required=False, default=TMP_DIR, ) return parser
[docs] def is_bf16_compatible() -> bool: """Checks if the current environment is bfloat16 compatible.""" return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
[docs] def main(): # Parse command-line arguments parser = setup_arg_parser() args = parser.parse_args() # Process arguments model_list = args.model or DEFAULT_MODEL_LIST save_dir = Path(args.save_dir).expanduser().resolve() save_dir.mkdir(exist_ok=True, parents=False) cache_dir = Path(args.tmp_cache_dir).expanduser().resolve() cache_dir.mkdir(exist_ok=True, parents=False) for model_name in tqdm(model_list): # Create sub-folder to save this model to from folktexts.llm_utils import get_model_folder_path curr_save_dir = get_model_folder_path(model_name, root_dir=save_dir) # If model already exists on disk, skip if Path(curr_save_dir).exists(): logging.warning(f"Model '{model_name}' already exists at '{curr_save_dir}'") continue # Download model to tmp dir from folktexts.llm_utils import load_model_tokenizer model, tokenizer = load_model_tokenizer(model_name, cache_dir=cache_dir) # Save model and tokenizer to disk print(f"Saving {model_name} to {curr_save_dir}") model.save_pretrained(curr_save_dir) tokenizer.save_pretrained(curr_save_dir) # Delete references to the model and tokenizer and force garbage collection del model del tokenizer gc.collect() # Empty VRAM if GPU is available if torch.cuda.is_available(): torch.cuda.empty_cache()
if __name__ == "__main__": main()