Sliced inference for Cell Segmentation (Large Images)

Process the input image

Convert the input images into an RGB format where the blue channel represents the nuclear channel, the green channel corresponds to the membrane channel.

drawing

Inference the sliced prediction based on SAHI library

Import required libraries

from sahi.utils.detectron2 import Detectron2TestConstants

# import required functions, classes
from sahi import AutoDetectionModel
from sahi_predict import get_sliced_prediction, predict, get_prediction
from sahi.utils.file import download_from_url
from sahi.utils.cv import read_image
from cellotype.predict import Detectron2DetectionModel
import numpy as np
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
from tqdm import tqdm
import time
import logging
from deepcell.utils.plot_utils import create_rgb_image
from deepcell.utils.plot_utils import make_outline_overlay
from skimage import io
import matplotlib.pyplot as plt

def get_mask_from_result(result):
    obj_list = result.object_prediction_list
    output = np.zeros(result.image.size).T
    for i in tqdm(range(len(obj_list))):
        mask = obj_list[i].mask.bool_mask
        output[mask==True] = i+1
    return output

Run sliced inference

detection_model = Detectron2DetectionModel(
model_path='cellotype/models/tissuenet_model_0019999.pth',
config_path='cellotype/configs/maskdino_R50_bs16_50ep_4s_dowsample1_2048.yaml',
confidence_threshold=0.4,
image_size=512,
device="cuda:0", # or 'cuda:0'
channels=3,
)
img_path = 'data/example/example_large.png'
logging.info("Predicting...")
start = time.time()
result = get_sliced_prediction(
    img_path,
    detection_model,
    slice_height = 512,
    slice_width = 512,
    overlap_height_ratio = 0.05,
    overlap_width_ratio = 0.05,
    perform_standard_pred = False,
)

If the GPU memory is not enough, you can reduce the image_size, slice_height and slice_width to 256.

Save the result and visualization

output = get_mask_from_result(result)
im = Image.fromarray(output.astype('int32'))
im.save('figures/example_sahi.tif')

logging.info("Saving visualization results...")

output = get_mask_from_result(result)
img_data = io.imread(img_path)[:,:,[2,1]]
img_data = np.reshape(img_data, (1, img_data.shape[0], img_data.shape[1], 2))
rgb_image = create_rgb_image(img_data, channel_colors=['blue', 'green'])
fig = plt.figure(figsize=(30,10))
ax = fig.add_subplot(111)
ax.imshow(make_outline_overlay(rgb_image, predictions=np.reshape(output, (1, output.shape[0], output.shape[1], 1)))[0])
plt.savefig('figures/example_large.png', dpi=600)
drawing