| | import re |
| | import torch |
| | import requests |
| | from PIL import Image, ImageDraw |
| | from transformers import AutoProcessor, Kosmos2_5ForConditionalGeneration |
| |
|
| | repo = "microsoft/kosmos-2.5" |
| | device = "cuda:0" |
| | dtype = torch.bfloat16 |
| | model = Kosmos2_5ForConditionalGeneration.from_pretrained(repo, device_map=device, torch_dtype=dtype) |
| | processor = AutoProcessor.from_pretrained(repo) |
| |
|
| | |
| | url = "https://huggingface.co/microsoft/kosmos-2.5/resolve/main/receipt_00008.png" |
| | image = Image.open(requests.get(url, stream=True).raw) |
| |
|
| | |
| | prompt = "<ocr>" |
| | inputs = processor(text=prompt, images=image, return_tensors="pt") |
| | height, width = inputs.pop("height"), inputs.pop("width") |
| | raw_width, raw_height = image.size |
| | scale_height = raw_height / height |
| | scale_width = raw_width / width |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | inputs = {k: v.to(device) if v is not None else None for k, v in inputs.items()} |
| | inputs["flattened_patches"] = inputs["flattened_patches"].to(dtype) |
| | generated_ids = model.generate( |
| | **inputs, |
| | max_new_tokens=1024, |
| | ) |
| |
|
| | generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) |
| | def post_process(y, scale_height, scale_width): |
| | y = y.replace(prompt, "") |
| | if "<md>" in prompt: |
| | return y |
| | pattern = r"<bbox><x_\d+><y_\d+><x_\d+><y_\d+></bbox>" |
| | bboxs_raw = re.findall(pattern, y) |
| | lines = re.split(pattern, y)[1:] |
| | bboxs = [re.findall(r"\d+", i) for i in bboxs_raw] |
| | bboxs = [[int(j) for j in i] for i in bboxs] |
| | info = "" |
| | for i in range(len(lines)): |
| | box = bboxs[i] |
| | x0, y0, x1, y1 = box |
| | if not (x0 >= x1 or y0 >= y1): |
| | x0 = int(x0 * scale_width) |
| | y0 = int(y0 * scale_height) |
| | x1 = int(x1 * scale_width) |
| | y1 = int(y1 * scale_height) |
| | info += f"{x0},{y0},{x1},{y0},{x1},{y1},{x0},{y1},{lines[i]}" |
| | return info |
| |
|
| | output_text = post_process(generated_text[0], scale_height, scale_width) |
| | print(output_text) |
| |
|
| | draw = ImageDraw.Draw(image) |
| | lines = output_text.split("\n") |
| | for line in lines: |
| | |
| | line = list(line.split(",")) |
| | if len(line) < 8: |
| | continue |
| | line = list(map(int, line[:8])) |
| | draw.polygon(line, outline="red") |
| | image.save("output.png") |
| |
|