| | import json |
| | import os |
| | from pathlib import Path |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| | from monai.bundle import ConfigParser |
| |
|
| | from utils import page_utils |
| |
|
| | with open("configs/inference.json") as f: |
| | inference_config = json.load(f) |
| |
|
| | device = torch.device('cpu') |
| | if torch.cuda.is_available(): |
| | device = torch.device('cuda:0') |
| |
|
| | |
| | inference_config["device"] = device |
| |
|
| | parser = ConfigParser() |
| | parser.read_config(f=inference_config) |
| | parser.read_meta(f="configs/metadata.json") |
| |
|
| | inference = parser.get_parsed_content("inferer") |
| | |
| | network = parser.get_parsed_content("network_def") |
| | preprocess = parser.get_parsed_content("preprocessing") |
| | postprocess = parser.get_parsed_content("postprocessing") |
| |
|
| | use_fp16 = os.environ.get('USE_FP16', False) |
| |
|
| | state_dict = torch.load("models/model.pt") |
| | network.load_state_dict(state_dict, strict=True) |
| |
|
| | network = network.to(device) |
| | network.eval() |
| |
|
| | if use_fp16 and torch.cuda.is_available(): |
| | network = network.half() |
| |
|
| | label2color = {0: (0, 0, 0), |
| | 1: (225, 24, 69), |
| | 2: (135, 233, 17), |
| | 3: (0, 87, 233), |
| | 4: (242, 202, 25), |
| | 5: (137, 49, 239),} |
| |
|
| | example_files = list(Path("sample_data").glob("*.png")) |
| |
|
| | def visualize_instance_seg_mask(mask): |
| | image = np.zeros((mask.shape[0], mask.shape[1], 3)) |
| | labels = np.unique(mask) |
| | for i in range(image.shape[0]): |
| | for j in range(image.shape[1]): |
| | image[i, j, :] = label2color[mask[i, j]] |
| | image = image / 255 |
| | return image |
| |
|
| | def query_image(img): |
| | data = {"image": img} |
| | batch = preprocess(data) |
| | batch['image'] = batch['image'].to(device) |
| |
|
| | if use_fp16 and torch.cuda.is_available(): |
| | batch['image'] = batch['image'].half() |
| |
|
| | with torch.no_grad(): |
| | pred = inference(batch['image'].unsqueeze(dim=0), network) |
| |
|
| | batch["pred"] = pred |
| | for k,v in batch["pred"].items(): |
| | batch["pred"][k] = v.squeeze(dim=0) |
| |
|
| | batch = postprocess(batch) |
| |
|
| | result = visualize_instance_seg_mask(batch["type_map"].squeeze()) |
| |
|
| | |
| | result = batch["image"].permute(1, 2, 0).cpu().numpy() * 0.5 + result * 0.5 |
| |
|
| | |
| | result = np.fliplr(result) |
| | result = np.rot90(result, k=1) |
| |
|
| | return result |
| |
|
| | |
| | with open('index.html', encoding='utf-8') as f: |
| | html_content = f.read() |
| |
|
| | demo = gr.Interface( |
| | query_image, |
| | inputs=[gr.Image(type="filepath")], |
| | outputs="image", |
| | theme=gr.themes.Default(primary_hue=page_utils.KALBE_THEME_COLOR, secondary_hue=page_utils.KALBE_THEME_COLOR).set( |
| | button_primary_background_fill="*primary_600", |
| | button_primary_background_fill_hover="*primary_500", |
| | button_primary_text_color="white", |
| | ), |
| | description = html_content, |
| | examples=example_files, |
| | ) |
| |
|
| | demo.queue(max_size=10).launch() |
| |
|