Source code for gluonts.nursery.anomaly_detection.supervised_metrics._buffered_precision_recall

# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import List, Tuple

from .utils import range_overlap


def extend_ranges(
    ranges: List[range], extension_length: int, direction: str = "right"
) -> List[range]:
    """
    Extends a list of ranges by `extension_length` steps.

    Parameters
    ----------
    ranges: List[range]
        Ranges to be extended, non-overlapping and in sorted order.
    extension_length: int
        Positive integer, the length of the extension.
    direction: str
        Direction of the extension. If "right" ("left"), ranges are extended towards
        the right (left). Default "right".
    Returns
    -------
    out_ranges: List[range]
    """
    if not ranges:
        return []
    assert extension_length >= 0, "`extension_length` must be zero or above"
    assert direction in [
        "left",
        "right",
    ], "`direction` must be one of 'left' or 'right'."

    out_ranges = []

    # iterate over real ranges (ground truth anomalies) and "extend" them to include
    # slack windows where it's OK if they're caught
    for i, old_range in enumerate(ranges):
        if direction == "right":
            range_end_ub = (
                ranges[i + 1].start
                if i < len(ranges) - 1
                else ranges[-1].stop + extension_length + 1
            )
            new_range = range(
                old_range.start,
                min(range_end_ub, old_range.stop + extension_length),
            )
        else:
            range_start_lb = ranges[i - 1].stop if i > 0 else 0
            new_range = range(
                max(range_start_lb, old_range.start - extension_length),
                old_range.stop,
            )

        out_ranges.append(new_range)

    return out_ranges


[docs]def buffered_precision_recall( real_ranges: List[range], pred_ranges: List[range], buffer_length: int = 5, ) -> Tuple[float, float]: """ Implements a new range-based precision recall metric, that measures how well anomalies (`real_ranges`) are caught with labels (`pred_ranges`). We extend anomaly ranges by a number of time steps (`buffer_length`) to accomodate those raised with a lag. For example, if an annotator has marked `range(5, 9)`, and the model has labeled `range(11, 13)` as an anomaly, we would often like to mark this as a correctly raised anomaly. There are two reasons for this. (i) Human annotators often draw boxes around anomalies with a certain "margin," i.e., with a lead and a lag around the true anomaly. (ii) The low-pass filter raises anomalies with a certain latency. Therefore, this function looks for intersections between "extended" anomaly ranges, those with a buffer of `buffer_length` added after the annotated range, and the predicted ranges. Any intersection is counted as a success. More precisely, - If an "extended" anomaly range intersects with any labeled range, it's "caught." If an anomaly intersects with no predicted range, it's not caught. Recall is, `n_caught_anomalies / n_all_anomalies`. - If a predicted range intersects with any "extended" anomaly range, it's a good alarm. Precision is `n_good_pred_ranges / n_pred_ranges`. Note that the numerators (numbers of true positives) are different for precision and recall. This is since an anomaly can be caught by multiple pred ranges, as well as a pred range marking two separate anomalies. This function allows for this behavior. Moreover, a prediction range is either "good" (it intersects with an extended anomaly range, and is a "true positive predicted range") or or "bad" (false positive). This is different than `segment_precision_recall`, since there a predicted range is counted towards true positives and false positives at the same time if it spans the intersection of an anomaly segment and a non-anomaly segment. Parameters ---------- real_ranges: List[range] Python range objects representing ground truth anomalies (e.g., as annotated by human labelers). Ranges must ve non-overlapping and in sorted order. pred_ranges: List[range] Python range objects representing labels produced by the model. Ranges must be non-overlapping and in sorted order. buffer_length: int The number of time periods which a predicted range is allowed lag after an anomaly, for which it will be marked as a "good" raise. For example, if the actual range is `range(5,7)` and the predicted range is `range(8, 9)`, this prediction will be deemed accurate with a buffer length of 2 or above. Returns ------- precision: float Precision. Ratio of predicted ranges that overlap with an (extended) anomaly range. recall: float Recall. Ratio of (extended) anomaly ranges that were caught by (overlaps with) at least one prediction range. """ if len(real_ranges) == 0 and len(pred_ranges) == 0: return 1.0, 1.0 if len(pred_ranges) == 0: return 1.0, 0.0 if len(real_ranges) == 0: return 0.0, 1.0 extended_ranges = extend_ranges(real_ranges, buffer_length) recall_tp, fn = 0, 0 for tr in extended_ranges: # any labels raised this target range? raised = any(range_overlap(tr, pr) for pr in pred_ranges) recall_tp += raised fn += 1 - raised precision_tp, fp = 0, 0 for pr in pred_ranges: # any anomalies caught by the label? caught = any(range_overlap(tr, pr) for tr in extended_ranges) precision_tp += caught fp += 1 - caught return precision_tp / len(pred_ranges), recall_tp / len(extended_ranges)