Confusion Matrix
class ConfusionMatrix
Calculate and visualize confusion matrix of Object Detection model.
from_detections
def from_detections(cls, true_batches, detection_batches, num_classes, conf_threshold, iou_threshold)
Args
-
true_batches:
List[np.ndarray]representing ground-truth objects across all images in concerned dataset. Each element oftrue_batcheslist describe single image and hasshape = (N, 5)whereNis number of ground-truth objects. Each row is expected to be in(x_min, y_min, x_max, y_max, class). -
detection_batches:
List[np.ndarray]representing detected objects across all images in concerned dataset. Each element ofdetection_batcheslist describe single image and hasshape = (M, 6)whereMis number of detected objects. Each row is expected to be in(x_min, y_min, x_max, y_max, class, conf). -
num_classes:
intnumber of classes detected by model. -
conf_threshold:
floatdetection confidence threshold between 0 and 1. Detections with lower confidence will be excluded. -
iou_threshold:
floatdetection iou threshold between 0 and 1. Detections with lower iou will be classified as FP.
Returns
- confusion_matrix:
ConfusionMatrixobject raw confusion matrix 2dnp.ndarray.
Example usage
>>> import numpy as np
>>> from onemetric.cv.object_detection import ConfusionMatrix
>>> true_batches = [
... np.array([
... [0.0, 0.0, 3.0, 3.0, 1],
... [2.0, 2.0, 5.0, 5.0, 1],
... [6.0, 1.0, 8.0, 3.0, 2],
... ]),
... np.array([
... [1.0, 1.0, 2.0, 2.0, 2],
... ]),
... ]
>>> detection_batches = [
... np.array([
... [0.0, 0.0, 3.0, 3.0, 1, 0.9],
... [0.1, 0.1, 3.0, 3.0, 0, 0.9],
... [6.0, 1.0, 8.0, 3.0, 1, 0.8],
... [1.0, 6.0, 2.0, 7.0, 1, 0.8],
... ]),
... np.array([
... [1.0, 1.0, 2.0, 2.0, 2, 0.8],
... ]),
... ]
>>> confusion_matrix = ConfusionMatrix.from_detections(
... true_batches=true_batches,
... detection_batches=detection_batches,
... num_classes=3
... )
>>> confusion_matrix.matrix
... array([
... [0., 0., 0., 0.],
... [0., 1., 0., 1.],
... [0., 1., 1., 0.],
... [1., 1., 0., 0.]
... ])
plot
def plot(target_path, title, class_names, normalize)
Args
-
target_path:
strselected target location of confusion matrix plot. -
title:
Optional[str]title displayed at the top of the confusion matrix plot. DefaultNone. -
class_names:
Optional[List[str]]list of class names detected my model. If non given class indexes will be used. DefaultNone. -
normalize:
boolif set toFalsechart will display absolute number of detections falling into given category. Otherwise percentage of detections will be displayed.