import sqlite3
from threading import Lock
from typing import Generic, TypeVar, Union, get_args

from pydantic import BaseModel, parse_raw_as

from .item_storage import ItemStorageABC, PaginatedResults

T = TypeVar("T", bound=BaseModel)

sqlite_memory = ":memory:"


class SqliteItemStorage(ItemStorageABC, Generic[T]):
    _filename: str
    _table_name: str
    _conn: sqlite3.Connection
    _cursor: sqlite3.Cursor
    _id_field: str
    _lock: Lock

    def __init__(self, filename: str, table_name: str, id_field: str = "id"):
        super().__init__()

        self._filename = filename
        self._table_name = table_name
        self._id_field = id_field  # TODO: validate that T has this field
        self._lock = Lock()
        self._conn = sqlite3.connect(
            self._filename, check_same_thread=False
        )  # TODO: figure out a better threading solution
        self._cursor = self._conn.cursor()

        self._create_table()

    def _create_table(self):
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
                item TEXT,
                id TEXT GENERATED ALWAYS AS (json_extract(item, '$.{self._id_field}')) VIRTUAL NOT NULL);"""
            )
            self._cursor.execute(
                f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
            )
        finally:
            self._lock.release()

    def _parse_item(self, item: str) -> T:
        item_type = get_args(self.__orig_class__)[0]
        return parse_raw_as(item_type, item)

    def set(self, item: T):
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
                (item.json(),),
            )
            self._conn.commit()
        finally:
            self._lock.release()
        self._on_changed(item)

    def get(self, id: str) -> Union[T, None]:
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
            )
            result = self._cursor.fetchone()
        finally:
            self._lock.release()

        if not result:
            return None

        return self._parse_item(result[0])

    def delete(self, id: str):
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
            )
            self._conn.commit()
        finally:
            self._lock.release()
        self._on_deleted(id)

    def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
                (per_page, page * per_page),
            )
            result = self._cursor.fetchall()

            items = list(map(lambda r: self._parse_item(r[0]), result))

            self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
            count = self._cursor.fetchone()[0]
        finally:
            self._lock.release()

        pageCount = int(count / per_page) + 1

        return PaginatedResults[T](
            items=items, page=page, pages=pageCount, per_page=per_page, total=count
        )

    def search(
        self, query: str, page: int = 0, per_page: int = 10
    ) -> PaginatedResults[T]:
        try:
            self._lock.acquire()
            self._cursor.execute(
                f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
                (f"%{query}%", per_page, page * per_page),
            )
            result = self._cursor.fetchall()

            items = list(map(lambda r: self._parse_item(r[0]), result))

            self._cursor.execute(
                f"""SELECT count(*) FROM {self._table_name} WHERE item LIKE ?;""",
                (f"%{query}%",),
            )
            count = self._cursor.fetchone()[0]
        finally:
            self._lock.release()

        pageCount = int(count / per_page) + 1

        return PaginatedResults[T](
            items=items, page=page, pages=pageCount, per_page=per_page, total=count
        )