# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Optional
from logging import Logger
from invokeai.app.services.board_image_record_storage import (
    SqliteBoardImageRecordStorage,
)
from invokeai.app.services.board_images import (
    BoardImagesService,
    BoardImagesServiceDependencies,
)
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__

from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService


# TODO: is there a better way to achieve this?
def check_internet() -> bool:
    """
    Return true if the internet is reachable.
    It does this by pinging huggingface.co.
    """
    import urllib.request

    host = "http://huggingface.co"
    try:
        urllib.request.urlopen(host, timeout=1)
        return True
    except:
        return False


logger = InvokeAILogger.getLogger()


class ApiDependencies:
    """Contains and initializes all dependencies for the API"""

    invoker: Invoker

    @staticmethod
    def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
        logger.info(f"InvokeAI version {__version__}")
        logger.info(f"Root directory = {str(config.root_path)}")
        logger.debug(f"Internet connectivity is {config.internet_available}")

        events = FastAPIEventService(event_handler_id)

        output_folder = config.output_path

        # TODO: build a file/path manager?
        db_path = config.db_path
        db_path.parent.mkdir(parents=True, exist_ok=True)
        db_location = str(db_path)

        graph_execution_manager = SqliteItemStorage[GraphExecutionState](
            filename=db_location, table_name="graph_executions"
        )

        urls = LocalUrlService()
        image_record_storage = SqliteImageRecordStorage(db_location)
        image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
        names = SimpleNameService()
        latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))

        board_record_storage = SqliteBoardRecordStorage(db_location)
        board_image_record_storage = SqliteBoardImageRecordStorage(db_location)

        boards = BoardService(
            services=BoardServiceDependencies(
                board_image_record_storage=board_image_record_storage,
                board_record_storage=board_record_storage,
                image_record_storage=image_record_storage,
                url=urls,
                logger=logger,
            )
        )

        board_images = BoardImagesService(
            services=BoardImagesServiceDependencies(
                board_image_record_storage=board_image_record_storage,
                board_record_storage=board_record_storage,
                image_record_storage=image_record_storage,
                url=urls,
                logger=logger,
            )
        )

        images = ImageService(
            services=ImageServiceDependencies(
                board_image_record_storage=board_image_record_storage,
                image_record_storage=image_record_storage,
                image_file_storage=image_file_storage,
                url=urls,
                logger=logger,
                names=names,
                graph_execution_manager=graph_execution_manager,
            )
        )

        services = InvocationServices(
            model_manager=ModelManagerService(config, logger),
            events=events,
            latents=latents,
            images=images,
            boards=boards,
            board_images=board_images,
            queue=MemoryInvocationQueue(),
            graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
            graph_execution_manager=graph_execution_manager,
            processor=DefaultInvocationProcessor(),
            configuration=config,
            performance_statistics=InvocationStatsService(graph_execution_manager),
            logger=logger,
        )

        create_system_graphs(services.graph_library)

        ApiDependencies.invoker = Invoker(services)

    @staticmethod
    def shutdown():
        if ApiDependencies.invoker:
            ApiDependencies.invoker.stop()