add logitprocessor
Browse files- logit_processors/__init__.py +0 -0
- logit_processors/logit_.py +75 -0
logit_processors/__init__.py
ADDED
|
File without changes
|
logit_processors/logit_.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
import torch
|
| 3 |
+
from vllm.config import VllmConfig
|
| 4 |
+
from vllm.v1.sample.logits_processor import (
|
| 5 |
+
AdapterLogitsProcessor,
|
| 6 |
+
RequestLogitsProcessor,
|
| 7 |
+
)
|
| 8 |
+
from vllm.sampling_params import SamplingParams
|
| 9 |
+
import os
|
| 10 |
+
from loguru import logger
|
| 11 |
+
from collections import Counter
|
| 12 |
+
|
| 13 |
+
class ThinkLogitsProcessor:
|
| 14 |
+
def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
|
| 15 |
+
self.think_end_token = think_end_token
|
| 16 |
+
self.min_answer_budget = 4096
|
| 17 |
+
self.max_len = max_len
|
| 18 |
+
self.ratio = ratio
|
| 19 |
+
self.interval = 4096
|
| 20 |
+
|
| 21 |
+
def find_repeated_ngrams(self, input_ids, n=512):
|
| 22 |
+
"""
|
| 23 |
+
input_ids: list of integer tokens
|
| 24 |
+
n: n-gram size
|
| 25 |
+
returns dict of {ngram_tuple: count} for repeated n-grams
|
| 26 |
+
"""
|
| 27 |
+
ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, 256)]
|
| 28 |
+
freq = Counter(ngrams)
|
| 29 |
+
return {ng: c for ng, c in freq.items() if c > 7}
|
| 30 |
+
|
| 31 |
+
def __call__(
|
| 32 |
+
self,
|
| 33 |
+
prompt_token_ids: List[int],
|
| 34 |
+
past_token_ids: List[int],
|
| 35 |
+
logits: torch.Tensor
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
if self.think_end_token not in past_token_ids:
|
| 38 |
+
|
| 39 |
+
# ngram
|
| 40 |
+
if len(past_token_ids) % self.interval == 0:
|
| 41 |
+
# If repetation detected, force </think>
|
| 42 |
+
if self.find_repeated_ngrams(past_token_ids, n=16384):
|
| 43 |
+
# Set all other logits to -inf except for </think>
|
| 44 |
+
logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
|
| 45 |
+
logits[self.think_end_token] = 1.0
|
| 46 |
+
else:
|
| 47 |
+
# ratio
|
| 48 |
+
tokens_since_think = len(past_token_ids)
|
| 49 |
+
|
| 50 |
+
response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio)))
|
| 51 |
+
remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think
|
| 52 |
+
|
| 53 |
+
if 0 >= remaining_budget:
|
| 54 |
+
logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
|
| 55 |
+
logits[self.think_end_token] = 1.0
|
| 56 |
+
return logits
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
| 60 |
+
def __init__(self, vllm_config: VllmConfig, device: torch.device,is_pin_memory: bool):
|
| 61 |
+
super().__init__(vllm_config, device, is_pin_memory)
|
| 62 |
+
self.model_max_len = vllm_config.model_config.max_model_len
|
| 63 |
+
assert self.model_max_len, "specify --model-max-len if using ratiologitprocessor"
|
| 64 |
+
self.ratio = float(os.environ.get("VLLM_THINK_BUDGET_RATIO", "0.0"))
|
| 65 |
+
assert 1 >= self.ratio > 0, "specify env var VLLM_THINK_BUDGET_RATIO in 0.0 < R =< 1.0 if using ratiologitprocessor"
|
| 66 |
+
|
| 67 |
+
def is_argmax_invariant(self) -> bool:
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def new_req_logits_processor(
|
| 71 |
+
self,
|
| 72 |
+
params: SamplingParams,
|
| 73 |
+
) -> Optional[RequestLogitsProcessor]:
|
| 74 |
+
|
| 75 |
+
return ThinkLogitsProcessor(max_len = self.model_max_len, ratio = self.ratio)
|