# Copyright 2019, 2021-2024 The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Unit tests for the workspace models."""
from typing import ClassVar

from django.test import TestCase

from debusine.artifacts.models import CollectionCategory
from debusine.db.models import (
    Collection,
    FileInStore,
    FileStore,
    Workspace,
    default_workspace,
)
from debusine.db.models.auth import User
from debusine.db.models.workspaces import WorkspaceChain
from debusine.test import TestHelpersMixin


class WorkspaceManagerTests(TestCase):
    """Tests for the WorkspaceManager class."""

    def test_create_with_name(self) -> None:
        """Test create_with_name return Workspace with default's FileStore."""
        name = "nas-01"
        created = Workspace.objects.create_with_name(name)
        created.refresh_from_db()
        self.assertEqual(created.name, name)
        self.assertEqual(created.default_file_store, FileStore.default())


class WorkspaceTests(TestHelpersMixin, TestCase):
    """Tests for the Workspace class."""

    a: ClassVar[Workspace]
    b: ClassVar[Workspace]
    c: ClassVar[Workspace]
    d: ClassVar[Workspace]
    user: ClassVar[User]

    @staticmethod
    def _get_collection(
        workspace: Workspace,
        name: str,
        category: CollectionCategory = CollectionCategory.SUITE,
        user: User | None = None,
    ) -> Collection:
        """Shortcut to lookup a collection from a workspace."""
        return workspace.get_collection(name=name, category=category, user=user)

    @classmethod
    def setUpTestData(cls) -> None:
        """Set up common data for tests."""
        super().setUpTestData()
        cls.a = cls.create_workspace(name="a", public=True)
        cls.b = cls.create_workspace(name="b", public=True)
        cls.c = cls.create_workspace(name="c", public=True)
        cls.d = cls.create_workspace(name="d", public=True)
        cls.user = cls.get_test_user()

    def assertInherits(
        self, child: Workspace, parents: list[Workspace]
    ) -> None:
        """Ensure inheritance chain matches."""
        chain = list(
            child.chain_parents.order_by("order").values_list(
                "child__name", "parent__name", "order"
            )
        )

        expected: list[tuple[str, str, int]] = []
        for idx, parent in enumerate(parents):
            expected.append((child.name, parent.name, idx))
        self.assertEqual(chain, expected)

    def assertGetCollectionEqual(
        self,
        workspace: Workspace,
        name: str,
        expected: Collection,
        *,
        category: CollectionCategory = CollectionCategory.SUITE,
        user: User | None = None,
    ) -> None:
        """Check that collection lookup yields the given result."""
        self.assertEqual(
            self._get_collection(
                workspace, name=name, category=category, user=user
            ),
            expected,
        )

    def assertGetCollectionFails(
        self,
        workspace: Workspace,
        name: str,
        *,
        category: CollectionCategory = CollectionCategory.SUITE,
        user: User | None = None,
    ) -> None:
        """Check that collection lookup fails."""
        with self.assertRaises(Collection.DoesNotExist):
            self._get_collection(
                workspace, name=name, category=category, user=user
            )

    def test_default_values_fields(self) -> None:
        """Test basic behavior."""
        name = "test"
        workspace = Workspace(name=name, default_file_store=FileStore.default())

        self.assertEqual(workspace.name, name)
        self.assertFalse(workspace.public)

        workspace.clean_fields()
        workspace.save()

    def test_is_file_in_workspace_default_file_store(self) -> None:
        """
        Test is_file_in_workspace return the correct value.

        The file is added in the default FileStore.
        """
        fileobj = self.create_file()
        workspace = default_workspace()
        self.assertFalse(workspace.is_file_in_workspace(fileobj))

        # Add file in store
        FileInStore.objects.create(
            file=fileobj, store=workspace.default_file_store
        )

        self.assertTrue(workspace.is_file_in_workspace(fileobj))

    def test_is_file_in_workspace_other_file_stores(self) -> None:
        """
        Test is_file_in_workspace return the correct value.

        The file is added in a FileStore that is the non-default.
        """
        fileobj = self.create_file()
        workspace = default_workspace()
        self.assertFalse(workspace.is_file_in_workspace(fileobj))

        # Add file in a store which is not the default one
        store = FileStore.objects.create(
            name="nas-01", backend=FileStore.BackendChoices.LOCAL
        )
        workspace.other_file_stores.add(store)
        workspace.refresh_from_db()

        # Add file in the store
        FileInStore.objects.create(file=fileobj, store=store)

        self.assertTrue(workspace.is_file_in_workspace(fileobj))

    def test_str(self) -> None:
        """Test __str__ method."""
        workspace = Workspace(name="test")

        self.assertEqual(
            workspace.__str__(),
            f"Id: {workspace.id} Name: {workspace.name}",
        )

    def test_workspacechain_str(self) -> None:
        """Test WorkspaceChain.__str__ method."""
        chain = WorkspaceChain(parent=self.a, child=self.b, order=5)
        self.assertEqual(chain.__str__(), "5:b→a")

    def test_set_inheritance(self) -> None:
        """Test set_inheritance method."""
        a, b, c = self.a, self.b, self.c

        a.set_inheritance([b])
        self.assertInherits(a, [b])
        self.assertInherits(b, [])
        self.assertInherits(c, [])

        a.set_inheritance([b, c])
        self.assertInherits(a, [b, c])
        self.assertInherits(b, [])
        self.assertInherits(c, [])

        a.set_inheritance([c, b])
        self.assertInherits(a, [c, b])
        self.assertInherits(b, [])
        self.assertInherits(c, [])

        c.set_inheritance([a])
        self.assertInherits(a, [c, b])
        self.assertInherits(b, [])
        self.assertInherits(c, [a])

        a.set_inheritance([])
        self.assertInherits(a, [])
        self.assertInherits(b, [])
        self.assertInherits(c, [a])

        with self.assertRaisesRegex(
            ValueError, r"duplicate workspace 'b' in inheritance chain"
        ):
            a.set_inheritance([b, b, c])

    def test_collection_lookup(self) -> None:
        """Test collection lookup."""
        a, b, c = self.a, self.b, self.c

        a1 = self.create_collection(
            name="a", category=CollectionCategory.SUITE, workspace=a
        )

        # Lookup locally
        self.assertEqual(self._get_collection(a, "a"), a1)
        self.assertGetCollectionFails(b, "a")
        self.assertGetCollectionFails(c, "a")

        # Lookup through a simple chain
        b.set_inheritance([a])
        self.assertGetCollectionEqual(b, "a", a1)

        # Walk the inheritance chain until the end
        b.set_inheritance([c, a])
        self.assertGetCollectionEqual(b, "a", a1)
        self.assertGetCollectionFails(c, "a")

        # Create a chain loop, lookup does not break
        c.set_inheritance([b])
        self.assertGetCollectionEqual(b, "a", a1)

    def test_collection_lookup_graph(self) -> None:
        """Test collection lookup graph."""
        a, b, c, d = self.a, self.b, self.c, self.d

        cb = self.create_collection(
            name="a", category=CollectionCategory.SUITE, workspace=b
        )
        cc = self.create_collection(
            name="a", category=CollectionCategory.SUITE, workspace=c
        )

        # Lookup happens in order
        a.set_inheritance([b, c])
        self.assertGetCollectionEqual(a, "a", cb)
        a.set_inheritance([c, b])
        self.assertGetCollectionEqual(a, "a", cc)

        # Lookup happens depth-first
        a.set_inheritance([d, c])
        b.set_inheritance([])
        c.set_inheritance([])
        d.set_inheritance([b])
        self.assertGetCollectionEqual(a, "a", cb)

    def test_user_restrictions(self) -> None:
        """Test user restriction enforcement."""
        wpub = self.create_workspace(name="public", public=True)
        cpub = self.create_collection(
            "test", CollectionCategory.SUITE, workspace=wpub
        )
        wpriv = self.create_workspace(name="private", public=False)
        cpriv = self.create_collection(
            "test", CollectionCategory.SUITE, workspace=wpriv
        )
        wstart = self.create_workspace(name="start", public=True)

        self.assertGetCollectionFails(wstart, "test", user=None)
        self.assertGetCollectionFails(wstart, "test", user=self.user)

        # Lookups in the workspace itself do check restrictions
        self.assertGetCollectionEqual(wpub, "test", cpub, user=None)
        self.assertGetCollectionEqual(wpub, "test", cpub, user=self.user)
        self.assertGetCollectionFails(wpriv, "test", user=None)
        self.assertGetCollectionEqual(wpriv, "test", cpriv, user=self.user)

        # Inheritance chain is always followed for public datasets
        wstart.set_inheritance([wpub])
        self.assertGetCollectionEqual(wstart, "test", cpub, user=None)
        self.assertGetCollectionEqual(wstart, "test", cpub, user=self.user)

        # Inheritance chain on private datasets is followed only if logged in
        wstart.set_inheritance([wpriv])
        self.assertGetCollectionFails(wstart, "test", user=None)
        self.assertGetCollectionEqual(wstart, "test", cpriv, user=self.user)

        # Inheritance chain skips private datasets but can see public ones
        wstart.set_inheritance([wpriv, wpub])
        self.assertGetCollectionEqual(wstart, "test", cpub, user=None)
        self.assertGetCollectionEqual(wstart, "test", cpriv, user=self.user)
