Skip to content

Detection Schema

CATEGORY_NAMES = Literal["foo", "bar"]
CATEGORY_NAMES_LIST = get_args(CATEGORY_NAMES)


class COCOCategory(BaseModel):
    id: conint(ge=0, le=len(CATEGORY_NAMES_LIST))
    name: CATEGORY_NAMES
    supercategory: str = "object"


class COCOImage(BaseModel):
    id: int
    width: int
    height: int
    file_name: str


class COCOAnnotation(BaseModel):
    id: int
    image_id: int
    area: confloat(gt=0.0)
    bbox: conlist(int, min_items=4, max_items=4)
    iscrowd: conint(ge=0, le=1) = 0
    score: confloat(ge=0, le=1) = 1.0
    category_id: conint(ge=0, le=len(CATEGORY_NAMES_LIST))


class COCODetectionDataset(BaseModel):
    images: List[COCOImage]
    annotations: List[COCOAnnotation]
    categories: List[COCOCategory]