| | |
| | |
| | |
| |
|
| | import os |
| | join = os.path.join |
| | import argparse |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import tifffile as tif |
| | import monai |
| | from tqdm import tqdm |
| | from utils.postprocess import mask_overlay |
| | from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType |
| | from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead |
| | import matplotlib.pyplot as plt |
| | from skimage import io, exposure, segmentation, morphology |
| | from utils.postprocess import watershed_post |
| | from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference |
| | import gradio as gr |
| |
|
| | def normalize_channel(img, lower=0.1, upper=99.9): |
| | non_zero_vals = img[np.nonzero(img)] |
| | percentiles = np.percentile(non_zero_vals, [lower, upper]) |
| | if percentiles[1] - percentiles[0] > 0.001: |
| | img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') |
| | else: |
| | img_norm = img |
| | return img_norm |
| |
|
| | def preprocess(img_data): |
| | if len(img_data.shape) == 2: |
| | img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) |
| | elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: |
| | img_data = img_data[:,:, :3] |
| | else: |
| | pass |
| | pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) |
| | for i in range(3): |
| | img_channel_i = img_data[:,:,i] |
| | if len(img_channel_i[np.nonzero(img_channel_i)])>0: |
| | pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) |
| | return pre_img_data |
| |
|
| |
|
| | def inference(pre_img_data): |
| | test_npy = pre_img_data/np.max(pre_img_data) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device) |
| | checkpoint = torch.load('./model.pth', map_location=torch.device(device)) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| | with torch.no_grad(): |
| | test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) |
| | |
| | val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model) |
| |
|
| | |
| | val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1]) |
| | test_pred_mask = val_seg_inst.squeeze().astype(np.uint16) |
| | |
| | |
| | boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner') |
| | boundary = morphology.binary_dilation(boundary, morphology.disk(1)) |
| | pre_img_data[boundary, 0] = 0 |
| | pre_img_data[boundary, 1] = 255 |
| | pre_img_data[boundary, 2] = 0 |
| |
|
| | return test_pred_mask, pre_img_data |
| |
|
| |
|
| | def predict(img): |
| | print('##########', img.name) |
| | img_name = img.name |
| | if img_name.endswith('.tif') or img_name.endswith('.tiff'): |
| | img_data = tif.imread(img_name) |
| | else: |
| | img_data = io.imread(img_name) |
| | if len(img_data.shape)==2: |
| | pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9) |
| | pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1) |
| | else: |
| | pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8) |
| | for i in range(3): |
| | img_channel_i = img_data[:,:,i] |
| | if len(img_channel_i[np.nonzero(img_channel_i)])>0: |
| | pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9) |
| | |
| | seg_labels, seg_overlay = inference(pre_img_data) |
| |
|
| | tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib') |
| |
|
| | return seg_overlay, join(os.getcwd(), 'segmentation.tiff') |
| |
|
| | unicell_api = gr.Interface( |
| | predict, |
| | inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"), |
| | outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")], |
| | title = "UniCell Online Demo", |
| | examples=['demo.png', 'demo.tif'] |
| | ) |
| |
|
| | unicell_api.launch() |
| |
|
| |
|