from typing import List, Optional import torch from vllm.config import VllmConfig from vllm.v1.sample.logits_processor import ( AdapterLogitsProcessor, RequestLogitsProcessor, ) from vllm.sampling_params import SamplingParams import os from collections import Counter CHUNK_SIZE=16384 WINDOW_SIZE=256 MAX_REPETATION_COUNT=7 class ThinkLogitsProcessor: def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95): self.think_end_token = think_end_token self.min_answer_budget = 4096 self.max_len = max_len self.ratio = ratio self.interval = 4096 def find_repeated_ngrams(self, input_ids, n=512): """ input_ids: list of integer tokens n: n-gram size returns dict of {ngram_tuple: count} for repeated n-grams """ ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, WINDOW_SIZE)] freq = Counter(ngrams) return {ng: c for ng, c in freq.items() if c > MAX_REPETATION_COUNT} def __call__( self, prompt_token_ids: List[int], past_token_ids: List[int], logits: torch.Tensor ) -> torch.Tensor: if self.think_end_token not in past_token_ids: # ratio tokens_since_think = len(past_token_ids) response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio))) remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think if 0 >= remaining_budget: logits = torch.full_like(logits, torch.finfo(logits.dtype).min) logits[self.think_end_token] = 1.0 # ngram elif len(past_token_ids) % self.interval == 0: # If repetation detected, force if self.find_repeated_ngrams(past_token_ids, n=CHUNK_SIZE): # Set all other logits to -inf except for logits = torch.full_like(logits, torch.finfo(logits.dtype).min) logits[self.think_end_token] = 1.0 return logits class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor): def __init__(self, vllm_config: VllmConfig, device: torch.device,is_pin_memory: bool): super().__init__(vllm_config, device, is_pin_memory) self.model_max_len = vllm_config.model_config.max_model_len assert self.model_max_len, "specify --model-max-len if using ratiologitprocessor" self.ratio = float(os.environ.get("VLLM_THINK_BUDGET_RATIO", "0.0")) assert 1 >= self.ratio > 0, "specify env var VLLM_THINK_BUDGET_RATIO in 0.0 < R =< 1.0 if using ratiologitprocessor" def is_argmax_invariant(self) -> bool: return False def new_req_logits_processor( self, params: SamplingParams, ) -> Optional[RequestLogitsProcessor]: return ThinkLogitsProcessor(max_len = self.model_max_len, ratio = self.ratio)