Skip to content

process_image_classification

process_image_classification(config_file, model_checkpoint, image_path, output_folder, device='cuda:0')

Process image with classification model generating image result and COCO annotations file.

Parameters:

Name Type Description Default
config_file str

path to mmcls model config file

required
model_checkpoint str

pytorch .pth checkpoint file

required
image_path str

path to input image

required
output_folder str

plotting and annotation results will be generated in this folder. If it does not exists, it will be created.

required
device str

device used for inference

'cuda:0'
Source code in src/stages/model/explore/process_image_classification.py
def process_image_classification(
    config_file: str,
    model_checkpoint: str,
    image_path: str,
    output_folder: str,
    device: str = "cuda:0",
) -> None:
    """Process image with [classification](https://gradiant.github.io/ai-project-template/supported_tasks/#classification) model generating image result and COCO annotations file.

    Args:
        config_file:
            path to mmcls model config file
        model_checkpoint:
            pytorch `.pth` checkpoint file
        image_path:
            path to input image
        output_folder:
            plotting and annotation results will be generated in this folder.
            If it does not exists, it will be created.
        device:
            device used for inference
    """
    import mmcv
    from mmcls.apis import inference_model, init_model

    model = init_model(config_file, model_checkpoint, device=device)
    img = mmcv.imread(str(image_path))
    width, height = img.shape[:2]
    results = inference_model(model, img)

    Path(output_folder).mkdir(parents=True, exist_ok=True)

    images = [
        dict(
            id=1,
            width=width,
            height=height,
            file_name=Path(image_path).name,
            category_id=results["pred_label"],
        )
    ]
    categories = [
        dict(supercategory="object", id=i, name=category)
        for i, category in enumerate(model.CLASSES)
    ]

    coco_result = dict(categories=categories, images=images)

    with open(Path(output_folder) / "coco_results.json", "w") as f:
        json.dump(coco_result, f)