import re

import pytest
from pydantic import BaseModel

from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory


class MockItemModel(BaseModel):
    id: str
    value: int


@pytest.fixture
def item_storage_memory():
    return ItemStorageMemory[MockItemModel]()


def test_item_storage_memory_initializes():
    item_storage_memory = ItemStorageMemory[MockItemModel]()
    assert item_storage_memory._items == {}
    assert item_storage_memory._id_field == "id"
    assert item_storage_memory._max_items == 10

    item_storage_memory = ItemStorageMemory[MockItemModel](id_field="bananas", max_items=20)
    assert item_storage_memory._id_field == "bananas"
    assert item_storage_memory._max_items == 20

    with pytest.raises(ValueError, match=re.escape("max_items must be at least 1")):
        item_storage_memory = ItemStorageMemory[MockItemModel](max_items=0)
    with pytest.raises(ValueError, match=re.escape("id_field must not be empty")):
        item_storage_memory = ItemStorageMemory[MockItemModel](id_field="")


def test_item_storage_memory_sets(item_storage_memory: ItemStorageMemory[MockItemModel]):
    item_1 = MockItemModel(id="1", value=1)
    item_storage_memory.set(item_1)
    assert item_storage_memory._items == {"1": item_1}

    item_2 = MockItemModel(id="2", value=2)
    item_storage_memory.set(item_2)
    assert item_storage_memory._items == {"1": item_1, "2": item_2}

    # Updating value of existing item
    item_2_updated = MockItemModel(id="2", value=9001)
    item_storage_memory.set(item_2_updated)
    assert item_storage_memory._items == {"1": item_1, "2": item_2_updated}


def test_item_storage_memory_gets(item_storage_memory: ItemStorageMemory[MockItemModel]):
    item_1 = MockItemModel(id="1", value=1)
    item_storage_memory.set(item_1)
    item = item_storage_memory.get("1")
    assert item == item_1

    item_2 = MockItemModel(id="2", value=2)
    item_storage_memory.set(item_2)
    item = item_storage_memory.get("2")
    assert item == item_2

    with pytest.raises(ItemNotFoundError, match=re.escape("Item with id 3 not found")):
        item_storage_memory.get("3")


def test_item_storage_memory_deletes(item_storage_memory: ItemStorageMemory[MockItemModel]):
    item_1 = MockItemModel(id="1", value=1)
    item_2 = MockItemModel(id="2", value=2)
    item_storage_memory.set(item_1)
    item_storage_memory.set(item_2)

    item_storage_memory.delete("2")
    assert item_storage_memory._items == {"1": item_1}


def test_item_storage_memory_respects_max():
    item_storage_memory = ItemStorageMemory[MockItemModel](max_items=3)
    for i in range(10):
        item_storage_memory.set(MockItemModel(id=str(i), value=i))
    assert item_storage_memory._items == {
        "7": MockItemModel(id="7", value=7),
        "8": MockItemModel(id="8", value=8),
        "9": MockItemModel(id="9", value=9),
    }


def test_item_storage_memory_calls_set_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
    called_item = None
    item = MockItemModel(id="1", value=1)

    def on_changed(item: MockItemModel):
        nonlocal called_item
        called_item = item

    item_storage_memory.on_changed(on_changed)
    item_storage_memory.set(item)
    assert called_item == item


def test_item_storage_memory_calls_delete_callback(item_storage_memory: ItemStorageMemory[MockItemModel]):
    called_item_id = None
    item = MockItemModel(id="1", value=1)

    def on_deleted(item_id: str):
        nonlocal called_item_id
        called_item_id = item_id

    item_storage_memory.on_deleted(on_deleted)
    item_storage_memory.set(item)
    item_storage_memory.delete("1")
    assert called_item_id == "1"