leejunhyeok commited on
Commit
bc61b50
·
verified ·
1 Parent(s): 390323a

add logitprocessor

Browse files
logit_processors/__init__.py ADDED
File without changes
logit_processors/logit_.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+ from vllm.config import VllmConfig
4
+ from vllm.v1.sample.logits_processor import (
5
+ AdapterLogitsProcessor,
6
+ RequestLogitsProcessor,
7
+ )
8
+ from vllm.sampling_params import SamplingParams
9
+ import os
10
+ from loguru import logger
11
+ from collections import Counter
12
+
13
+ class ThinkLogitsProcessor:
14
+ def __init__(self, think_end_token = 219406, max_len: int = 131072, ratio: float = 0.95):
15
+ self.think_end_token = think_end_token
16
+ self.min_answer_budget = 4096
17
+ self.max_len = max_len
18
+ self.ratio = ratio
19
+ self.interval = 4096
20
+
21
+ def find_repeated_ngrams(self, input_ids, n=512):
22
+ """
23
+ input_ids: list of integer tokens
24
+ n: n-gram size
25
+ returns dict of {ngram_tuple: count} for repeated n-grams
26
+ """
27
+ ngrams = [tuple(input_ids[i:i+n]) for i in range(0, len(input_ids) - n + 1, 256)]
28
+ freq = Counter(ngrams)
29
+ return {ng: c for ng, c in freq.items() if c > 7}
30
+
31
+ def __call__(
32
+ self,
33
+ prompt_token_ids: List[int],
34
+ past_token_ids: List[int],
35
+ logits: torch.Tensor
36
+ ) -> torch.Tensor:
37
+ if self.think_end_token not in past_token_ids:
38
+
39
+ # ngram
40
+ if len(past_token_ids) % self.interval == 0:
41
+ # If repetation detected, force </think>
42
+ if self.find_repeated_ngrams(past_token_ids, n=16384):
43
+ # Set all other logits to -inf except for </think>
44
+ logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
45
+ logits[self.think_end_token] = 1.0
46
+ else:
47
+ # ratio
48
+ tokens_since_think = len(past_token_ids)
49
+
50
+ response_budget = max(self.min_answer_budget, int((self.max_len - len(prompt_token_ids)) * (1-self.ratio)))
51
+ remaining_budget = self.max_len - len(prompt_token_ids) - response_budget - tokens_since_think
52
+
53
+ if 0 >= remaining_budget:
54
+ logits = torch.full_like(logits, torch.finfo(torch.bfloat16).min)
55
+ logits[self.think_end_token] = 1.0
56
+ return logits
57
+
58
+
59
+ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
60
+ def __init__(self, vllm_config: VllmConfig, device: torch.device,is_pin_memory: bool):
61
+ super().__init__(vllm_config, device, is_pin_memory)
62
+ self.model_max_len = vllm_config.model_config.max_model_len
63
+ assert self.model_max_len, "specify --model-max-len if using ratiologitprocessor"
64
+ self.ratio = float(os.environ.get("VLLM_THINK_BUDGET_RATIO", "0.0"))
65
+ assert 1 >= self.ratio > 0, "specify env var VLLM_THINK_BUDGET_RATIO in 0.0 < R =< 1.0 if using ratiologitprocessor"
66
+
67
+ def is_argmax_invariant(self) -> bool:
68
+ return False
69
+
70
+ def new_req_logits_processor(
71
+ self,
72
+ params: SamplingParams,
73
+ ) -> Optional[RequestLogitsProcessor]:
74
+
75
+ return ThinkLogitsProcessor(max_len = self.model_max_len, ratio = self.ratio)