"""
Readline helper functions for cli_app.py
You may import the global singleton `completer` to get access to the
completer object.
"""
import atexit
import readline
import shlex

from pathlib import Path
from typing import List, Dict, Literal, get_args, get_type_hints, get_origin

import invokeai.backend.util.logging as logger
from ...backend import ModelManager
from ..invocations.baseinvocation import BaseInvocation
from .commands import BaseCommand
from ..services.invocation_services import InvocationServices

# singleton object, class variable
completer = None

class Completer(object):
    
    def __init__(self, model_manager: ModelManager):
        self.commands = self.get_commands()
        self.matches = None
        self.linebuffer = None
        self.manager = model_manager
        return

    def complete(self, text, state):
        """
        Complete commands and switches fromm the node CLI command line.
        Switches are determined in a context-specific manner.
        """

        buffer = readline.get_line_buffer()
        if state == 0:
            options = None
            try:
                current_command, current_switch = self.get_current_command(buffer)
                options = self.get_command_options(current_command, current_switch)
            except IndexError:
                pass
            options = options or list(self.parse_commands().keys())
            
            if not text:  # first time
                self.matches = options
            else:
                self.matches = [s for s in options if s and s.startswith(text)]

        try:
            match = self.matches[state]
        except IndexError:
            match = None
        return match

    @classmethod
    def get_commands(self)->List[object]:
        """
        Return a list of all the client commands and invocations.
        """
        return BaseCommand.get_commands() + BaseInvocation.get_invocations()

    def get_current_command(self, buffer: str)->tuple[str, str]:
        """
        Parse the readline buffer to find the most recent command and its switch.
        """
        if len(buffer)==0:
            return None, None
        tokens = shlex.split(buffer)
        command = None
        switch = None
        for t in tokens:
            if t[0].isalpha():
                if switch is None:
                    command = t
            else:
                switch = t
        # don't try to autocomplete switches that are already complete
        if switch and buffer.endswith(' '):
            switch=None
        return command or '', switch or ''

    def parse_commands(self)->Dict[str, List[str]]:
        """
        Return a dict in which the keys are the command name
        and the values are the parameters the command takes.
        """
        result = dict()
        for command in self.commands:
            hints = get_type_hints(command)
            name = get_args(hints['type'])[0]
            result.update({name:hints})
        return result

    def get_command_options(self, command: str, switch: str)->List[str]:
        """
        Return all the parameters that can be passed to the command as
        command-line switches. Returns None if the command is unrecognized.
        """
        parsed_commands = self.parse_commands()
        if command not in parsed_commands:
            return None
        
        # handle switches in the format "-foo=bar"
        argument = None
        if switch and '=' in switch:
            switch, argument = switch.split('=')
            
        parameter = switch.strip('-')
        if parameter in parsed_commands[command]:
            if argument is None:
                return self.get_parameter_options(parameter, parsed_commands[command][parameter])
            else:
                return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
        else:
            return [f"--{x}" for x in parsed_commands[command].keys()]

    def get_parameter_options(self, parameter: str, typehint)->List[str]:
        """
        Given a parameter type (such as Literal), offers autocompletions.
        """
        if get_origin(typehint) == Literal:
            return get_args(typehint)
        if parameter == 'model':
            return self.manager.model_names()
        
    def _pre_input_hook(self):
        if self.linebuffer:
            readline.insert_text(self.linebuffer)
            readline.redisplay()
            self.linebuffer = None
    
def set_autocompleter(services: InvocationServices) -> Completer:
    global completer
    
    if completer:
        return completer
    
    completer = Completer(services.model_manager)

    readline.set_completer(completer.complete)
    # pyreadline3 does not have a set_auto_history() method
    try:
        readline.set_auto_history(True)
    except:
        pass
    readline.set_pre_input_hook(completer._pre_input_hook)
    readline.set_completer_delims(" ")
    readline.parse_and_bind("tab: complete")
    readline.parse_and_bind("set print-completions-horizontally off")
    readline.parse_and_bind("set page-completions on")
    readline.parse_and_bind("set skip-completed-text on")
    readline.parse_and_bind("set show-all-if-ambiguous on")

    histfile = Path(services.configuration.root_dir / ".invoke_history")
    try:
        readline.read_history_file(histfile)
        readline.set_history_length(1000)
    except FileNotFoundError:
        pass
    except OSError:  # file likely corrupted
        newname = f"{histfile}.old"
        logger.error(
            f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
        )
        histfile.replace(Path(newname))
    atexit.register(readline.write_history_file, histfile)