# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from __future__ import annotations

from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING

from pydantic import BaseModel, Field

if TYPE_CHECKING:
    from ..services.invocation_services import InvocationServices


class InvocationContext:
    services: InvocationServices
    graph_execution_state_id: str

    def __init__(self, services: InvocationServices, graph_execution_state_id: str):
        self.services = services
        self.graph_execution_state_id = graph_execution_state_id


class BaseInvocationOutput(BaseModel):
    """Base class for all invocation outputs"""

    # All outputs must include a type name like this:
    # type: Literal['your_output_name']

    @classmethod
    def get_all_subclasses_tuple(cls):
        subclasses = []
        toprocess = [cls]
        while len(toprocess) > 0:
            next = toprocess.pop(0)
            next_subclasses = next.__subclasses__()
            subclasses.extend(next_subclasses)
            toprocess.extend(next_subclasses)
        return tuple(subclasses)


class BaseInvocation(ABC, BaseModel):
    """A node to process inputs and produce outputs.
    May use dependency injection in __init__ to receive providers.
    """

    # All invocations must include a type name like this:
    # type: Literal['your_output_name']

    @classmethod
    def get_all_subclasses(cls):
        subclasses = []
        toprocess = [cls]
        while len(toprocess) > 0:
            next = toprocess.pop(0)
            next_subclasses = next.__subclasses__()
            subclasses.extend(next_subclasses)
            toprocess.extend(next_subclasses)
        return subclasses

    @classmethod
    def get_invocations(cls):
        return tuple(BaseInvocation.get_all_subclasses())

    @classmethod
    def get_invocations_map(cls):
        # Get the type strings out of the literals and into a dictionary
        return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses()))
    
    @classmethod
    def get_output_type(cls):
        return signature(cls.invoke).return_annotation

    @abstractmethod
    def invoke(self, context: InvocationContext) -> BaseInvocationOutput:
        """Invoke with provided context and return outputs."""
        pass
    
    #fmt: off
    id: str = Field(description="The id of this node. Must be unique among all nodes.")
    is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.")
    #fmt: on


# TODO: figure out a better way to provide these hints
# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False`
class UIConfig(TypedDict, total=False):
    type_hints: Dict[
        str,
        Literal[
            "integer",
            "float",
            "boolean",
            "string",
            "enum",
            "image",
            "latents",
            "model",
            "control",
            "image_collection",
        ],
    ]
    tags: List[str]
    title: str

class CustomisedSchemaExtra(TypedDict):
    ui: UIConfig


class InvocationConfig(BaseModel.Config):
    """Customizes pydantic's BaseModel.Config class for use by Invocations.

    Provide `schema_extra` a `ui` dict to add hints for generated UIs.

    `tags`
    - A list of strings, used to categorise invocations.

    `type_hints`
    - A dict of field types which override the types in the invocation definition.
    - Each key should be the name of one of the invocation's fields.
    - Each value should be one of the valid types:
      - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model`

    ```python
    class Config(InvocationConfig):
      schema_extra = {
          "ui": {
              "tags": ["stable-diffusion", "image"],
              "type_hints": {
                  "initial_image": "image",
              },
          },
      }
    ```
    """

    schema_extra: CustomisedSchemaExtra