import re
from typing import List, Tuple

import invokeai.backend.util.logging as logger
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.textual_inversion import TextualInversionModelRaw


def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
    ti_triggers: List[str] = []
    for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
        ti_triggers.append(str(trigger))
    return ti_triggers


def generate_ti_list(
    prompt: str, base: BaseModelType, context: InvocationContext
) -> List[Tuple[str, TextualInversionModelRaw]]:
    ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
    for trigger in extract_ti_triggers_from_prompt(prompt):
        name_or_key = trigger[1:-1]
        try:
            loaded_model = context.models.load(name_or_key)
            model = loaded_model.model
            assert isinstance(model, TextualInversionModelRaw)
            assert loaded_model.config.base == base
            ti_list.append((name_or_key, model))
        except UnknownModelException:
            try:
                loaded_model = context.models.load_by_attrs(
                    name=name_or_key, base=base, type=ModelType.TextualInversion
                )
                model = loaded_model.model
                assert isinstance(model, TextualInversionModelRaw)
                assert loaded_model.config.base == base
                ti_list.append((name_or_key, model))
            except UnknownModelException:
                pass
        except ValueError:
            logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
        except AssertionError:
            logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
        except Exception:
            logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
    return ti_list