Skip to content

evaluation

Train Config Evaluation App.

The pyodi train-config evaluation app can be used to evaluate a given mmdetection Anchor Generator Configuration to train your model using a specific training pipeline.

Procedure

Training performance of object detection model depends on how well generated anchors match with ground truth bounding boxes. This simple application provides intuitions about this, by recreating train preprocessing conditions such as image resizing or padding, and computing different metrics based on the largest Intersection over Union (IoU) between ground truth boxes and the provided anchors.

Each bounding box is assigned with the anchor that shares a largest IoU with it. We call overlap, to the maximum IoU each ground truth box has with the generated anchor set.

Example usage:

pyodi train-config evaluation \\
$TINY_COCO_ANIMAL/annotations/train.json \\
$TINY_COCO_ANIMAL/resources/anchor_config.py \\
--input-size [1280,720]

The app provides four different plots:

COCO scale_ratio

Cumulative Overlap

It shows a cumulative distribution function for the overlap distribution. This view helps to distinguish which percentage of bounding boxes have a very low overlap with generated anchors and viceversa.

It can be very useful to determine positive and negative thresholds for your training, these are the values that determine is a ground truth bounding box will is going to be taken into account in the loss function or discarded and considered as background.

Bounding Box Distribution

It shows a scatter plot of bounding box width vs height. The color of each point represent the overlap value assigned to that bounding box. Thanks to this plot we can easily observe pattern such low overlap values for large bounding boxes. We could have this into account and generate larger anchors to improve this matching.

Scale and Mean Overlap

This plot contains a simple histogram with bins of similar scales and its mean overlap value. It help us to visualize how overlap decays when scale increases, as we said before.

Log Ratio and Mean Overlap

Similarly to previous plot, it shows an histogram of bounding box log ratios and its mean overlap values. It is useful to visualize this relation and see how certain box ratios might be having problems to match with generated anchors. In this example, boxes with negative log ratios, where width is much larger than height, overlaps are very small. See how this matches with patterns observed in bounding box distribution plot, where all boxes placed near to x axis, have low overlaps.


API REFERENCE

load_anchor_config_file(anchor_config_file)

Loads the anchor_config_file.

Parameters:

Name Type Description Default
anchor_config_file str

File with the anchor configuration.

required

Returns:

Type Description
Dict[str, Any]

Dictionary with the training configuration.

Source code in pyodi/apps/train_config/train_config_evaluation.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def load_anchor_config_file(anchor_config_file: str) -> Dict[str, Any]:
    """Loads the `anchor_config_file`.

    Args:
        anchor_config_file: File with the anchor configuration.

    Returns:
        Dictionary with the training configuration.

    """
    logger.info("Loading Train Config File")
    with TemporaryDirectory() as temp_config_dir:
        copyfile(anchor_config_file, osp.join(temp_config_dir, "_tempconfig.py"))
        sys.path.insert(0, temp_config_dir)
        mod = import_module("_tempconfig")
        sys.path.pop(0)
        train_config = {
            name: value
            for name, value in mod.__dict__.items()
            if not name.startswith("__")
        }
        # delete imported module
        del sys.modules["_tempconfig"]
    return train_config

train_config_evaluation(ground_truth_file, anchor_config, input_size=(1280, 720), show=True, output=None, output_size=(1600, 900), keep_ratio=False)

Evaluates the fitness between ground_truth_file and anchor_config_file.

Parameters:

Name Type Description Default
ground_truth_file Union[str, pd.DataFrame]

Path to COCO ground truth file or coco df_annotations DataFrame to be used from pyodi train-config generation

required
anchor_config str

Path to MMDetection-like anchor_generator section. It can also be a dictionary with the required data.

required
input_size Tuple[int, int]

Model image input size. Defaults to (1333, 800).

(1280, 720)
show bool

Show results or not. Defaults to True.

True
output Optional[str]

Output directory where results going to be saved. Defaults to None.

None
output_size Tuple[int, int]

Size of saved images. Defaults to (1600, 900).

