# Copyright 2024 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Tests for application context."""

import asyncio
from collections.abc import Callable
from threading import Thread
from typing import Any

from debusine.db.context import ContextConsistencyError, context
from debusine.db.models import Scope, Workspace
from debusine.test.django import TestCase


class TestAppContext(TestCase):
    """Test application context variables."""

    def setUp(self):
        """Ensure context is not left dirty by a bug in a previous test."""
        super().setUp()
        self.assertIsNone(context.scope)
        self.assertIsNone(context.workspace)

    def run_in_thread(
        self, func: Callable[..., Any], *args: Any, **kwargs: Any
    ) -> Any:
        """Run a callable in a thread."""
        result: Any = None

        def _thread_main() -> None:
            nonlocal result
            result = func(*args, **kwargs)

        thread = Thread(target=_thread_main)
        thread.start()
        thread.join()

        return result

    def run_in_task(
        self, func: Callable[..., Any], *args: Any, **kwargs: Any
    ) -> Any:
        """Run a callable in an asyncio task."""

        async def _task_main():
            return func(*args, **kwargs)

        async def _async_main():
            return await asyncio.create_task(_task_main())

        return asyncio.run(_async_main())

    def test_defaults(self) -> None:
        """Context variables are None by default."""
        self.assertIsNone(context.scope)
        self.assertIsNone(context.workspace)

    def test_local_previously_none(self) -> None:
        """Using local() restores previously unset vars."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)

        self.assertIsNone(context.scope)
        self.assertIsNone(context.workspace)

    def test_local_previously_set(self) -> None:
        """Using local() restores previously set vars."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        try:
            context.workspace = workspace

            with context.local():
                context.scope = None

                self.assertIsNone(context.scope)
                self.assertIsNone(context.workspace)

            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)
        finally:
            context.reset()

    def test_reset(self) -> None:
        """Calling reset() sets context to its initial values."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.workspace = workspace

            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)

            context.reset()

            self.assertIsNone(context.scope)
            self.assertIsNone(context.workspace)

    def test_visibility_thread(self) -> None:
        """Check visibility with subthreads."""
        scope1 = Scope(name="scope1")
        workspace1 = Workspace(scope=scope1, name="test1")
        scope2 = Scope(name="scope2")
        workspace2 = Workspace(scope=scope2, name="test2")

        with context.local():
            context.workspace = workspace1

            def _test():
                # Application context is cleared when changing thread or task
                self.assertIsNone(context.scope)
                self.assertIsNone(context.workspace)

                context.scope = scope2
                context.workspace = workspace2

                self.assertEqual(context.scope, scope2)
                self.assertEqual(context.workspace, workspace2)

            orig_scope = context.scope
            orig_workspace = context.workspace
            self.run_in_thread(_test)
            self.assertEqual(context.scope, orig_scope)
            self.assertEqual(context.workspace, orig_workspace)

    def test_visibility_task(self) -> None:
        """Check visibility with asyncio tasks."""
        scope1 = Scope(name="scope1")
        workspace1 = Workspace(scope=scope1, name="test1")
        scope2 = Scope(name="scope2")
        workspace2 = Workspace(scope=scope2, name="test2")

        with context.local():
            context.workspace = workspace1

            def _test():
                self.assertEqual(context.scope, scope1)
                self.assertEqual(context.workspace, workspace1)

                context.scope = scope2
                context.workspace = workspace2

                self.assertEqual(context.scope, scope2)
                self.assertEqual(context.workspace, workspace2)

            orig_scope = context.scope
            orig_workspace = context.workspace
            self.run_in_task(_test)
            self.assertEqual(context.scope, orig_scope)
            self.assertEqual(context.workspace, orig_workspace)

    def test_set_scope(self) -> None:
        """Test setting scope."""
        scope = Scope(name="scope")

        with context.local():
            self.assertIsNone(context.scope)
            context.scope = scope
            self.assertEqual(context.scope, scope)
            self.assertIsNone(context.workspace)

    def test_set_scope_none(self) -> None:
        """Test setting scope to None."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)

            context.scope = None
            self.assertIsNone(context.scope)
            self.assertIsNone(context.workspace)

    def test_set_workspace_after_scope(self) -> None:
        """Test setting scope and workspace separately."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.scope = scope
            context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)

    def test_set_scope_after_workspace(self) -> None:
        """Test setting workspace then scope."""
        scope1 = Scope(name="scope1")
        scope2 = Scope(name="scope2")
        workspace = Workspace(scope=scope1, name="test")

        with context.local():
            context.workspace = workspace
            self.assertEqual(context.scope, scope1)
            self.assertEqual(context.workspace, workspace)

            context.scope = scope2
            self.assertEqual(context.scope, scope2)
            self.assertIsNone(context.workspace)

    def test_set_workspace_and_scope(self) -> None:
        """Test setting scope from workspace."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)

    def test_set_workspace_wrong_scope(self) -> None:
        """Test scope/workspace.scope consistency checks."""
        scope = Scope(name="scope")
        scope1 = Scope(name="scope1")
        workspace = Workspace(scope=scope1, name="test")

        with context.local():
            context.scope = scope
            with self.assertRaisesRegex(
                ContextConsistencyError,
                "workspace scope 'scope1' does not match current scope 'scope'",
            ):
                context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertIsNone(context.workspace)

    def test_set_workspace_none(self) -> None:
        """Test setting scope and workspace separately."""
        scope = Scope(name="scope")
        workspace = Workspace(scope=scope, name="test")

        with context.local():
            context.workspace = workspace
            self.assertEqual(context.scope, scope)
            self.assertEqual(context.workspace, workspace)
            context.workspace = None
            self.assertEqual(context.scope, scope)
            self.assertIsNone(context.workspace)
