# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for emitting metrics."""

import json
from typing import Literal
from unittest import TestCase as SimpleTestCase
from unittest import mock

import requests
import responses
from celery import shared_task, signals
from celery.contrib.testing.worker import start_worker
from celery.utils.nodenames import gethostname, nodename
from django.conf import settings
from django.test import override_settings
from django_prometheus.testutils import assert_metric_diff, save_registry
from prometheus_client import (
    CollectorRegistry,
    Counter,
    Gauge,
    Histogram,
    Summary,
)

from debusine.db.metrics import (
    BadMetricType,
    MetricNotFound,
    emit_metric,
    logger,
)
from debusine.project.celery import make_app
from debusine.test.django import TestCase, TransactionTestCase


class ExceptionTests(SimpleTestCase):
    """Test methods on our custom exceptions."""

    def test_MetricNotFound_str(self) -> None:
        e = MetricNotFound("some.metric")
        self.assertEqual(str(e), "Metric 'some.metric' not found")

    def test_BadMetricType_str(self) -> None:
        e = BadMetricType("some.metric", "Histogram")
        self.assertEqual(str(e), "Metric 'some.metric' is not a Histogram")


class EmitMetricDirectlyTests(TestCase):
    """Test :py:func:`emit_metric` when not running in a Celery task."""

    def setUp(self) -> None:
        super().setUp()
        self.registry = CollectorRegistry()
        self.prefix = self.__class__.__name__
        self.counter = Counter(
            f"{self.prefix}_test_counter",
            "test_counter",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.gauge = Gauge(
            f"{self.prefix}_test_gauge",
            "test_gauge",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.summary = Summary(
            f"{self.prefix}_test_summary",
            "test_summary",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.histogram = Histogram(
            f"{self.prefix}_test_histogram",
            "test_histogram",
            labelnames=("foo",),
            registry=self.registry,
        )
        self.enterContext(
            mock.patch("debusine.db.metrics.REGISTRY", new=self.registry)
        )
        self.frozen_registry = save_registry(registry=self.registry)

    def test_metric_not_found(self) -> None:
        with self.assertRaises(MetricNotFound) as raised:
            emit_metric(
                metric_type="counter", name="nonexistent", labels={}, value=1
            )
        self.assertEqual(raised.exception.name, "nonexistent")

    def test_bad_metric_type(self) -> None:
        metric_type: Literal["counter", "gauge", "summary", "histogram"]
        for metric_type, name in (
            ("counter", f"{self.prefix}_test_gauge"),
            ("gauge", f"{self.prefix}_test_counter"),
            ("summary", f"{self.prefix}_test_histogram"),
            ("histogram", f"{self.prefix}_test_summary"),
        ):
            with self.subTest(metric_type=metric_type):
                with self.assertRaises(BadMetricType) as raised:
                    emit_metric(
                        metric_type=metric_type, name=name, labels={}, value=1
                    )
                self.assertEqual(raised.exception.name, name)
                self.assertEqual(
                    raised.exception.metric_class_name, metric_type.capitalize()
                )

    def test_success(self) -> None:
        metric_type: Literal["counter", "gauge", "summary", "histogram"]
        for metric_type, name, value, expected_diffs in (
            (
                "counter",
                f"{self.prefix}_test_counter",
                1,
                ((f"{self.prefix}_test_counter_total", 1, {}),),
            ),
            (
                "gauge",
                f"{self.prefix}_test_gauge",
                2.5,
                ((f"{self.prefix}_test_gauge", 2.5, {}),),
            ),
            (
                "summary",
                f"{self.prefix}_test_summary",
                512,
                (
                    (f"{self.prefix}_test_summary_count", 1, {}),
                    (f"{self.prefix}_test_summary_sum", 512, {}),
                ),
            ),
            (
                "histogram",
                f"{self.prefix}_test_histogram",
                10.0,
                (
                    (f"{self.prefix}_test_histogram_bucket", 0, {"le": "7.5"}),
                    (f"{self.prefix}_test_histogram_bucket", 1, {"le": "10.0"}),
                    (f"{self.prefix}_test_histogram_count", 1, {}),
                    (f"{self.prefix}_test_histogram_sum", 10.0, {}),
                ),
            ),
        ):
            with self.subTest(metric_type=metric_type):
                emit_metric(
                    metric_type=metric_type,
                    name=name,
                    labels={"foo": "bar"},
                    value=value,
                )
                for (
                    expected_name,
                    expected_diff,
                    expected_labels,
                ) in expected_diffs:
                    with self.subTest(expected_name=expected_name):
                        assert_metric_diff(
                            self.frozen_registry,
                            expected_diff,
                            expected_name,
                            registry=self.registry,
                            foo="bar",
                            **expected_labels,
                        )


# mypy complains that celery.shared_task is untyped, which is true, but we
# can't fix that here.
@shared_task  # type: ignore[misc]
def _emit_metric(
    metric_type: Literal["counter", "gauge", "summary", "histogram"],
    name: str,
    labels: dict[str, str],
    value: float,
) -> None:
    """Sample Celery task to emit a metric."""
    emit_metric(metric_type=metric_type, name=name, labels=labels, value=value)


@override_settings(CELERY_BROKER_URL="memory://")
class EmitMetricFromCeleryTaskTests(TransactionTestCase):
    """Test :py:func:`emit_metric` when running in a Celery task."""

    def test_success(self) -> None:
        url = f"https://{settings.DEBUSINE_FQDN}/api/1.0/open-metrics/emit/"

        with (
            mock.patch("debusine.django.django_utils.test_data_override"),
            start_worker(
                make_app(),
                perform_ping_check=False,
                hostname=nodename("celery", gethostname()),
            ) as celery_worker,
        ):
            metric_type: Literal["counter", "gauge", "summary", "histogram"]
            for metric_type, name, value, expected_diffs in (
                (
                    "counter",
                    "test_counter",
                    1,
                    (("test_counter_total", 1, {}),),
                ),
                ("gauge", "test_gauge", 2.5, (("test_gauge", 2.5, {}),)),
                (
                    "summary",
                    "test_summary",
                    512,
                    (
                        ("test_summary_count", 1, {}),
                        ("test_summary_sum", 512, {}),
                    ),
                ),
                (
                    "histogram",
                    "test_histogram",
                    10.0,
                    (
                        ("test_histogram_bucket", 0, {"le": "7.5"}),
                        ("test_histogram_bucket", 1, {"le": "10.0"}),
                        ("test_histogram_count", 1, {}),
                        ("test_histogram_sum", 10.0, {}),
                    ),
                ),
            ):
                with (
                    self.subTest(metric_type=metric_type),
                    responses.RequestsMock() as rsps,
                ):
                    rsps.add(
                        responses.POST, url, status=requests.codes.no_content
                    )
                    _emit_metric.delay(
                        metric_type=metric_type,
                        name=name,
                        labels={"foo": "bar"},
                        value=value,
                    ).get()
                    rsps.assert_call_count(url, 1)
                    assert rsps.calls[0].request.body is not None
                    self.assertEqual(
                        json.loads(rsps.calls[0].request.body),
                        {
                            "metric_type": metric_type,
                            "name": name,
                            "labels": {"foo": "bar"},
                            "value": value,
                        },
                    )

            signals.worker_shutdown.send(sender=celery_worker)

    def test_failure(self) -> None:
        with (
            mock.patch("debusine.django.django_utils.test_data_override"),
            start_worker(
                make_app(),
                perform_ping_check=False,
                hostname=nodename("celery", gethostname()),
            ) as celery_worker,
        ):
            with (
                responses.RequestsMock(),
                self.assertLogs(
                    logger=logger,
                ) as logs,
            ):
                _emit_metric.delay(
                    metric_type="counter",
                    name="test_counter",
                    labels={"foo": "bar"},
                    value=1.0,
                ).get()
        signals.worker_shutdown.send(sender=celery_worker)
        self.assertRegex(
            logs.output[0],
            (
                r"Failed to report metric back to the server: .* "
                r"Error: Connection refused by Responses .*"
            ),
        )
