"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex
from pathlib import Path
from typing import Dict, List, Literal, get_args, get_origin, get_type_hints

import invokeai.backend.util.logging as logger

from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from ..services.invocation_services import InvocationServices
from .commands import BaseCommand

# singleton object, class variable
completer = None


class Completer(object):
    def __init__(self, model_manager: ModelManager):
        self.commands = self.get_commands()
        self.matches = None
        self.linebuffer = None
        self.manager = model_manager
        return

    def complete(self, text, state):
        """
        Complete commands and switches fromm the node CLI command line.
        Switches are determined in a context-specific manner.
        """

        buffer = readline.get_line_buffer()
        if state == 0:
            options = None
            try:
                current_command, current_switch = self.get_current_command(buffer)
                options = self.get_command_options(current_command, current_switch)
            except IndexError:
                pass
            options = options or list(self.parse_commands().keys())

            if not text:  # first time
                self.matches = options
            else:
                self.matches = [s for s in options if s and s.startswith(text)]

        try:
            match = self.matches[state]
        except IndexError:
            match = None
        return match

    @classmethod
    def get_commands(self) -> List[object]:
        """
        Return a list of all the client commands and invocations.
        """
        return BaseCommand.get_commands() + BaseInvocation.get_invocations()

    def get_current_command(self, buffer: str) -> tuple[str, str]:
        """
        Parse the readline buffer to find the most recent command and its switch.
        """
        if len(buffer) == 0:
            return None, None
        tokens = shlex.split(buffer)
        command = None
        switch = None
        for t in tokens:
            if t[0].isalpha():
                if switch is None:
                    command = t
            else:
                switch = t
        # don't try to autocomplete switches that are already complete
        if switch and buffer.endswith(" "):
            switch = None
        return command or "", switch or ""

    def parse_commands(self) -> Dict[str, List[str]]:
        """
        Return a dict in which the keys are the command name
        and the values are the parameters the command takes.
        """
        result = dict()
        for command in self.commands:
            hints = get_type_hints(command)
            name = get_args(hints["type"])[0]
            result.update({name: hints})
        return result

    def get_command_options(self, command: str, switch: str) -> List[str]:
        """
        Return all the parameters that can be passed to the command as
        command-line switches. Returns None if the command is unrecognized.
        """
        parsed_commands = self.parse_commands()
        if command not in parsed_commands:
            return None

        # handle switches in the format "-foo=bar"
        argument = None
        if switch and "=" in switch:
            switch, argument = switch.split("=")

        parameter = switch.strip("-")
        if parameter in parsed_commands[command]:
            if argument is None:
                return self.get_parameter_options(parameter, parsed_commands[command][parameter])
            else:
                return [
                    f"--{parameter}={x}"
                    for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
                ]
        else:
            return [f"--{x}" for x in parsed_commands[command].keys()]

    def get_parameter_options(self, parameter: str, typehint) -> List[str]:
        """
        Given a parameter type (such as Literal), offers autocompletions.
        """
        if get_origin(typehint) == Literal:
            return get_args(typehint)
        if parameter == "model":
            return self.manager.model_names()

    def _pre_input_hook(self):
        if self.linebuffer:
            readline.insert_text(self.linebuffer)
            readline.redisplay()
            self.linebuffer = None


def set_autocompleter(services: InvocationServices) -> Completer:
    global completer

    if completer:
        return completer

    completer = Completer(services.model_manager)

    readline.set_completer(completer.complete)
    try:
        readline.set_auto_history(True)
    except AttributeError:
        # pyreadline3 does not have a set_auto_history() method
        pass
    readline.set_pre_input_hook(completer._pre_input_hook)
    readline.set_completer_delims(" ")
    readline.parse_and_bind("tab: complete")
    readline.parse_and_bind("set print-completions-horizontally off")
    readline.parse_and_bind("set page-completions on")
    readline.parse_and_bind("set skip-completed-text on")
    readline.parse_and_bind("set show-all-if-ambiguous on")

    histfile = Path(services.configuration.root_dir / ".invoke_history")
    try:
        readline.read_history_file(histfile)
        readline.set_history_length(1000)
    except FileNotFoundError:
        pass
    except OSError:  # file likely corrupted
        newname = f"{histfile}.old"
        logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
        histfile.replace(Path(newname))
    atexit.register(readline.write_history_file, histfile)