Source code for folktexts.llm_utils

"""Common functions to use with transformer LLMs."""
from __future__ import annotations

import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

if TYPE_CHECKING:
    from folktexts.qa_interface import DirectNumericQA, MultipleChoiceQA

# Will warn if the sum of digit probabilities is below this threshold
PROB_WARN_THR = 0.5


def _apply_chat_template_batch(
    inputs: list[str],
    *,
    tokenizer: AutoTokenizer,
    enable_thinking: bool | None,
) -> list[str]:
    """Apply tokenizer chat template to a list of prompts, if requested.

    Parameters
    ----------
    inputs : list[str]
        Raw (non-chat) prompts.
    tokenizer : AutoTokenizer
        Tokenizer used to apply the chat template.
    enable_thinking : bool | None
        If None, no chat template is applied. If True/False, chat template is
        applied and (if supported) the `enable_thinking` kwarg is forwarded.

    Returns
    -------
    formatted_inputs : list[str]
        Prompts formatted using the tokenizer's chat template when requested.
    """
    if enable_thinking is None:
        return inputs

    # Base models (e.g. raw Llama-3-8B, GPT-2) have no chat template; falling
    # through to apply_chat_template would raise ValueError mid-batch. Use the
    # raw prompts instead — the CoT prompt itself is already self-contained.
    if getattr(tokenizer, "chat_template", None) is None:
        if enable_thinking:
            logging.warning(
                "Tokenizer has no chat_template; cannot honor enable_thinking=True. "
                "Falling back to raw prompts (base model)."
            )
        else:
            logging.info("Tokenizer has no chat_template; using raw prompts (base model).")
        return inputs

    processed: list[str] = []
    for text in inputs:
        messages = [{"role": "user", "content": text}]
        try:
            processed.append(
                tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    enable_thinking=enable_thinking,
                )
            )
        except TypeError:
            if enable_thinking:
                logging.warning(
                    "Tokenizer does not support 'enable_thinking' parameter. "
                    "Falling back to standard chat template."
                )
            processed.append(
                tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
            )

    logging.debug(f"Applied chat template (enable_thinking={enable_thinking})")
    return processed


def _postprocess_generated_text(
    full_text: str,
    *,
    enable_thinking: bool | None,
    i: int,
    n: int,
) -> str:
    """Return the content to use for downstream extraction.

    In thinking mode, the model may emit a `<think> ... </think>` block. We log
    the thinking content (debug) but return only the response content after
    `</think>` for extraction.

    Parameters
    ----------
    full_text : str
        Full decoded generation (new tokens only).
    enable_thinking : bool | None
        Whether thinking mode is enabled.
    i : int
        Index of this generation in the batch (0-based).
    n : int
        Total number of generations in the batch.

    Returns
    -------
    response_text : str
        Text to be used for probability extraction.
    """
    if enable_thinking is not True:
        logging.debug(f"=== Generated output {i+1}/{n} ===")
        logging.debug(f"Content ({len(full_text)} chars):\n{full_text[:500]}...")
        return full_text.strip()

    think_end_marker = "</think>"
    if think_end_marker not in full_text:
        logging.warning(
            f"</think> marker not found in output (thinking mode was enabled). "
            f"Using full generated text ({len(full_text)} chars)."
        )
        return full_text.strip()

    parts = full_text.split(think_end_marker, 1)
    thinking_content = parts[0].strip()
    response_content = parts[1].strip() if len(parts) > 1 else ""

    logging.debug(f"=== Generated output {i+1}/{n} ===")
    logging.debug(f"Thinking content ({len(thinking_content)} chars) [IGNORED for extraction]:")
    logging.debug(f"{thinking_content[:500]}..." if len(thinking_content) > 500 else thinking_content)
    logging.debug(f"Response content ({len(response_content)} chars) [USED for extraction]:")
    logging.debug(response_content)

    if response_content:
        return response_content

    logging.warning(
        "Response content after </think> is empty. "
        "Model may not have generated a proper response. "
        "Probability extraction will likely fail."
    )
    return ""


