from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from pydantic import BaseModel, Field


class TestModel(BaseModel):
    id: str = Field(description="ID")
    name: str = Field(description="Name")


def test_sqlite_service_can_create_and_get():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    assert db.get("1") == TestModel(id="1", name="Test")


def test_sqlite_service_can_list():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.list()
    assert results.page == 0
    assert results.pages == 1
    assert results.per_page == 10
    assert results.total == 3
    assert results.items == [
        TestModel(id="1", name="Test"),
        TestModel(id="2", name="Test"),
        TestModel(id="3", name="Test"),
    ]


def test_sqlite_service_can_delete():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.delete("1")
    assert db.get("1") is None


def test_sqlite_service_calls_set_callback():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    called = False

    def on_changed(item: TestModel):
        nonlocal called
        called = True

    db.on_changed(on_changed)
    db.set(TestModel(id="1", name="Test"))
    assert called


def test_sqlite_service_calls_delete_callback():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    called = False

    def on_deleted(item_id: str):
        nonlocal called
        called = True

    db.on_deleted(on_deleted)
    db.set(TestModel(id="1", name="Test"))
    db.delete("1")
    assert called


def test_sqlite_service_can_list_with_pagination():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.list(page=0, per_page=2)
    assert results.page == 0
    assert results.pages == 2
    assert results.per_page == 2
    assert results.total == 3
    assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]


def test_sqlite_service_can_list_with_pagination_and_offset():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.list(page=1, per_page=2)
    assert results.page == 1
    assert results.pages == 2
    assert results.per_page == 2
    assert results.total == 3
    assert results.items == [TestModel(id="3", name="Test")]


def test_sqlite_service_can_search():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.search(query="Test")
    assert results.page == 0
    assert results.pages == 1
    assert results.per_page == 10
    assert results.total == 3
    assert results.items == [
        TestModel(id="1", name="Test"),
        TestModel(id="2", name="Test"),
        TestModel(id="3", name="Test"),
    ]


def test_sqlite_service_can_search_with_pagination():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.search(query="Test", page=0, per_page=2)
    assert results.page == 0
    assert results.pages == 2
    assert results.per_page == 2
    assert results.total == 3
    assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")]


def test_sqlite_service_can_search_with_pagination_and_offset():
    db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id")
    db.set(TestModel(id="1", name="Test"))
    db.set(TestModel(id="2", name="Test"))
    db.set(TestModel(id="3", name="Test"))
    results = db.search(query="Test", page=1, per_page=2)
    assert results.page == 1
    assert results.pages == 2
    assert results.per_page == 2
    assert results.total == 3
    assert results.items == [TestModel(id="3", name="Test")]