chouss commited on
Commit
990cf91
·
verified ·
1 Parent(s): 3e0b3b8

Uploading folder contents

Browse files
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .alpha_clip_new import *
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
__pycache__/__init__.cpython-39.pyc ADDED
Binary file (183 Bytes). View file
 
__pycache__/alpha_clip.cpython-310.pyc ADDED
Binary file (9.42 kB). View file
 
__pycache__/alpha_clip.cpython-39.pyc ADDED
Binary file (9.39 kB). View file
 
__pycache__/alpha_clip_new.cpython-39.pyc ADDED
Binary file (9.4 kB). View file
 
__pycache__/model.cpython-310.pyc ADDED
Binary file (28.9 kB). View file
 
__pycache__/model.cpython-39.pyc ADDED
Binary file (28.8 kB). View file
 
__pycache__/model_new.cpython-39.pyc ADDED
Binary file (27.9 kB). View file
 
__pycache__/simple_tokenizer.cpython-310.pyc ADDED
Binary file (5.7 kB). View file
 
__pycache__/simple_tokenizer.cpython-39.pyc ADDED
Binary file (5.76 kB). View file
 
alpha_clip_new.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model_new import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
+ }
41
+
42
+
43
+ def _download(url: str, root: str):
44
+ os.makedirs(root, exist_ok=True)
45
+ filename = os.path.basename(url)
46
+
47
+ expected_sha256 = url.split("/")[-2]
48
+ download_target = os.path.join(root, filename)
49
+
50
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
51
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
52
+
53
+ if os.path.isfile(download_target):
54
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
+ return download_target
56
+ else:
57
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
+
59
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
+
72
+ return download_target
73
+
74
+
75
+ def _convert_image_to_rgb(image):
76
+ return image.convert("RGB")
77
+
78
+
79
+ def _transform(n_px):
80
+ return Compose([
81
+ Resize(n_px, interpolation=BICUBIC),
82
+ CenterCrop(n_px),
83
+ _convert_image_to_rgb,
84
+ ToTensor(),
85
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
+ ])
87
+
88
+
89
+ def available_models() -> List[str]:
90
+ """Returns the names of available CLIP models"""
91
+ return list(_MODELS.keys())
92
+
93
+
94
+ def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
95
+ """Load a CLIP model
96
+
97
+ Parameters
98
+ ----------
99
+ name : str
100
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
+
102
+ alpha_vision_ckpt_pth: str
103
+ only changed when inferencing model instead of training
104
+
105
+ device : Union[str, torch.device]
106
+ The device to put the loaded model
107
+
108
+ jit : bool
109
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
110
+
111
+ download_root: str
112
+ path to download the model files; by default, it uses "~/.cache/clip"
113
+
114
+ Returns
115
+ -------
116
+ model : torch.nn.Module
117
+ The CLIP model
118
+
119
+ preprocess : Callable[[PIL.Image], torch.Tensor]
120
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
121
+ """
122
+ if name in _MODELS:
123
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
124
+ elif os.path.isfile(name):
125
+ model_path = name
126
+ else:
127
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
128
+
129
+ with open(model_path, 'rb') as opened_file:
130
+ try:
131
+ # loading JIT archive
132
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
133
+ state_dict = None
134
+ except RuntimeError:
135
+ # loading saved state dict
136
+ if jit:
137
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
138
+ jit = False
139
+ state_dict = torch.load(opened_file, map_location="cpu")
140
+
141
+ if not jit:
142
+ model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
143
+ if str(device) == "cpu":
144
+ model.float()
145
+ if alpha_vision_ckpt_pth != "None":
146
+ model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
147
+ model.eval() # merge lora params if exists (for inference only)
148
+ return model, _transform(model.visual.input_resolution)
149
+
150
+ # patch the device names
151
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
152
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
153
+
154
+ def _node_get(node: torch._C.Node, key: str):
155
+ """Gets attributes of a node which is polymorphic over return type.
156
+
157
+ From https://github.com/pytorch/pytorch/pull/82628
158
+ """
159
+ sel = node.kindOf(key)
160
+ return getattr(node, sel)(key)
161
+
162
+ def patch_device(module):
163
+ try:
164
+ graphs = [module.graph] if hasattr(module, "graph") else []
165
+ except RuntimeError:
166
+ graphs = []
167
+
168
+ if hasattr(module, "forward1"):
169
+ graphs.append(module.forward1.graph)
170
+
171
+ for graph in graphs:
172
+ for node in graph.findAllNodes("prim::Constant"):
173
+ if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
174
+ node.copyAttributes(device_node)
175
+
176
+ model.apply(patch_device)
177
+ patch_device(model.encode_image)
178
+ patch_device(model.encode_text)
179
+
180
+ # patch dtype to float32 on CPU
181
+ if str(device) == "cpu":
182
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
183
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
184
+ float_node = float_input.node()
185
+
186
+ def patch_float(module):
187
+ try:
188
+ graphs = [module.graph] if hasattr(module, "graph") else []
189
+ except RuntimeError:
190
+ graphs = []
191
+
192
+ if hasattr(module, "forward1"):
193
+ graphs.append(module.forward1.graph)
194
+
195
+ for graph in graphs:
196
+ for node in graph.findAllNodes("aten::to"):
197
+ inputs = list(node.inputs())
198
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
199
+ if _node_get(inputs[i].node(), "value") == 5:
200
+ inputs[i].node().copyAttributes(float_node)
201
+
202
+ model.apply(patch_float)
203
+ patch_float(model.encode_image)
204
+ patch_float(model.encode_text)
205
+
206
+ model.float()
207
+ return model, _transform(model.input_resolution.item())
208
+
209
+
210
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
211
+ """
212
+ Returns the tokenized representation of given input string(s)
213
+
214
+ Parameters
215
+ ----------
216
+ texts : Union[str, List[str]]
217
+ An input string or a list of input strings to tokenize
218
+
219
+ context_length : int
220
+ The context length to use; all CLIP models use 77 as the context length
221
+
222
+ truncate: bool
223
+ Whether to truncate the text in case its encoding is longer than the context length
224
+
225
+ Returns
226
+ -------
227
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
228
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
229
+ """
230
+ if isinstance(texts, str):
231
+ texts = [texts]
232
+
233
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
234
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
235
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
236
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
237
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
238
+ else:
239
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
240
+
241
+ for i, tokens in enumerate(all_tokens):
242
+ if len(tokens) > context_length:
243
+ if truncate:
244
+ tokens = tokens[:context_length]
245
+ tokens[-1] = eot_token
246
+ else:
247
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
248
+ result[i, :len(tokens)] = torch.tensor(tokens)
249
+
250
+ return result
bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
model_new.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import loralib as lora
9
+ import math
10
+ import collections
11
+ import torch.nn.init as init
12
+ import spconv.pytorch as spconv
13
+
14
+ class CPEconv(nn.Module):
15
+ def __init__(self, in_channels, spatial_shape, kernel_size=(3, 3, 3), padding=(1, 1, 1)):
16
+ super(CPEconv, self).__init__()
17
+ self.in_channels = in_channels
18
+ self.spatial_shape = 6
19
+ self.conv3d = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, padding=padding,groups=in_channels)
20
+ nn.init.zeros_(self.conv3d.weight)
21
+ if self.conv3d.bias is not None:
22
+ nn.init.zeros_(self.conv3d.bias)
23
+
24
+ self.register_buffer('target_tensor_template', torch.zeros(1, in_channels, self.spatial_shape, 1, 1))
25
+
26
+ def generate_3d_coords_from_depth(self, depth_maps):
27
+ # 假设 depth_maps 形状为 (B, H, W)
28
+ B, H, W = depth_maps.shape
29
+ z_min = depth_maps.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (B, 1, 1)
30
+ z_max = depth_maps.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (B, 1, 1)
31
+ z = (depth_maps - z_min) / (z_max - z_min + 1e-8)
32
+ # z = depth_maps # z 坐标为深度值,形状为 (B, H, W)
33
+
34
+ return z
35
+
36
+ def forward(self, features, depth):
37
+ #features [197,256,768] depth [256,14,14]
38
+ B,h,w=depth.shape
39
+ _,_,C=features.shape
40
+ D = self.spatial_shape
41
+ features = features[1:,:,:]
42
+ features = features.permute(1,0,2)
43
+ coord=self.generate_3d_coords_from_depth(depth)
44
+ bnd=self.spatial_shape - 1
45
+ coord = (coord *bnd).to(torch.int64)
46
+ coord = (
47
+ coord.clamp(0, bnd) # clamp into bnd
48
+ )
49
+ target_tensor = self.target_tensor_template.expand(B, C, D, h, w).clone()
50
+ # target_tensor = torch.zeros(B, C, D, h, w).to(device=features.device)
51
+ # return 0
52
+
53
+ coord = coord.unsqueeze(1).expand(-1, C, -1, -1) # [B, C, H, W]
54
+ # reshape features 以便与 coord 进行操作
55
+ features = features.view(B, h, w, C) # [B, H, W, C]
56
+ features = features.permute(0, 3, 1, 2) # [B, C, H, W]
57
+ features = features.unsqueeze(2).to(dtype=target_tensor.dtype)
58
+ coord = coord.unsqueeze(2)
59
+ # import pdb;pdb.set_trace()
60
+
61
+ # scatter features into target_tensor
62
+ target_tensor = target_tensor.scatter_(2, coord, features)
63
+ # 2. 使用 b 的值作为下标,将 features 的值复制到目标张量的相应位置
64
+ # 3. 使用 for 循环将 features 的值复制到目标张量
65
+ # for i in range(B):
66
+ # for j in range(h):
67
+ # for k in range(w):
68
+ # # 获取在 features 中的索引
69
+ # index = coord[i, j, k] # 从 b 中获取索引
70
+ # target_tensor[i, :,index, j, k] = features[i, j * 14 + k, :] # 复制对应的 features 值
71
+ output = self.conv3d(target_tensor).mean(dim=2) #(B,768,14,14)
72
+ output = output.reshape(-1,output.size(0),output.size(1))
73
+ cls_feat = torch.zeros(1,output.size(-2), output.size(-1)).to(device=output.device,dtype=output.dtype)
74
+ out_feat = torch.cat([cls_feat,output],dim=0)
75
+
76
+ return out_feat
77
+ class RPE(torch.nn.Module):
78
+ def __init__(self, patch_num, num_heads):
79
+ super(RPE, self).__init__()
80
+ self.num_heads = num_heads
81
+ self.pos_bnd = patch_num
82
+ self.rpe_num = 2 * self.pos_bnd + 1
83
+ self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads))
84
+ # torch.nn.init.trunc_normal_(self.rpe_table, std=0.02)
85
+
86
+ def generate_3d_coords_from_depth(self,depth_maps):
87
+ # 假设 depth_maps 形状为 (B, H, W)
88
+ B, H, W = depth_maps.shape
89
+
90
+ # 生成网格 i, j,形状为 (H, W)
91
+ i, j = torch.meshgrid(torch.arange(H, device=depth_maps.device), torch.arange(W, device=depth_maps.device), indexing='ij')
92
+
93
+ # 归一化 x 和 y 坐标
94
+ x = j.float() / (W - 1) # (H, W)
95
+ y = i.float() / (H - 1) # (H, W)
96
+
97
+ # 将 x 和 y 扩展到 (B, H, W) 以匹配 depth_maps
98
+ x = x.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
99
+ y = y.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
100
+
101
+ z_min = depth_maps.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (B, 1, 1)
102
+ z_max = depth_maps.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (B, 1, 1)
103
+ z = (depth_maps - z_min) / (z_max - z_min + 1e-8)
104
+ # z = depth_maps # z 坐标为深度值,形状为 (B, H, W)
105
+
106
+ # 组合成 (B, H, W, 3) 的三维坐标
107
+ coords = torch.stack([x, y, z], dim=-1) # (B, H, W, 3)
108
+
109
+ return coords
110
+
111
+
112
+ def compute_relative_positions(self,absolute_coords):
113
+ """
114
+ 计算相对位置编码
115
+ 参数:
116
+ absolute_coords: 形状为 (N, 3) 的绝���三维坐标张量
117
+ 返回:
118
+ 相对位置编码,形状为 (N, N, 3)
119
+ """
120
+ # 确保输入是一个张量
121
+ if not isinstance(absolute_coords, torch.Tensor):
122
+ raise ValueError("Input must be a PyTorch tensor.")
123
+ N = absolute_coords.shape[1]
124
+ relative_positions = absolute_coords.unsqueeze(2) - absolute_coords.unsqueeze(1)
125
+
126
+ return relative_positions
127
+
128
+
129
+ def forward(self,depth):
130
+ # B,K,K,3
131
+ # import pdb;pdb.set_trace()
132
+
133
+ depth=self.generate_3d_coords_from_depth(depth).squeeze(0)
134
+ depth=depth.reshape(depth.size(0),-1,depth.size(-1))
135
+ # zeros_tensor = torch.zeros(depth.size(0), 1, depth.size(-1))
136
+ # depth = torch.cat((zeros_tensor,depth), dim=1)
137
+ coord=self.compute_relative_positions(depth)
138
+ # 将 coord 从 [0, 1] 范围转换为 [0, width] 或 [0, height]
139
+ # coord = coord.reshape(coord.size(0),-1,coord.size(-1))
140
+ # import pdb;pdb.set_trace()
141
+ coord = (coord * torch.tensor([self.pos_bnd, self.pos_bnd, self.pos_bnd], device=coord.device)).round().long()
142
+ idx = (
143
+ coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd
144
+ + self.pos_bnd # relative position to positive index
145
+ + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride
146
+ )
147
+ out = self.rpe_table.index_select(0, idx.reshape(-1))
148
+ # out = out.reshape(coord.size(0) ,coord.size(1) ,coord.size(2) , -1)
149
+ out = out.view(idx.shape + (-1,)).sum(3)
150
+
151
+ out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K)
152
+ # out_new=torch.zeros(out.size(0),out.size(1),out.size(2)+1,out.size(3)+1)
153
+ # out_new[:, :, 1:, 1:] = out
154
+ return out
155
+
156
+ class PositionEmbeddingCoordsSine(nn.Module):
157
+ def __init__(
158
+ self,
159
+ temperature=10000,
160
+ normalize=False,
161
+ scale=None,
162
+ pos_type="fourier",
163
+ d_pos=None,
164
+ d_in=3,
165
+ gauss_scale=1.0,
166
+ ):
167
+ super().__init__()
168
+ self.temperature = temperature
169
+ self.normalize = normalize
170
+ if scale is not None and normalize is False:
171
+ raise ValueError("normalize should be True if scale is passed")
172
+ if scale is None:
173
+ scale = 2 * math.pi
174
+ assert pos_type in ["sine", "fourier"]
175
+ self.pos_type = pos_type
176
+ self.scale = scale
177
+ self.ln = LayerNorm(768)
178
+ if pos_type == "fourier":
179
+ assert d_pos is not None
180
+ assert d_pos % 2 == 0
181
+ # define a gaussian matrix input_ch -> output_ch
182
+ B = torch.empty((d_in, d_pos // 2)).normal_()
183
+ B *= gauss_scale
184
+ # self.gauss_B = nn.Parameter(B)
185
+ self.register_buffer("gauss_B", B)
186
+ self.d_pos = d_pos
187
+ self.trans3d=nn.Conv1d(in_channels=3, out_channels=768, kernel_size=1)
188
+ init.zeros_(self.trans3d.weight)
189
+ if self.trans3d.bias is not None:
190
+ init.zeros_(self.trans3d.bias)
191
+ def get_sine_embeddings(self, xyz, num_channels, input_range):
192
+ ncoords = xyz.shape[1]
193
+ ndim = num_channels // xyz.shape[2]
194
+ if ndim % 2 != 0:
195
+ ndim -= 1
196
+ # automatically handle remainder by assiging it to the first dim
197
+ rems = num_channels - (ndim * xyz.shape[2])
198
+
199
+ assert (
200
+ ndim % 2 == 0
201
+ ), f"Cannot handle odd sized ndim={ndim} where num_channels={num_channels} and xyz={xyz.shape}"
202
+
203
+ final_embeds = []
204
+ prev_dim = 0
205
+
206
+ for d in range(xyz.shape[2]):
207
+ cdim = ndim
208
+ if rems > 0:
209
+ # add remainder in increments of two to maintain even size
210
+ cdim += 2
211
+ rems -= 2
212
+
213
+ if cdim != prev_dim:
214
+ dim_t = torch.arange(cdim, dtype=torch.float32, device=xyz.device)
215
+ dim_t = self.temperature ** (2 * (dim_t // 2) / cdim)
216
+
217
+ # create batch x cdim x nccords embedding
218
+ raw_pos = xyz[:, :, d]
219
+ if self.scale:
220
+ raw_pos *= self.scale
221
+ pos = raw_pos[:, :, None] / dim_t
222
+ pos = torch.stack(
223
+ (pos[:, :, 0::2].sin(), pos[:, :, 1::2].cos()), dim=3
224
+ ).flatten(2)
225
+ final_embeds.append(pos)
226
+ prev_dim = cdim
227
+
228
+ final_embeds = torch.cat(final_embeds, dim=2)
229
+ return final_embeds
230
+ def get_fourier_embeddings(self, xyz, num_channels=None, input_range=None):
231
+ if num_channels is None:
232
+ num_channels = self.gauss_B.shape[1] * 2
233
+ bsize, npoints = xyz.shape[0], xyz.shape[1]
234
+ assert num_channels > 0 and num_channels % 2 == 0
235
+ d_in, max_d_out = self.gauss_B.shape[0], self.gauss_B.shape[1]
236
+ d_out = num_channels // 2
237
+ # assert d_out <= max_d_out
238
+ assert d_in == xyz.shape[-1]
239
+
240
+ # clone coords so that shift/scale operations do not affect original tensor
241
+ # import pdb;pdb.set_trace()
242
+ ncoords = xyz.shape[1]
243
+ if self.normalize:
244
+ # xyz = shift_scale_points(xyz, src_range=input_range)
245
+ pass
246
+
247
+ xyz *= 2 * torch.pi
248
+ xyz_proj = torch.mm(xyz.view(-1, d_in), self.gauss_B[:, :d_out]).view(
249
+ bsize, npoints, d_out
250
+ )
251
+ final_embeds = [xyz_proj.sin(), xyz_proj.cos()]
252
+
253
+ # return batch x d_pos x npoints embedding
254
+ final_embeds = torch.cat(final_embeds, dim=2)
255
+ # import pdb;pdb.set_trace()
256
+ # final_embeds = self.ln(final_embeds)
257
+ final_embeds = F.normalize(final_embeds, p=2, dim=2)
258
+
259
+ # If necessary, you can permute it back to [batch, 196, 768]
260
+ return final_embeds
261
+
262
+ def forward(self, depth_map, num_channels=None, input_range=None):
263
+ cam_coords_tensor = self.generate_3d_coords_from_depth(depth_map) # (B, H, W, 3)
264
+ # cam_coords_tensor = torch.tensor(cam_coords, dtype=torch.float16) # (B, H, W, 3)
265
+ cam_coords_tensor = cam_coords_tensor.view(cam_coords_tensor.size(0), -1, 3) # (B, H*W, 3)
266
+ xyz=cam_coords_tensor
267
+ # import pdb;pdb.set_trace()
268
+ assert xyz.ndim == 3
269
+ # xyz is batch x npoints x 3
270
+ if self.pos_type == "sine":
271
+ with torch.no_grad():
272
+ return self.get_sine_embeddings(xyz, 768, input_range)
273
+ elif self.pos_type == "fourier":
274
+ with torch.no_grad():
275
+ return self.get_fourier_embeddings(xyz, num_channels, input_range)
276
+ else:
277
+ raise ValueError(f"Unknown {self.pos_type}")
278
+
279
+ def positiontrans3d(self,depth_map):
280
+ cam_coords_tensor = self.generate_3d_coords_from_depth(depth_map) # (B, H, W, 3)
281
+ # cam_coords_tensor = torch.tensor(cam_coords, dtype=torch.float16) # (B, H, W, 3)
282
+ cam_coords_tensor = cam_coords_tensor.view(cam_coords_tensor.size(0), -1, 3) # (B, H*W, 3)
283
+ x=cam_coords_tensor
284
+ x = x.permute(0, 2, 1) # (B, H*W, 3) -> (B, 3, H*W)
285
+ x = self.trans3d(x) # 1D卷积映射 (B, 768, H*W)
286
+ x = x.permute(0, 2, 1) # 转换回 (B, H*W, 768)
287
+ return x
288
+ def generate_3d_coords_from_depth(self, depth_maps):
289
+ # 假设 depth_maps 形状为 (B, H, W)
290
+ B, H, W = depth_maps.shape
291
+
292
+ # 生成网格 i, j,形状为 (H, W)
293
+ i, j = torch.meshgrid(torch.arange(H, device=depth_maps.device), torch.arange(W, device=depth_maps.device), indexing='ij')
294
+
295
+ # 归一化 x 和 y 坐标
296
+ x = j.float() / (W - 1) # (H, W)
297
+ y = i.float() / (H - 1) # (H, W)
298
+
299
+ # 将 x 和 y 扩展到 (B, H, W) 以匹配 depth_maps
300
+ x = x.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
301
+ y = y.unsqueeze(0).expand(B, -1, -1) # (B, H, W)
302
+
303
+ z = depth_maps # z 坐标为深度值,形状为 (B, H, W)
304
+
305
+ # 组合成 (B, H, W, 3) 的三维坐标
306
+ coords = torch.stack([x, y, z], dim=-1) # (B, H, W, 3)
307
+
308
+ return coords
309
+
310
+
311
+ class Bottleneck(nn.Module):
312
+ expansion = 4
313
+
314
+ def __init__(self, inplanes, planes, stride=1):
315
+ super().__init__()
316
+
317
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
318
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
319
+ self.bn1 = nn.BatchNorm2d(planes)
320
+ self.relu1 = nn.ReLU(inplace=True)
321
+
322
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
323
+ self.bn2 = nn.BatchNorm2d(planes)
324
+ self.relu2 = nn.ReLU(inplace=True)
325
+
326
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
327
+
328
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
329
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
330
+ self.relu3 = nn.ReLU(inplace=True)
331
+
332
+ self.downsample = None
333
+ self.stride = stride
334
+
335
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
336
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
337
+ self.downsample = nn.Sequential(OrderedDict([
338
+ ("-1", nn.AvgPool2d(stride)),
339
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
340
+ ("1", nn.BatchNorm2d(planes * self.expansion))
341
+ ]))
342
+
343
+ def forward(self, x: torch.Tensor):
344
+ identity = x
345
+
346
+ out = self.relu1(self.bn1(self.conv1(x)))
347
+ out = self.relu2(self.bn2(self.conv2(out)))
348
+ out = self.avgpool(out)
349
+ out = self.bn3(self.conv3(out))
350
+
351
+ if self.downsample is not None:
352
+ identity = self.downsample(x)
353
+
354
+ out += identity
355
+ out = self.relu3(out)
356
+ return out
357
+
358
+
359
+ class AttentionPool2d(nn.Module):
360
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
361
+ super().__init__()
362
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
363
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
364
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
365
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
366
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
367
+ self.num_heads = num_heads
368
+
369
+ def forward(self, x):
370
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
371
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
372
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
373
+ x, _ = F.multi_head_attention_forward(
374
+ query=x[:1], key=x, value=x,
375
+ embed_dim_to_check=x.shape[-1],
376
+ num_heads=self.num_heads,
377
+ q_proj_weight=self.q_proj.weight,
378
+ k_proj_weight=self.k_proj.weight,
379
+ v_proj_weight=self.v_proj.weight,
380
+ in_proj_weight=None,
381
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
382
+ bias_k=None,
383
+ bias_v=None,
384
+ add_zero_attn=False,
385
+ dropout_p=0,
386
+ out_proj_weight=self.c_proj.weight,
387
+ out_proj_bias=self.c_proj.bias,
388
+ use_separate_proj_weight=True,
389
+ training=self.training,
390
+ need_weights=False
391
+ )
392
+ return x.squeeze(0)
393
+
394
+
395
+ class ModifiedResNet(nn.Module):
396
+ """
397
+ A ResNet class that is similar to torchvision's but contains the following changes:
398
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
399
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
400
+ - The final pooling layer is a QKV attention instead of an average pool
401
+ """
402
+
403
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
404
+ super().__init__()
405
+ self.output_dim = output_dim
406
+ self.input_resolution = input_resolution
407
+
408
+ # the 3-layer stem
409
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
410
+ self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
411
+ self.bn1 = nn.BatchNorm2d(width // 2)
412
+ self.relu1 = nn.ReLU(inplace=True)
413
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
414
+ self.bn2 = nn.BatchNorm2d(width // 2)
415
+ self.relu2 = nn.ReLU(inplace=True)
416
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
417
+ self.bn3 = nn.BatchNorm2d(width)
418
+ self.relu3 = nn.ReLU(inplace=True)
419
+ self.avgpool = nn.AvgPool2d(2)
420
+
421
+ # residual layers
422
+ self._inplanes = width # this is a *mutable* variable used during construction
423
+ self.layer1 = self._make_layer(width, layers[0])
424
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
425
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
426
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
427
+
428
+ embed_dim = width * 32 # the ResNet feature dimension
429
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
430
+
431
+ def _make_layer(self, planes, blocks, stride=1):
432
+ layers = [Bottleneck(self._inplanes, planes, stride)]
433
+
434
+ self._inplanes = planes * Bottleneck.expansion
435
+ for _ in range(1, blocks):
436
+ layers.append(Bottleneck(self._inplanes, planes))
437
+
438
+ return nn.Sequential(*layers)
439
+
440
+ def forward(self, x, alpha=None):
441
+ def stem(x):
442
+ x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
443
+ x = self.relu2(self.bn2(self.conv2(x)))
444
+ x = self.relu3(self.bn3(self.conv3(x)))
445
+ x = self.avgpool(x)
446
+ return x
447
+
448
+ x = x.type(self.conv1.weight.dtype)
449
+ x = stem(x)
450
+ x = self.layer1(x)
451
+ x = self.layer2(x)
452
+ x = self.layer3(x)
453
+ x = self.layer4(x)
454
+ x = self.attnpool(x)
455
+
456
+ return x
457
+
458
+
459
+ class LayerNorm(nn.LayerNorm):
460
+ """Subclass torch's LayerNorm to handle fp16."""
461
+
462
+ def forward(self, x: torch.Tensor):
463
+ orig_type = x.dtype
464
+ ret = super().forward(x.type(torch.float32))
465
+ return ret.type(orig_type)
466
+
467
+
468
+ class QuickGELU(nn.Module):
469
+ def forward(self, x: torch.Tensor):
470
+ return x * torch.sigmoid(1.702 * x)
471
+
472
+ class Attention(nn.Module):
473
+ def __init__(
474
+ self,
475
+ dim,
476
+ num_heads=8,
477
+ qkv_bias=True,
478
+ scaled_cosine=False,
479
+ scale_heads=False,
480
+ logit_scale_max=math.log(1. / 0.01),
481
+ attn_drop=0.,
482
+ proj_drop=0.,
483
+ lora_adapt=False,
484
+ rank=16,
485
+ patch_num=16
486
+ ):
487
+ super().__init__()
488
+ self.scaled_cosine = scaled_cosine
489
+ self.scale_heads = scale_heads
490
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
491
+ self.num_heads = num_heads
492
+ self.head_dim = dim // num_heads
493
+ self.scale = self.head_dim ** -0.5
494
+ self.logit_scale_max = logit_scale_max
495
+ self.use_rel_pos = True # 保存相对位置编码的使用状态
496
+ self.rpe = RPE(patch_num=patch_num,num_heads=self.num_heads)
497
+ self.rpe.requires_grad=True
498
+ # import pdb;pdb.set_trace()
499
+ # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
500
+ if lora_adapt:
501
+ print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
502
+ self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
503
+ else:
504
+ self.in_proj = nn.Linear(dim, dim * 3)
505
+ # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
506
+ # if qkv_bias:
507
+ # self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
508
+ # else:
509
+ # self.in_proj_bias = None
510
+
511
+ if self.scaled_cosine:
512
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
513
+ else:
514
+ self.logit_scale = None
515
+ self.attn_drop = nn.Dropout(attn_drop)
516
+ if self.scale_heads:
517
+ self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
518
+ else:
519
+ self.head_scale = None
520
+ self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
521
+ self.out_drop = nn.Dropout(proj_drop)
522
+
523
+ def forward(self, x, attn_mask = None,depth=None):
524
+ L, N, C = x.shape
525
+ q, k, v = self.in_proj(x).chunk(3, dim=-1)
526
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
527
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
528
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
529
+
530
+ if self.logit_scale is not None:
531
+ attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
532
+ logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
533
+ attn = attn.view(N, self.num_heads, L, L) * logit_scale
534
+ attn = attn.view(-1, L, L)
535
+ else:
536
+ q = q * self.scale
537
+ attn = torch.bmm(q, k.transpose(-2, -1))
538
+
539
+ if depth is not None:
540
+ depth=depth.squeeze(1)
541
+ res= self.rpe(depth)
542
+ res=res.reshape(-1,res.size(-2),res.size(-1))
543
+ # import pdb;pdb.set_trace()
544
+ attn[:,1:,1:]=attn[:,1:,1:]+res
545
+
546
+ if attn_mask is not None:
547
+ if attn_mask.dtype == torch.bool:
548
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
549
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
550
+ attn_mask = new_attn_mask
551
+ attn += attn_mask
552
+
553
+ attn = attn.softmax(dim=-1)
554
+ attn = self.attn_drop(attn)
555
+
556
+ x = torch.bmm(attn, v)
557
+ if self.head_scale is not None:
558
+ x = x.view(N, self.num_heads, L, C) * self.head_scale
559
+ x = x.view(-1, L, C)
560
+ x = x.transpose(0, 1).reshape(L, N, C)
561
+ x = self.out_proj(x)
562
+ x = self.out_drop(x)
563
+ return x, attn
564
+
565
+
566
+ class CustomResidualAttentionBlock(nn.Module):
567
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16,patch_num=16):
568
+ super().__init__()
569
+
570
+ self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank,patch_num=patch_num)
571
+ self.ln_1 = LayerNorm(d_model)
572
+ self.mlp = nn.Sequential(OrderedDict([
573
+ ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
574
+ ("gelu", QuickGELU()),
575
+ ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
576
+ ]))
577
+ self.ln_2 = LayerNorm(d_model)
578
+ self.ln_cpe = LayerNorm(d_model)
579
+ self.attn_mask = attn_mask
580
+ self.cpe=CPEconv(d_model,patch_num)
581
+
582
+
583
+ def attention(self, x: torch.Tensor,depth=None):
584
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
585
+ return self.attn(x, attn_mask=self.attn_mask,depth=depth)
586
+
587
+
588
+ def forward(self, x: torch.Tensor, return_attn=False,depth=None):
589
+ # import pdb;pdb.set_trace()
590
+ # x ([577, 50, 1024])
591
+ # if None:
592
+ shortcut=x
593
+ # import pdb;pdb.set_trace()
594
+ # shapes=x.shape
595
+ # x= x.reshape(-1,x.size(-1))
596
+ # import pdb;pdb.set_trace()
597
+ # cposi = self.cpe(x, depth).reshape(shapes)
598
+ cposi = self.cpe(self.ln_cpe(x), depth)
599
+ x =shortcut+cposi
600
+
601
+ attn_out, attn = self.attention(self.ln_1(x),depth)
602
+ x = x + attn_out
603
+ x = x + self.mlp(self.ln_2(x))
604
+ if return_attn:
605
+ return x, attn
606
+ else:
607
+ return x
608
+
609
+ class ResidualAttentionBlock(nn.Module):
610
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
611
+ super().__init__()
612
+
613
+ self.attn = nn.MultiheadAttention(d_model, n_head)
614
+ self.ln_1 = LayerNorm(d_model)
615
+ self.mlp = nn.Sequential(OrderedDict([
616
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
617
+ ("gelu", QuickGELU()),
618
+ ("c_proj", nn.Linear(d_model * 4, d_model))
619
+ ]))
620
+ self.ln_2 = LayerNorm(d_model)
621
+ self.attn_mask = attn_mask
622
+
623
+ def attention(self, x: torch.Tensor):
624
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
625
+ return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
626
+
627
+ def forward(self, x: torch.Tensor):
628
+ x = x + self.attention(self.ln_1(x))
629
+ x = x + self.mlp(self.ln_2(x))
630
+ return x
631
+
632
+ class Transformer(nn.Module):
633
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
634
+ super().__init__()
635
+ self.width = width
636
+ self.layers = layers
637
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
638
+
639
+ def forward(self, x: torch.Tensor):
640
+ return self.resblocks(x)
641
+
642
+ class CustomTransformer(nn.Module):
643
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16,patch_num=16):
644
+ super().__init__()
645
+ self.width = width
646
+ self.layers = layers
647
+ self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank,patch_num=patch_num) for _ in range(layers)])
648
+
649
+ def forward(self, x: torch.Tensor, return_attn=False,depth=None):
650
+ # import pdb;pdb.set_trace()
651
+ if return_attn:
652
+ for i, block in enumerate(self.resblocks):
653
+ if i == len(self.resblocks) - 1:
654
+ return block(x, return_attn=True,depth=depth)
655
+ else:
656
+ x = block(x,depth=depth)
657
+ assert False
658
+ for block in self.resblocks:
659
+ # import pdb;pdb.set_trace()
660
+ x = block(x, depth=depth) # 将 depth 传递给每个模块
661
+ return x
662
+ # return self.resblocks(x)
663
+
664
+ # ////////////////////////////////////////////////////////////////////////////////////////////
665
+ class VisionTransformer(nn.Module):
666
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
667
+ super().__init__()
668
+ self.input_resolution = input_resolution
669
+ self.output_dim = output_dim
670
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
671
+ self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
672
+ nn.init.zeros_(self.conv1_alpha.weight)
673
+ scale = width ** -0.5
674
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
675
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
676
+ # self.depth_positional_embedding = nn.Parameter(scale * torch.zeros((input_resolution // patch_size) ** 2, width)) # 用于alpha的深度编码
677
+ # self.depth_positional_embedding = PositionEmbeddingCoordsSine(temperature=10000,
678
+ # normalize=True,
679
+ # scale=2 * torch.pi,
680
+ # pos_type="fourier",
681
+ # d_pos=768, # 示例输出维度
682
+ # d_in=3,
683
+ # gauss_scale=1.0
684
+ # )
685
+ # self.sine_positional_embedding = PositionEmbeddingCoordsSine(temperature=10000,
686
+ # normalize=True,
687
+ # scale=2 * torch.pi,
688
+ # pos_type="sine",
689
+ # d_pos=768, # 示例输出维度
690
+ # d_in=3,
691
+ # gauss_scale=1.0
692
+ # )
693
+ # self.large_positional_embedding = PositionEmbeddingCoordsSine(temperature=10000,
694
+ # normalize=True,
695
+ # scale=2 * torch.pi,
696
+ # pos_type="sine",
697
+ # d_pos=1024, # 示例输出维度
698
+ # d_in=3,
699
+ # gauss_scale=1.0
700
+ # )
701
+ # self.depth_mlp=nn.Linear(768,768)
702
+ # nn.init.zeros_(self.depth_mlp.weight)
703
+ # if self.depth_mlp.bias is not None:
704
+ # nn.init.zeros_(self.depth_mlp.bias)
705
+ self.patch_size=patch_size
706
+
707
+ self.ln_pre = LayerNorm(width)
708
+ self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank,patch_num=input_resolution // patch_size)
709
+
710
+ self.ln_post = LayerNorm(width)
711
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
712
+
713
+ def forward(self, x: torch.Tensor, alpha=None, return_attn=False,pos_embed=None):
714
+ # import pdb;pdb.set_trace()
715
+ x = self.conv1(x) # shape = [*, width, grid, grid]
716
+ # ASSUME alpha is always not None!
717
+ # import pdb;pdb.set_trace()
718
+ # if pos_embed == "nodepth":
719
+ # pass
720
+ # else:
721
+ # x = x + self.conv1_alpha(alpha)
722
+ # import pdb;pdb.set_trace()
723
+
724
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
725
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
726
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
727
+ # import pdb;pdb.set_trace()
728
+ alpha_resized = F.adaptive_avg_pool2d(alpha, (self.input_resolution // self.patch_size, self.input_resolution // self.patch_size))
729
+ # alpha_flattened = alpha_resized.flatten(start_dim=2).permute(0, 2, 1)
730
+ alpha_resized = alpha_resized.squeeze(1)
731
+ # x[:, 1:] += self.depth_positional_embedding.to(x.dtype) * alpha_flattened
732
+ # import pdb;pdb.set_trace()
733
+ # if pos_embed == "fourier":
734
+ # depth_embedding = self.depth_positional_embedding(alpha_resized)
735
+ # x[:, 1:] +=self.depth_mlp(depth_embedding)
736
+ # elif pos_embed == "sine":
737
+ # depth_embedding = self.sine_positional_embedding(alpha_resized)
738
+ # x[:, 1:] +=self.depth_mlp(depth_embedding)
739
+ # elif pos_embed == "3d":
740
+ # depth_embedding = self.depth_positional_embedding.positiontrans3d(alpha_resized)
741
+ # x[:, 1:] +=self.depth_mlp(depth_embedding)
742
+
743
+ x = x + self.positional_embedding.to(x.dtype)
744
+ x = self.ln_pre(x)
745
+ # import pdb;pdb.set_trace()
746
+ x = x.permute(1, 0, 2) # NLD -> LND
747
+ if return_attn:
748
+ x, attn_last = self.transformer(x, return_attn=True,depth=alpha_resized)
749
+ else:
750
+ x = self.transformer(x, return_attn=False,depth=alpha_resized)
751
+ x = x.permute(1, 0, 2) # LND -> NLD
752
+
753
+ x = self.ln_post(x[:, 0, :])
754
+
755
+ if self.proj is not None:
756
+ x = x @ self.proj
757
+ if return_attn:
758
+ return x, attn_last
759
+ else:
760
+ return x
761
+ # /////////////////////////////////////////////////////////////////////////////////////////////////////
762
+
763
+ class CLIP(nn.Module):
764
+ def __init__(self,
765
+ embed_dim: int,
766
+ # vision
767
+ image_resolution: int,
768
+ vision_layers: Union[Tuple[int, int, int, int], int],
769
+ vision_width: int,
770
+ vision_patch_size: int,
771
+ # text
772
+ context_length: int,
773
+ vocab_size: int,
774
+ transformer_width: int,
775
+ transformer_heads: int,
776
+ transformer_layers: int,
777
+ lora_adapt = False,
778
+ rank = 16,
779
+ ):
780
+ super().__init__()
781
+
782
+ self.context_length = context_length
783
+
784
+ if isinstance(vision_layers, (tuple, list)):
785
+ vision_heads = vision_width * 32 // 64
786
+ self.visual = ModifiedResNet(
787
+ layers=vision_layers,
788
+ output_dim=embed_dim,
789
+ heads=vision_heads,
790
+ input_resolution=image_resolution,
791
+ width=vision_width
792
+ )
793
+ else:
794
+ vision_heads = vision_width // 64
795
+ self.visual = VisionTransformer(
796
+ input_resolution=image_resolution,
797
+ patch_size=vision_patch_size,
798
+ width=vision_width,
799
+ layers=vision_layers,
800
+ heads=vision_heads,
801
+ output_dim=embed_dim,
802
+ lora_adapt=lora_adapt,
803
+ rank=rank
804
+ )
805
+
806
+ self.transformer = Transformer(
807
+ width=transformer_width,
808
+ layers=transformer_layers,
809
+ heads=transformer_heads,
810
+ attn_mask=self.build_attention_mask()
811
+ )
812
+
813
+ self.vocab_size = vocab_size
814
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
815
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
816
+ self.ln_final = LayerNorm(transformer_width)
817
+
818
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
819
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
820
+
821
+ self.initialize_parameters()
822
+
823
+ def initialize_parameters(self):
824
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
825
+ nn.init.normal_(self.positional_embedding, std=0.01)
826
+
827
+ if isinstance(self.visual, ModifiedResNet):
828
+ if self.visual.attnpool is not None:
829
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
830
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
831
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
832
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
833
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
834
+
835
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
836
+ for name, param in resnet_block.named_parameters():
837
+ if name.endswith("bn3.weight"):
838
+ nn.init.zeros_(param)
839
+
840
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
841
+ attn_std = self.transformer.width ** -0.5
842
+ fc_std = (2 * self.transformer.width) ** -0.5
843
+ for block in self.transformer.resblocks:
844
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
845
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
846
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
847
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
848
+
849
+ if self.text_projection is not None:
850
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
851
+
852
+ def build_attention_mask(self):
853
+ # lazily create causal attention mask, with full attention between the vision tokens
854
+ # pytorch uses additive attention mask; fill with -inf
855
+ mask = torch.empty(self.context_length, self.context_length)
856
+ mask.fill_(float("-inf"))
857
+ mask.triu_(1) # zero out the lower diagonal
858
+ return mask
859
+
860
+ @property
861
+ def dtype(self):
862
+ if not hasattr(self.visual, "conv1"):
863
+ return self.visual.module.conv1.weight.dtype
864
+ return self.visual.conv1.weight.dtype
865
+
866
+ def encode_image(self, image, alpha):
867
+ assert alpha is not None
868
+ return self.visual(image.type(self.dtype), alpha.type(self.dtype))
869
+
870
+ def encode_text(self, text):
871
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
872
+
873
+ x = x + self.positional_embedding.type(self.dtype)
874
+ x = x.permute(1, 0, 2) # NLD -> LND
875
+ x = self.transformer(x)
876
+ x = x.permute(1, 0, 2) # LND -> NLD
877
+ x = self.ln_final(x).type(self.dtype)
878
+
879
+ # x.shape = [batch_size, n_ctx, transformer.width]
880
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
881
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
882
+
883
+ return x
884
+
885
+
886
+
887
+ def forward(self, image, text, alpha):
888
+ image_features = self.encode_image(image, alpha)
889
+ text_features = self.encode_text(text)
890
+
891
+ # normalized features
892
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
893
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
894
+
895
+ # cosine similarity as logits
896
+ logit_scale = self.logit_scale.exp()
897
+ logits_per_image = logit_scale * image_features @ text_features.t()
898
+ logits_per_text = logits_per_image.t()
899
+
900
+ # shape = [global_batch_size, global_batch_size]
901
+ return logits_per_image, logits_per_text
902
+
903
+
904
+ def convert_weights(model: nn.Module):
905
+ """Convert applicable model parameters to fp16"""
906
+
907
+ def _convert_weights_to_fp16(l):
908
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
909
+ l.weight.data = l.weight.data.half()
910
+ if l.bias is not None:
911
+ l.bias.data = l.bias.data.half()
912
+
913
+ if isinstance(l, nn.MultiheadAttention):
914
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
915
+ tensor = getattr(l, attr)
916
+ if tensor is not None:
917
+ tensor.data = tensor.data.half()
918
+
919
+ for name in ["text_projection", "proj"]:
920
+ if hasattr(l, name):
921
+ attr = getattr(l, name)
922
+ if attr is not None:
923
+ attr.data = attr.data.half()
924
+
925
+ model.apply(_convert_weights_to_fp16)
926
+
927
+
928
+ def build_model(state_dict: dict, lora_adapt=False, rank=16):
929
+ vit = "visual.proj" in state_dict
930
+
931
+ if vit:
932
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
933
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
934
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
935
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
936
+ image_resolution = vision_patch_size * grid_size
937
+ else:
938
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
939
+ vision_layers = tuple(counts)
940
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
941
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
942
+ vision_patch_size = None
943
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
944
+ image_resolution = output_width * 32
945
+
946
+ embed_dim = state_dict["text_projection"].shape[1]
947
+ context_length = state_dict["positional_embedding"].shape[0]
948
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
949
+ transformer_width = state_dict["ln_final.weight"].shape[0]
950
+ transformer_heads = transformer_width // 64
951
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
952
+
953
+ # always load lora version
954
+ model = CLIP(
955
+ embed_dim,
956
+ image_resolution, vision_layers, vision_width, vision_patch_size,
957
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
958
+ lora_adapt=lora_adapt, rank=rank,
959
+ )
960
+
961
+ for key in ["input_resolution", "context_length", "vocab_size"]:
962
+ if key in state_dict:
963
+ del state_dict[key]
964
+ # para_wb to linear
965
+ new_state_dict = collections.OrderedDict()
966
+ for k, v in state_dict.items():
967
+ if 'visual' in k:
968
+ if 'in_proj_weight' in k:
969
+ new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
970
+ elif 'in_proj_bias' in k:
971
+ new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
972
+ else:
973
+ new_state_dict[k] = v
974
+ else:
975
+ new_state_dict[k] = v
976
+
977
+ state_dict = new_state_dict
978
+ # add rgba_conv_weight
979
+ if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
980
+ rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
981
+ rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
982
+ state_dict['visual.conv1_alpha.weight'] = rgba_weigth
983
+ convert_weights(model)
984
+ model.load_state_dict(state_dict, strict=False)
985
+ return model.eval()
simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text