[docs] def query_model_batch( text_inputs: list[str], model: AutoModelForCausalLM, tokenizer: AutoTokenizer, context_size: int ) -> np.array: """Queries the model with a batch of text inputs. Parameters ---------- text_inputs : list[str] The inputs to the model as a list of strings. model : AutoModelForCausalLM The model to query. tokenizer : AutoTokenizer The tokenizer used to encode the text inputs. context_size : int The maximum context size to consider for each input (in tokens). Returns ------- last_token_probs : np.array Model's last token *linear* probabilities for each input as an np.array of shape (batch_size, vocab_size). """ model_device = next(model.parameters()).device # Tokenize token_inputs = [tokenizer.encode(text, return_tensors="pt").flatten()[-context_size:] for text in text_inputs] idx_last_token = [tok_seq.shape[0] - 1 for tok_seq in token_inputs] # Pad tensor_inputs = torch.nn.utils.rnn.pad_sequence( token_inputs, batch_first=True, padding_value=tokenizer.pad_token_id, ).to(model_device) # Mask padded context attention_mask = tensor_inputs.ne(tokenizer.pad_token_id) # Query: run one forward pass, i.e., generate the next token with torch.no_grad(): logits = model(input_ids=tensor_inputs, attention_mask=attention_mask).logits # Probabilities corresponding to the last token after the prompt last_token_logits = logits[torch.arange(len(idx_last_token)), idx_last_token] last_token_probs = torch.nn.functional.softmax(last_token_logits, dim=-1) return last_token_probs.to(dtype=torch.float16).cpu().numpy()
[docs] def query_model_batch_multiple_passes( text_inputs: list[str], model: AutoModelForCausalLM, tokenizer: AutoTokenizer, context_size: int, n_passes: int, digits_only: bool = False, ) -> np.array: """Queries an LM for multiple forward passes. Greedy token search over multiple forward passes: Each forward pass takes the highest likelihood token from the previous pass. NOTE: could use model.generate in the future! Parameters ---------- text_inputs : list[str] The batch inputs to the model as a list of strings. model : AutoModelForCausalLM The model to query. tokenizer : AutoTokenizer The tokenizer used to encode the text inputs. context_size : int The maximum context size to consider for each input (in tokens). n_passes : int, optional The number of forward passes to run. digits_only : bool, optional Whether to only sample for digit tokens. Returns ------- last_token_probs : np.array Last token *linear* probabilities for each forward pass, for each text in the input batch. The output has shape (batch_size, n_passes, vocab_size). """ # Mask is sized to the model's logits dim, not the tokenizer's vocab dict: # neither `len(tokenizer.vocab)` nor `tokenizer.vocab_size` is reliable # (Gemma-3 has `len(vocab) == vocab_size + 1`; Llama-3.2 has # `len(vocab) == vocab_size + 256`). Only `model.config.vocab_size` matches # the actual logits axis we're masking. Multimodal Gemma-3 puts vocab_size # under `config.text_config` instead of the top-level config. vocab_dim = getattr(model.config, "vocab_size", None) if vocab_dim is None: vocab_dim = getattr(getattr(model.config, "text_config", None), "vocab_size", None) if vocab_dim is None: raise AttributeError( f"Could not resolve vocab_size from {type(model.config).__name__} " "(checked top-level and text_config)." ) allowed_tokens_filter = np.ones(vocab_dim, dtype=bool) if digits_only: allowed_token_ids = np.array([ tok_id for token, tok_id in tokenizer.vocab.items() if token.isdecimal() and tok_id < vocab_dim ]) allowed_tokens_filter = np.zeros(vocab_dim, dtype=bool) allowed_tokens_filter[allowed_token_ids] = True # Current text batch current_batch = text_inputs # For each forward pass, add one token to each text in the batch last_token_probs = [] for iter in range(n_passes): # Query the model with the current batch current_probs = query_model_batch(current_batch, model, tokenizer, context_size) # Filter out probabilities for tokens that are not allowed current_probs[:, ~allowed_tokens_filter] = 0 # Sanity check digit probabilities if iter == 0 and digits_only: total_digit_probs = np.sum(current_probs, axis=-1) if any(probs < PROB_WARN_THR for probs in total_digit_probs): logging.error(f"Digit probabilities are too low: {total_digit_probs}") # Add the highest likelihood token to each text in the batch next_tokens = [tokenizer.decode([np.argmax(probs)]) for probs in current_probs] current_batch = [text + next_token for text, next_token in zip(current_batch, next_tokens)] # Store the probabilities of the last token for each text in the batch last_token_probs.append(current_probs) # Cast output to np.array with correct shape last_token_probs_array = np.array(last_token_probs) last_token_probs_array = np.moveaxis(last_token_probs_array, 0, 1) assert last_token_probs_array.shape == (len(text_inputs), n_passes, vocab_dim) return last_token_probs_array
[docs] def decode_topk_logprobs_to_risk_estimate( per_pass_topk: list[dict[int, float]], *, tokenizer_vocab: dict[str, int], vocab_dim: int, question: "MultipleChoiceQA | DirectNumericQA", ) -> float: """Convert top-K log-probabilities into a single risk-estimate float. Parameters ---------- per_pass_topk : list[dict[int, float]] One dict per generated token position, mapping token_id -> log-prob. The token_ids must match the values in `tokenizer_vocab`. Tokens absent from the top-K are assumed to have probability ~0. tokenizer_vocab : dict[str, int] Token string -> token_id map used by the QA decoder for prefix-variant lookup (MultipleChoiceQA) or digit/decimal lookup (DirectNumericQA). vocab_dim : int Size of the linear-probability array's vocab axis. For local backends this is `model.config.vocab_size` (the logits axis); for the synthetic WebAPI path it is the size of the synthesised vocab. question : MultipleChoiceQA | DirectNumericQA The QA interface used to interpret the probabilities. Returns ------- risk_estimate : float Risk score in [0, 1] from `question.get_answer_from_model_output`. Notes ----- Both the WebAPI backend (top_logprobs=20 from OpenAI-style responses) and the vLLM backend (top-K logprobs from `SamplingParams(logprobs=K)`) call this helper. The transformers backend reads the full softmax directly and bypasses this path; see `query_model_batch_multiple_passes`. """ n_passes = len(per_pass_topk) probs = np.zeros((n_passes, vocab_dim), dtype=np.float64) for i, pass_dict in enumerate(per_pass_topk): for tok_id, logprob in pass_dict.items(): if 0 <= tok_id < vocab_dim: probs[i, tok_id] = float(np.exp(logprob)) # Drop tokenizer-vocab entries that point past the array. MultipleChoiceQA's # decoder does an unchecked `last_token_probs[choice_token_id]` lookup; on # tokenizers where added tokens sit beyond `model.config.vocab_size` # (Llama-3.2, Gemma-3) those ids would IndexError. DirectNumericQA already # filters the same way internally — this keeps both modes consistent. in_range_vocab = { tok: tok_id for tok, tok_id in tokenizer_vocab.items() if 0 <= tok_id < vocab_dim } return question.get_answer_from_model_output( probs, tokenizer_vocab=in_range_vocab, )
[docs] def generate_text_batch( text_inputs: list[str], model: AutoModelForCausalLM, tokenizer: AutoTokenizer, max_new_tokens: int = 1024, context_size: int = None, enable_thinking: bool = None, ) -> list[str]: """Generate text completions for a batch of prompts. Uses the model's generate() method for autoregressive text generation, suitable for chain-of-thought Q&A where the model needs to produce free-form text before outputting a probability estimate. Generation is greedy (do_sample=False) so runs are reproducible — matches the web-API path's temperature=0 contract. Parameters ---------- text_inputs : list[str] The input prompts as a list of strings. model : AutoModelForCausalLM The model to use for generation. tokenizer : AutoTokenizer The tokenizer used to encode/decode text. max_new_tokens : int, optional Maximum number of new tokens to generate, by default 1024. context_size : int, optional The maximum context size for input tokens. If None, no truncation is applied to inputs. enable_thinking : bool, optional Controls chat template application and thinking mode: - None: Do not apply chat template (use raw prompts, for base models) - False: Apply chat template WITHOUT thinking mode (for instruction-tuned models) - True: Apply chat template WITH thinking mode, and extract response content after </think> marker (for thinking models like Qwen3) Returns ------- generated_texts : list[str] The generated text completions for each input prompt. Only the newly generated tokens are returned (not the input prompt). """ model_device = next(model.parameters()).device # Save original padding side and set to left for generation # (decoder-only models require left-padding for correct generation) original_padding_side = tokenizer.padding_side tokenizer.padding_side = "left" try: formatted_inputs = _apply_chat_template_batch( text_inputs, tokenizer=tokenizer, enable_thinking=enable_thinking, ) tokenized = tokenizer( formatted_inputs, return_tensors="pt", padding=True, truncation=True if context_size else False, max_length=context_size, ) tensor_inputs = tokenized.input_ids.to(model_device) attention_mask = tokenized.attention_mask.to(model_device) input_seq_length = tensor_inputs.shape[1] with torch.no_grad(): outputs = model.generate( input_ids=tensor_inputs, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=False, ) generated_texts: list[str] = [] for i, output in enumerate(outputs): generated_tokens = output[input_seq_length:] full_generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) generated_texts.append( _postprocess_generated_text( full_generated_text, enable_thinking=enable_thinking, i=i, n=len(outputs), ) ) return generated_texts finally: tokenizer.padding_side = original_padding_side
[docs] def add_pad_token(tokenizer): """Add a pad token to the model and tokenizer if it doesn't already exist. Here we're using the end-of-sentence token as the pad token. Both the model weights and tokenizer vocabulary are untouched. Another possible way would be to add a new token `[PAD]` to the tokenizer and update the tokenizer vocabulary and model weight embeddings accordingly. The embedding for the new pad token would be the average of all other embeddings. """ if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
[docs] def is_bf16_compatible() -> bool: """Checks if the current environment is bfloat16 compatible.""" return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
[docs] def load_model_tokenizer(model_name_or_path: str | Path, **kwargs) -> tuple[AutoModelForCausalLM, AutoTokenizer]: """Load a model and tokenizer from the given local path (or using the model name). Parameters ---------- model_name_or_path : str | Path Model name or local path to the model folder. kwargs : dict Additional keyword arguments to pass to the model `from_pretrained` call. Returns ------- tuple[AutoModelForCausalLM, AutoTokenizer] The loaded model and tokenizer, respectively. """ logging.info(f"Loading model '{model_name_or_path}'") # Load tokenizer from disk tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) # Set default keyword arguments for loading the pretrained model model_kwargs = dict( torch_dtype=torch.bfloat16 if is_bf16_compatible() else torch.float16, trust_remote_code=True, device_map="auto", ) model_kwargs.update(kwargs) # Load model from disk model = AutoModelForCausalLM.from_pretrained( model_name_or_path, **model_kwargs, ) # Add pad token to the tokenizer if it doesn't already exist add_pad_token(tokenizer) # Move model to the correct device device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" logging.info(f"Moving model to device: {device}") if model.device.type != device: model.to(device) return model, tokenizer
[docs] def load_vllm_model( model_name_or_path: str | Path, *, dtype: str = "auto", gpu_memory_utilization: float = 0.85, max_model_len: int | None = None, tensor_parallel_size: int = 1, trust_remote_code: bool = True, seed: int = 42, max_logprobs: int = 50, **kwargs, ): """Load a vLLM `LLM` engine and its tokenizer. Mirrors `load_model_tokenizer` for the vLLM backend. vLLM allocates the KV cache statically at startup based on `gpu_memory_utilization` and `max_model_len`; tune these per-GPU. `vllm` is an optional install — if it is not importable, this function raises a pointed error. Parameters ---------- model_name_or_path : str | Path Model name or local path to the model folder. Pre-cached snapshots under `/fast/groups/sf/huggingface-models/` work without download. dtype : str, optional Compute dtype: ``"auto"`` (default; vLLM picks bf16/fp16 from the config), ``"bfloat16"``, ``"float16"``, or ``"float32"``. gpu_memory_utilization : float, optional Fraction of GPU VRAM vLLM may use for weights + KV cache. Default 0.85 (vLLM's own default is 0.9, which is aggressive on shared cluster nodes). vLLM fails fast at startup if this isn't enough — bump down if you hit OOM at LLM(). max_model_len : int, optional Maximum number of tokens (input + output) per request. If ``None``, vLLM reads it from the model config — which on some Llama checkpoints is 131072 and will allocate enormous KV cache. Pass an explicit value sized as ``context_size + max_new_tokens + buffer`` for the workload. tensor_parallel_size : int, optional Number of GPUs to shard the model across; default 1. Set higher when the cluster job grants multiple GPUs and the model fits with tensor-parallel sharding. trust_remote_code : bool, optional Forwarded to vLLM (mirrors `load_model_tokenizer`). seed : int, optional Random seed for vLLM. Doesn't affect greedy (`temperature=0`) decoding but pinned for safety. max_logprobs : int, optional Engine-level cap on top-K logprobs SamplingParams may request. Default 50 — must be ≥ ``VLLMClassifier._TOPK_LOGPROBS`` or the engine rejects the request at predict time (`VLLMValidationError: Requested sample logprobs of K, which is greater than max allowed`). **kwargs Additional keyword arguments forwarded verbatim to ``vllm.LLM(...)``. Returns ------- tuple[vllm.LLM, AutoTokenizer] Loaded engine and its tokenizer. The tokenizer has had `add_pad_token` applied so it matches the transformers path's tokenizer state. """ try: from vllm import LLM except ImportError as exc: # pragma: no cover - exercised in user-facing CLI raise ImportError( "vLLM is not installed. Install the optional extra with " "`pip install 'folktexts[vllm]'`, or run with " "`--inference-backend transformers` to use the HuggingFace path." ) from exc # vLLM is extremely chatty during model loading; quieten it unless the # caller has explicitly opted into verbose logs. os.environ.setdefault("VLLM_LOGGING_LEVEL", "WARNING") # `processed_logprobs` returns top-K logprobs computed AFTER `allowed_token_ids` # masking. The default `raw_logprobs` would return top-K from the unmasked # distribution — which on `DirectNumericQA` causes the decoder to see non-digit # tokens (e.g., '.', '\n') as high-probability "numeric tokens" and pick them # over the only-allowed digit, collapsing Llama-3 base numeric output to 0.5 # (answer text "5." → regex "5" → 0.5). MC has no `allowed_token_ids` so this # defaults to the raw distribution either way; numeric is the only mode this # affects. kwargs.setdefault("logprobs_mode", "processed_logprobs") logging.info(f"Loading vLLM model '{model_name_or_path}'") llm = LLM( model=str(model_name_or_path), dtype=dtype, gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len, tensor_parallel_size=tensor_parallel_size, trust_remote_code=trust_remote_code, seed=seed, max_logprobs=max_logprobs, **kwargs, ) tokenizer = llm.get_tokenizer() add_pad_token(tokenizer) return llm, tokenizer
[docs] def get_model_folder_path(model_name: str, root_dir="/tmp") -> str: """Returns the folder where the model is saved.""" folder_name = model_name.replace("/", "--") return (Path(root_dir) / folder_name).resolve().as_posix()
[docs] def get_model_size_B(model_name: str, default: int = None) -> int: """Get the model size from the model name, in Billions of parameters. """ regex = re.search(r"((?P<times>\d+)[xX])?(?P<size>\d+)[bB]", model_name) if regex: return int(regex.group("size")) * int(regex.group("times") or 1) if default is not None: return default logging.warning(f"Could not infer model size from name '{model_name}'.") return default