process_image_segmentation
Process image with segmentation model.
process_image_segmentation(config_file, model_checkpoint, image_path, output_folder, device='cuda:0')
Process image with segmentation model.
Process image with segmentation model generating image result and COCO annotations file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_file |
str |
path to mmdet model config file |
required |
model_checkpoint |
str |
pytorch |
required |
image_path |
str |
path to input image |
required |
output_folder |
str |
plotting and annotation results will be generated in this folder. If it does not exists, it will be created. |
required |
device |
str |
device used for inference |
'cuda:0' |
Source code in src/stages/model/explore/process_image_segmentation.py
def process_image_segmentation(
config_file: str,
model_checkpoint: str,
image_path: str,
output_folder: str,
device: str = "cuda:0",
) -> None:
"""Process image with segmentation model.
Process image with [segmentation](https://gradiant.github.io/ai-project-template/supported_tasks/#segmentation)
model generating image result and COCO annotations file.
Args:
config_file:
path to mmdet model config file
model_checkpoint:
pytorch `.pth` checkpoint file
image_path:
path to input image
output_folder:
plotting and annotation results will be generated in this folder.
If it does not exists, it will be created.
device:
device used for inference
"""
import mmcv
import numpy as np
from mmseg.apis import inference_segmentor, init_segmentor
from pycocotools.mask import area, encode
model = init_segmentor(config_file, model_checkpoint, device=device)
img = mmcv.imread(str(image_path))
width, height = img.shape[:2]
result = inference_segmentor(model, img)[0]
Path(output_folder).mkdir(parents=True, exist_ok=True)
annotations = []
images = [dict(id=1, width=width, height=height, file_name=Path(image_path).name)]
categories = [
dict(supercategory="object", id=i, name=category)
for i, category in enumerate(model.CLASSES)
]
for cat_idx in range(len(model.CLASSES)):
bin_mask = (result == cat_idx).astype(np.uint8)
rle_mask = encode(np.asfortranarray(bin_mask))
rle_mask["counts"] = rle_mask["counts"].decode("utf-8")
annotations.append(
dict(
id=cat_idx,
segmentation=rle_mask,
area=int(area(rle_mask)),
image_id=1,
category_id=categories[cat_idx]["id"],
)
)
coco_result = dict(categories=categories, annotations=annotations, images=images)
with open(Path(output_folder) / "coco_results.json", "w") as f:
json.dump(coco_result, f)