(1600, 900)
keep_ratio bool

Whether to keep the aspect ratio or not. Defaults to False.

False

Examples:

# faster_rcnn_r50_fpn.py:
anchor_generator=dict(
    type='AnchorGenerator',
    scales=[8],
    ratios=[0.5, 1.0, 2.0],
    strides=[4, 8, 16, 32, 64]
)
Source code in pyodi/apps/train_config/train_config_evaluation.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
@logger.catch(reraise=True)
def train_config_evaluation(
    ground_truth_file: Union[str, pd.DataFrame],
    anchor_config: str,
    input_size: Tuple[int, int] = (1280, 720),
    show: bool = True,
    output: Optional[str] = None,
    output_size: Tuple[int, int] = (1600, 900),
    keep_ratio: bool = False,
) -> None:
    """Evaluates the fitness between `ground_truth_file` and `anchor_config_file`.

    Args:
        ground_truth_file: Path to COCO ground truth file or coco df_annotations DataFrame
            to be used from
            [`pyodi train-config generation`][pyodi.apps.train_config.train_config_generation.train_config_generation]
        anchor_config: Path to MMDetection-like `anchor_generator` section. It can also be a
            dictionary with the required data.
        input_size: Model image input size. Defaults to (1333, 800).
        show: Show results or not. Defaults to True.
        output: Output directory where results going to be saved. Defaults to None.
        output_size: Size of saved images. Defaults to (1600, 900).
        keep_ratio: Whether to keep the aspect ratio or not. Defaults to False.

    Examples:
        ```python
        # faster_rcnn_r50_fpn.py:
        anchor_generator=dict(
            type='AnchorGenerator',
            scales=[8],
            ratios=[0.5, 1.0, 2.0],
            strides=[4, 8, 16, 32, 64]
        )
        ```
    """
    if output is not None:
        Path(output).mkdir(parents=True, exist_ok=True)

    if isinstance(ground_truth_file, str):
        df_annotations = coco_ground_truth_to_df(ground_truth_file)

        df_annotations = filter_zero_area_bboxes(df_annotations)

        df_annotations = scale_bbox_dimensions(
            df_annotations, input_size=input_size, keep_ratio=keep_ratio
        )

        df_annotations = get_scale_and_ratio(df_annotations, prefix="scaled")

    else:
        df_annotations = ground_truth_file

    df_annotations["log_scaled_ratio"] = np.log(df_annotations["scaled_ratio"])

    if isinstance(anchor_config, str):
        anchor_config_data = load_anchor_config_file(anchor_config)
    elif isinstance(anchor_config, dict):
        anchor_config_data = anchor_config
    else:
        raise ValueError("anchor_config must be string or dictionary.")

    anchor_config_data["anchor_generator"].pop("type", None)
    anchor_generator = AnchorGenerator(**anchor_config_data["anchor_generator"])

    if isinstance(anchor_config, str):
        logger.info(anchor_generator.to_string())

    width, height = input_size
    featmap_sizes = [
        (width // stride, height // stride) for stride in anchor_generator.strides
    ]
    anchors_per_level = anchor_generator.grid_anchors(featmap_sizes=featmap_sizes)

    bboxes = get_bbox_array(
        df_annotations, prefix="scaled", output_bbox_format="corners"
    )

    overlaps = np.zeros(bboxes.shape[0])
    max_overlap_level = np.zeros(bboxes.shape[0])

    logger.info("Computing overlaps between anchors and ground truth ...")
    for i, anchor_level in enumerate(anchors_per_level):
        level_overlaps = get_max_overlap(
            bboxes.astype(np.float32), anchor_level.astype(np.float32)
        )
        max_overlap_level[level_overlaps > overlaps] = i
        overlaps = np.maximum(overlaps, level_overlaps)

    df_annotations["overlaps"] = overlaps
    df_annotations["max_overlap_level"] = max_overlap_level

    logger.info("Plotting results ...")
    plot_overlap_result(
        df_annotations, show=show, output=output, output_size=output_size
    )