diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index dcbdbec82d..8dbeaa3d05 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,75 +1,30 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2024 Lincoln Stein -from typing import Literal, Optional, Union -from fastapi import Query, Body -from fastapi.routing import APIRouter, HTTPException -from pydantic import BaseModel, Field, parse_obj_as -from ..dependencies import ApiDependencies +from typing import Literal, List, Optional, Union + +from fastapi import Body, Path, Query, Response +from fastapi.routing import APIRouter +from pydantic import BaseModel, parse_obj_as +from starlette.exceptions import HTTPException + from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management import AddModelResult -from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType -MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)] +from invokeai.backend.model_management.models import ( + OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType, +) +from invokeai.backend.model_management import MergeInterpolationMethod +from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) -class VaeRepo(BaseModel): - repo_id: str = Field(description="The repo ID to use for this VAE") - path: Optional[str] = Field(description="The path to the VAE") - subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") - -class ModelInfo(BaseModel): - description: Optional[str] = Field(description="A description of the model") - model_name: str = Field(description="The name of the model") - model_type: str = Field(description="The type of the model") - -class DiffusersModelInfo(ModelInfo): - format: Literal['folder'] = 'folder' - - vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model") - repo_id: Optional[str] = Field(description="The repo ID to use for this model") - path: Optional[str] = Field(description="The path to the model") - -class CkptModelInfo(ModelInfo): - format: Literal['ckpt'] = 'ckpt' - - config: str = Field(description="The path to the model config") - weights: str = Field(description="The path to the model weights") - vae: str = Field(description="The path to the model VAE") - width: Optional[int] = Field(description="The width of the model") - height: Optional[int] = Field(description="The height of the model") - -class SafetensorsModelInfo(CkptModelInfo): - format: Literal['safetensors'] = 'safetensors' - -class CreateModelRequest(BaseModel): - name: str = Field(description="The name of the model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") - -class CreateModelResponse(BaseModel): - name: str = Field(description="The name of the new model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") - status: str = Field(description="The status of the API response") - -class ImportModelResponse(BaseModel): - name: str = Field(description="The name of the imported model") -# base_model: str = Field(description="The base model") -# model_type: str = Field(description="The model type") - info: AddModelResult = Field(description="The model info") - status: str = Field(description="The status of the API response") - -class ConversionRequest(BaseModel): - name: str = Field(description="The name of the new model") - info: CkptModelInfo = Field(description="The converted model info") - save_location: str = Field(description="The path to save the converted model weights") - -class ConvertedModelResponse(BaseModel): - name: str = Field(description="The name of the new model") - info: DiffusersModelInfo = Field(description="The converted model info") +UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] +MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] class ModelsList(BaseModel): - models: list[MODEL_CONFIGS] - + models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @models_router.get( "/", @@ -77,75 +32,103 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: Optional[BaseModelType] = Query( - default=None, description="Base model" - ), - model_type: Optional[ModelType] = Query( - default=None, description="The type of model to get" - ), + base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models = parse_obj_as(ModelsList, { "models": models_raw }) return models -@models_router.post( - "/", +@models_router.patch( + "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"status": "success"}}, + responses={200: {"description" : "The model was updated successfully"}, + 404: {"description" : "The model could not be found"}, + 400: {"description" : "Bad request"} + }, + status_code = 200, + response_model = UpdateModelResponse, ) async def update_model( - model_request: CreateModelRequest -) -> CreateModelResponse: + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), +) -> UpdateModelResponse: """ Add Model """ - model_request_info = model_request.info - info_dict = model_request_info.dict() - model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") - - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_request.name, - model_attributes=info_dict, - clobber=True, - ) + try: + ApiDependencies.invoker.services.model_manager.update_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + model_attributes=info.dict() + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ) + model_response = parse_obj_as(UpdateModelResponse, model_raw) + except KeyError as e: + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) return model_response @models_router.post( - "/import", + "/", operation_id="import_model", responses= { 201: {"description" : "The model imported successfully"}, 404: {"description" : "The model could not be found"}, + 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description" : "There is already a model corresponding to this path or repo_id"}, }, status_code=201, response_model=ImportModelResponse ) async def import_model( - name: str = Query(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = Query(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), + location: str = Body(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ + Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: """ Add a model using its local path, repo_id, or remote URL """ - items_to_import = {name} + + items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger - - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) - ) - if info := installed_models.get(name): - logger.info(f'Successfully imported {name}, got {info}') - return ImportModelResponse( - name = name, - info = info, - status = "success", + + try: + installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( + items_to_import = items_to_import, + prediction_type_helper = lambda x: prediction_types.get(prediction_type) ) - else: - logger.error(f'Model {name} not imported') - raise HTTPException(status_code=404, detail=f'Model {name} not found') + info = installed_models.get(location) + + if not info: + logger.error("Import failed") + raise HTTPException(status_code=424) + + logger.info(f'Successfully imported {location}, got {info}') + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name=info.name, + base_model=info.base_model, + model_type=info.model_type + ) + return parse_obj_as(ImportModelResponse, model_raw) + + except KeyError as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) + except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + @models_router.delete( - "/{model_name}", + "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ 204: { @@ -156,144 +139,95 @@ async def import_model( } }, ) -async def delete_model(model_name: str) -> None: +async def delete_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> Response: """Delete Model""" - model_names = ApiDependencies.invoker.services.model_manager.model_names() logger = ApiDependencies.invoker.services.logger - model_exists = model_name in model_names - - # check if model exists - logger.info(f"Checking for model {model_name}...") - - if model_exists: - logger.info(f"Deleting Model: {model_name}") - ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) - logger.info(f"Model Deleted: {model_name}") - raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") - else: - logger.error("Model not found") + try: + ApiDependencies.invoker.services.model_manager.del_model(model_name, + base_model = base_model, + model_type = model_type + ) + logger.info(f"Deleted model: {model_name}") + return Response(status_code=204) + except KeyError: + logger.error(f"Model not found: {model_name}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - - # @socketio.on("convertToDiffusers") - # def convert_to_diffusers(model_to_convert: dict): - # try: - # if model_info := self.generate.model_manager.model_info( - # model_name=model_to_convert["model_name"] - # ): - # if "weights" in model_info: - # ckpt_path = Path(model_info["weights"]) - # original_config_file = Path(model_info["config"]) - # model_name = model_to_convert["model_name"] - # model_description = model_info["description"] - # else: - # self.socketio.emit( - # "error", {"message": "Model is not a valid checkpoint file"} - # ) - # else: - # self.socketio.emit( - # "error", {"message": "Could not retrieve model info."} - # ) - - # if not ckpt_path.is_absolute(): - # ckpt_path = Path(Globals.root, ckpt_path) - - # if original_config_file and not original_config_file.is_absolute(): - # original_config_file = Path(Globals.root, original_config_file) - - # diffusers_path = Path( - # ckpt_path.parent.absolute(), f"{model_name}_diffusers" - # ) - - # if model_to_convert["save_location"] == "root": - # diffusers_path = Path( - # global_converted_ckpts_dir(), f"{model_name}_diffusers" - # ) - - # if ( - # model_to_convert["save_location"] == "custom" - # and model_to_convert["custom_location"] is not None - # ): - # diffusers_path = Path( - # model_to_convert["custom_location"], f"{model_name}_diffusers" - # ) - - # if diffusers_path.exists(): - # shutil.rmtree(diffusers_path) - - # self.generate.model_manager.convert_and_import( - # ckpt_path, - # diffusers_path, - # model_name=model_name, - # model_description=model_description, - # vae=None, - # original_config_file=original_config_file, - # commit_to_conf=opt.conf, - # ) - - # new_model_list = self.generate.model_manager.list_models() - # socketio.emit( - # "modelConverted", - # { - # "new_model_name": model_name, - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Model Converted: {model_name}") - # except Exception as e: - # self.handle_exceptions(e) - - # @socketio.on("mergeDiffusersModels") - # def merge_diffusers_models(model_merge_info: dict): - # try: - # models_to_merge = model_merge_info["models_to_merge"] - # model_ids_or_paths = [ - # self.generate.model_manager.model_name_or_path(x) - # for x in models_to_merge - # ] - # merged_pipe = merge_diffusion_models( - # model_ids_or_paths, - # model_merge_info["alpha"], - # model_merge_info["interp"], - # model_merge_info["force"], - # ) - - # dump_path = global_models_dir() / "merged_models" - # if model_merge_info["model_merge_save_path"] is not None: - # dump_path = Path(model_merge_info["model_merge_save_path"]) - - # os.makedirs(dump_path, exist_ok=True) - # dump_path = dump_path / model_merge_info["merged_model_name"] - # merged_pipe.save_pretrained(dump_path, safe_serialization=1) - - # merged_model_config = dict( - # model_name=model_merge_info["merged_model_name"], - # description=f'Merge of models {", ".join(models_to_merge)}', - # commit_to_conf=opt.conf, - # ) - - # if vae := self.generate.model_manager.config[models_to_merge[0]].get( - # "vae", None - # ): - # print(f">> Using configured VAE assigned to {models_to_merge[0]}") - # merged_model_config.update(vae=vae) - - # self.generate.model_manager.import_diffuser_model( - # dump_path, **merged_model_config - # ) - # new_model_list = self.generate.model_manager.list_models() - - # socketio.emit( - # "modelsMerged", - # { - # "merged_models": models_to_merge, - # "merged_model_name": model_merge_info["merged_model_name"], - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Models Merged: {models_to_merge}") - # print(f">> New Model Added: {model_merge_info['merged_model_name']}") - # except Exception as e: +@models_router.put( + "/convert/{base_model}/{model_type}/{model_name}", + operation_id="convert_model", + responses={ + 200: { "description": "Model converted successfully" }, + 400: {"description" : "Bad request" }, + 404: { "description": "Model not found" }, + }, + status_code = 200, + response_model = ConvertModelResponse, +) +async def convert_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> ConvertModelResponse: + """Convert a checkpoint model into a diffusers model""" + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Converting model: {model_name}") + ApiDependencies.invoker.services.model_manager.convert_model(model_name, + base_model = base_model, + model_type = model_type + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, + base_model = base_model, + model_type = model_type) + response = parse_obj_as(ConvertModelResponse, model_raw) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response + +@models_router.put( + "/merge/{base_model}", + operation_id="merge_models", + responses={ + 200: { "description": "Model converted successfully" }, + 400: { "description": "Incompatible models" }, + 404: { "description": "One or more models not found" }, + }, + status_code = 200, + response_model = MergeModelResponse, +) +async def merge_models( + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model"), + alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), + force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), +) -> MergeModelResponse: + """Convert a checkpoint model into a diffusers model""" + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Merging models: {model_names}") + result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, + base_model, + merged_model_name or "+".join(model_names), + alpha, + interp, + force) + model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, + base_model = base_model, + model_type = ModelType.Main, + ) + response = parse_obj_as(ConvertModelResponse, model_raw) + except KeyError: + raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return response diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index b25136c240..63c13f3460 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -2,22 +2,29 @@ from __future__ import annotations -import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING -from dataclasses import dataclass +from pydantic import Field +from typing import Optional, Union, Callable, List, Tuple, TYPE_CHECKING +from types import ModuleType -from invokeai.backend.model_management.model_manager import ( +from invokeai.backend.model_management import ( ModelManager, BaseModelType, ModelType, SubModelType, ModelInfo, + AddModelResult, + SchedulerPredictionType, + ModelMerger, + MergeInterpolationMethod, ) + + +import torch from invokeai.app.models.exceptions import CanceledException -from .config import InvokeAIAppConfig from ...backend.util import choose_precision, choose_torch_device +from .config import InvokeAIAppConfig if TYPE_CHECKING: from ..invocations.baseinvocation import BaseInvocation, InvocationContext @@ -30,7 +37,7 @@ class ModelManagerServiceBase(ABC): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -73,13 +80,7 @@ class ModelManagerServiceBase(ABC): def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ Given a model name returns a dict-like (OmegaConf) object describing it. - """ - pass - - @abstractmethod - def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: - """ - Returns a list of all the model names known. + Uses the exact format as the omegaconf stanza. """ pass @@ -101,7 +102,20 @@ class ModelManagerServiceBase(ABC): } """ pass + + @abstractmethod + def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Return information about the model using the same format as list_models() + """ + pass + @abstractmethod + def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: + """ + Returns a list of all the model names known. + """ + pass @abstractmethod def add_model( @@ -111,7 +125,7 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, model_attributes: dict, clobber: bool = False - ) -> None: + ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. @@ -121,6 +135,24 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyErrorException if the name does not already exist. + + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + pass + @abstractmethod def del_model( self, @@ -135,11 +167,32 @@ class ModelManagerServiceBase(ABC): """ pass + @abstractmethod + def convert_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + """ + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + pass + @abstractmethod def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -159,7 +212,27 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def commit(self, conf_file: Path = None) -> None: + def merge_models( + self, + model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), + base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + ) -> AddModelResult: + """ + Merge two to three diffusrs pipeline models and save as a new model. + :param model_names: List of 2-3 models to merge + :param base_model: Base model to use for all models + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + """ + pass + + @abstractmethod + def commit(self, conf_file: Optional[Path] = None) -> None: """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -173,7 +246,7 @@ class ModelManagerService(ModelManagerServiceBase): def __init__( self, config: InvokeAIAppConfig, - logger: types.ModuleType, + logger: ModuleType, ): """ Initialize with the path to the models.yaml config file. @@ -301,12 +374,19 @@ class ModelManagerService(ModelManagerServiceBase): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None ) -> list[dict]: - # ) -> dict: """ Return a list of models. """ return self.mgr.list_models(base_model, model_type) + def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: + """ + Return information about the model using the same format as list_models() + """ + return self.mgr.list_model(model_name=model_name, + base_model=base_model, + model_type=model_type) + def add_model( self, model_name: str, @@ -322,9 +402,28 @@ class ModelManagerService(ModelManagerServiceBase): with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ + self.logger.debug(f'add/update model {model_name}') return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) - + def update_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + model_attributes: dict, + ) -> AddModelResult: + """ + Update the named model with a dictionary of attributes. Will fail with a + KeyError exception if the name does not already exist. + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or + the model name is missing. Call commit() to write changes to disk. + """ + self.logger.debug(f'update model {model_name}') + if not self.model_exists(model_name, base_model, model_type): + raise KeyError(f"Unknown model {model_name}") + return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) + def del_model( self, model_name: str, @@ -336,8 +435,29 @@ class ModelManagerService(ModelManagerServiceBase): then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ + self.logger.debug(f'delete model {model_name}') self.mgr.del_model(model_name, base_model, model_type) + def convert_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + """ + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is not a checkpoint. It will + also raise a ValueError in the event that there is a similarly-named diffusers + directory already in place. + """ + self.logger.debug(f'convert model {model_name}') + return self.mgr.convert_model(model_name, base_model, model_type) def commit(self, conf_file: Optional[Path]=None): """ @@ -389,9 +509,9 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.logger def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. @@ -408,4 +528,35 @@ class ModelManagerService(ModelManagerServiceBase): of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. ''' - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + + def merge_models( + self, + model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), + base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + ) -> AddModelResult: + """ + Merge two to three diffusrs pipeline models and save as a new model. + :param model_names: List of 2-3 models to merge + :param base_model: Base model to use for all models + :param merged_model_name: Name of destination merged model + :param alpha: Alpha strength to apply to 2d and 3d model + :param interp: Interpolation method. None (default) + """ + merger = ModelMerger(self.mgr) + try: + result = merger.merge_diffusion_models_and_save( + model_names = model_names, + base_model = base_model, + merged_model_name = merged_model_name, + alpha = alpha, + interp = interp, + force = force, + ) + except AssertionError as e: + raise ValueError(e) + return result diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 8da0ab0bd2..c5f15a3ce9 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -173,15 +173,19 @@ class ModelInstall(object): # add requested models for path in selections.install_models: logger.info(f'Installing {path} [{job}/{jobs}]') - self.heuristic_import(path) + try: + self.heuristic_import(path) + except (ValueError, KeyError) as e: + logger.error(str(e)) job += 1 dlogging.set_verbosity(verbosity) self.mgr.commit() def heuristic_import(self, - model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None)->Dict[str, AddModelResult]: + model_path_id_or_url: Union[str,Path], + models_installed: Set[Path]=None, + )->Dict[str, AddModelResult]: ''' :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param models_installed: Set of installed models, used for recursive invocation @@ -194,62 +198,53 @@ class ModelInstall(object): # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url path = Path(model_path_id_or_url) - - try: - # checkpoint file, or similar - if path.is_file(): - models_installed.update(self._install_path(path)) + # checkpoint file, or similar + if path.is_file(): + models_installed.update({str(path):self._install_path(path)}) - # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in \ - {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} - ] - ): - models_installed.update(self._install_path(path)) + # folders style or similar + elif path.is_dir() and any([(path/x).exists() for x in \ + {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} + ] + ): + models_installed.update(self._install_path(path)) - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) + # recursive scan + elif path.is_dir(): + for child in path.iterdir(): + self.heuristic_import(child, models_installed=models_installed) - # huggingface repo - elif len(str(model_path_id_or_url).split('/')) == 2: - models_installed.update(self._install_repo(str(model_path_id_or_url))) + # huggingface repo + elif len(str(model_path_id_or_url).split('/')) == 2: + models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - # a URL - elif model_path_id_or_url.startswith(("http:", "https:", "ftp:")): - models_installed.update(self._install_url(model_path_id_or_url)) + # a URL + elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): + models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) - else: - logger.warning(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') - - except ValueError as e: - logger.error(str(e)) + else: + raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') return models_installed # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->Dict[str, AddModelResult]: - try: - model_result = None - info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) - model_name = path.stem if path.is_file() else path.name - if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path,info) - model_result = self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) - except Exception as e: - logger.warning(f'{str(e)} Skipping registration.') - return {} - return {str(path): model_result} + def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult: + info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) + if not info: + logger.warning(f'Unable to parse format of {path}') + return None + model_name = path.stem if path.is_file() else path.name + if self.mgr.model_exists(model_name, info.base_type, info.model_type): + raise ValueError(f'A model named "{model_name}" is already installed.') + attributes = self._make_attributes(path,info) + return self.mgr.add_model(model_name = model_name, + base_model = info.base_type, + model_type = info.model_type, + model_attributes = attributes, + ) - def _install_url(self, url: str)->dict: - # copy to a staging area, probe, import and delete + def _install_url(self, url: str)->AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url,Path(staging)) if not location: @@ -261,7 +256,7 @@ class ModelInstall(object): # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->dict: + def _install_repo(self, repo_id: str)->AddModelResult: hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 34e0b15728..e31085acef 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,7 +1,8 @@ """ Initialization file for invokeai.backend.model_management """ -from .model_manager import ModelManager, ModelInfo, AddModelResult +from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_cache import ModelCache from .models import BaseModelType, ModelType, SubModelType, ModelVariantType +from .model_merge import ModelMerger, MergeInterpolationMethod diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index ae576e39d9..e98d71e85c 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -2,16 +2,14 @@ from __future__ import annotations import copy from contextlib import contextmanager +from typing import Optional, Dict, Tuple, Any, Union, List from pathlib import Path -from typing import Any, Dict, Optional, Tuple import torch from compel.embeddings_provider import BaseTextualInversionManager from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file -from torch.utils.hooks import RemovableHandle -from transformers import CLIPTextModel - +from transformers import CLIPTextModel, CLIPTokenizer class LoRALayerBase: #rank: Optional[int] @@ -123,8 +121,8 @@ class LoRALayer(LoRALayerBase): def get_weight(self): if self.mid is not None: - up = self.up.reshape(up.shape[0], up.shape[1]) - down = self.down.reshape(up.shape[0], up.shape[1]) + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) @@ -410,7 +408,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format print( - f">> Encountered unknown lora layer module in {self.name}: {layer_key}" + f">> Encountered unknown lora layer module in {model.name}: {layer_key}" ) return diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 35250c73da..d092e05c05 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -234,7 +234,7 @@ import textwrap from dataclasses import dataclass from pathlib import Path from typing import Optional, List, Tuple, Union, Dict, Set, Callable, types -from shutil import rmtree +from shutil import rmtree, move import torch from omegaconf import OmegaConf @@ -279,7 +279,7 @@ class InvalidModelError(Exception): pass class AddModelResult(BaseModel): - name: str = Field(description="The name of the model after import") + name: str = Field(description="The name of the model after installation") model_type: ModelType = Field(description="The type of model") base_model: BaseModelType = Field(description="The base model") config: ModelConfigBase = Field(description="The configuration of the model") @@ -490,17 +490,32 @@ class ModelManager(object): """ return [(self.parse_key(x)) for x in self.models.keys()] + def list_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + ) -> dict: + """ + Returns a dict describing one installed model, using + the combined format of the list_models() method. + """ + models = self.list_models(base_model,model_type,model_name) + return models[0] if models else None + def list_models( self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_name: Optional[str] = None, ) -> list[dict]: """ Return a list of models. """ + model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) models = [] - for model_key in sorted(self.models, key=str.casefold): + for model_key in model_keys: model_config = self.models[model_key] cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) @@ -545,10 +560,7 @@ class ModelManager(object): model_cfg = self.models.pop(model_key, None) if model_cfg is None: - self.logger.error( - f"Unknown model {model_key}" - ) - return + raise KeyError(f"Unknown model {model_key}") # note: it not garantie to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) @@ -614,6 +626,7 @@ class ModelManager(object): self.cache.uncache_model(cache_id) self.models[model_key] = model_config + self.commit() return AddModelResult( name = model_name, model_type = model_type, @@ -621,6 +634,60 @@ class ModelManager(object): config = model_config, ) + def convert_model ( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main,ModelType.Vae], + ) -> AddModelResult: + ''' + Convert a checkpoint file into a diffusers folder, deleting the cached + version and deleting the original checkpoint file if it is in the models + directory. + :param model_name: Name of the model to convert + :param base_model: Base model type + :param model_type: Type of model ['vae' or 'main'] + + This will raise a ValueError unless the model is a checkpoint. + ''' + info = self.model_info(model_name, base_model, model_type) + if info["model_format"] != "checkpoint": + raise ValueError(f"not a checkpoint format model: {model_name}") + + # We are taking advantage of a side effect of get_model() that converts check points + # into cached diffusers directories stored at `location`. It doesn't matter + # what submodeltype we request here, so we get the smallest. + submodel = {"submodel_type": SubModelType.Tokenizer} if model_type==ModelType.Main else {} + model = self.get_model(model_name, + base_model, + model_type, + **submodel, + ) + checkpoint_path = self.app_config.root_path / info["path"] + old_diffusers_path = self.app_config.models_path / model.location + new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name + if new_diffusers_path.exists(): + raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") + + try: + move(old_diffusers_path,new_diffusers_path) + info["model_format"] = "diffusers" + info["path"] = str(new_diffusers_path.relative_to(self.app_config.root_path)) + info.pop('config') + + result = self.add_model(model_name, base_model, model_type, + model_attributes = info, + clobber=True) + except: + # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! + rmtree(new_diffusers_path) + raise + + if checkpoint_path.exists() and checkpoint_path.is_relative_to(self.app_config.models_path): + checkpoint_path.unlink() + + return result + def search_models(self, search_folder): self.logger.info(f"Finding Models In: {search_folder}") models_folder_ckpt = Path(search_folder).glob("**/*.ckpt") @@ -821,6 +888,10 @@ class ModelManager(object): The result is a set of successfully installed models. Each element of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. + + May return the following exceptions: + - KeyError - one or more of the items to import is not a valid path, repo_id or URL + - ValueError - a corresponding model already exists ''' # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall @@ -830,11 +901,7 @@ class ModelManager(object): prediction_type_helper = prediction_type_helper, model_manager = self) for thing in items_to_import: - try: - installed = installer.heuristic_import(thing) - successfully_installed.update(installed) - except Exception as e: - self.logger.warning(f'{thing} could not be imported: {str(e)}') - + installed = installer.heuristic_import(thing) + successfully_installed.update(installed) self.commit() return successfully_installed diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py new file mode 100644 index 0000000000..39f951d2b4 --- /dev/null +++ b/invokeai/backend/model_management/model_merge.py @@ -0,0 +1,131 @@ +""" +invokeai.backend.model_management.model_merge exports: +merge_diffusion_models() -- combine multiple models by location and return a pipeline object +merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to models.yaml + +Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team +""" + +import warnings +from enum import Enum +from pathlib import Path +from diffusers import DiffusionPipeline +from diffusers import logging as dlogging +from typing import List, Union + +import invokeai.backend.util.logging as logger + +from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult + +class MergeInterpolationMethod(str, Enum): + WeightedSum = "weighted_sum" + Sigmoid = "sigmoid" + InvSigmoid = "inv_sigmoid" + AddDifference = "add_difference" + +class ModelMerger(object): + def __init__(self, manager: ModelManager): + self.manager = manager + + def merge_diffusion_models( + self, + model_paths: List[Path], + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + **kwargs, + ) -> DiffusionPipeline: + """ + :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids + :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. + :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + **kwargs - the default DiffusionPipeline.get_config_dict kwargs: + cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + verbosity = dlogging.get_verbosity() + dlogging.set_verbosity_error() + + pipe = DiffusionPipeline.from_pretrained( + model_paths[0], + custom_pipeline="checkpoint_merger", + ) + merged_pipe = pipe.merge( + pretrained_model_name_or_path_list=model_paths, + alpha=alpha, + interp=interp.value if interp else None, #diffusers API treats None as "weighted sum" + force=force, + **kwargs, + ) + dlogging.set_verbosity(verbosity) + return merged_pipe + + + def merge_diffusion_models_and_save ( + self, + model_names: List[str], + base_model: Union[BaseModelType,str], + merged_model_name: str, + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + **kwargs, + ) -> AddModelResult: + """ + :param models: up to three models, designated by their InvokeAI models.yaml model name + :param base_model: base model (must be the same for all merged models!) + :param merged_model_name: name for new model + :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha + would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 + :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. + Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). + :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. + + **kwargs - the default DiffusionPipeline.get_config_dict kwargs: + cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map + """ + model_paths = list() + config = self.manager.app_config + base_model = BaseModelType(base_model) + vae = None + + for mod in model_names: + info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) + assert info, f"model {mod}, base_model {base_model}, is unknown" + assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging" + assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" + assert len(model_names) <= 2 or \ + interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported" + # pick up the first model's vae + if mod == model_names[0]: + vae = info.get("vae") + model_paths.extend([config.root_path / info["path"]]) + + merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp) + logger.debug(f'interp = {interp}, merge_method={merge_method}') + merged_pipe = self.merge_diffusion_models( + model_paths, alpha, merge_method, force, **kwargs + ) + dump_path = config.models_path / base_model.value / ModelType.Main.value + dump_path.mkdir(parents=True, exist_ok=True) + dump_path = dump_path / merged_model_name + + merged_pipe.save_pretrained(dump_path, safe_serialization=1) + attributes = dict( + path = str(dump_path), + description = f"Merge of models {', '.join(model_names)}", + model_format = "diffusers", + variant = ModelVariantType.Normal.value, + vae = vae, + ) + return self.manager.add_model(merged_model_name, + base_model = base_model, + model_type = ModelType.Main, + model_attributes = attributes, + clobber = True + ) diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index a5d43c98a2..c98d5a0ae8 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -116,7 +116,7 @@ class StableDiffusion1Model(DiffusersModel): version=BaseModelType.StableDiffusion1, model_config=config, output_path=output_path, - ) + ) else: return model_path diff --git a/invokeai/frontend/merge/__init__.py b/invokeai/frontend/merge/__init__.py index fb892fd7db..f1fc66c39e 100644 --- a/invokeai/frontend/merge/__init__.py +++ b/invokeai/frontend/merge/__init__.py @@ -1,4 +1,5 @@ """ Initialization file for invokeai.frontend.merge """ -from .merge_diffusers import main as invokeai_merge_diffusers, merge_diffusion_models +from .merge_diffusers import main as invokeai_merge_diffusers + diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index 9da04b97f8..c20d913883 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -6,9 +6,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team """ import argparse import curses -import os import sys -import warnings from argparse import Namespace from pathlib import Path from typing import List, Union @@ -20,99 +18,15 @@ from npyscreen import widget from omegaconf import OmegaConf import invokeai.backend.util.logging as logger -from invokeai.services.config import InvokeAIAppConfig -from ...backend.model_management import ModelManager -from ...frontend.install.widgets import FloatTitleSlider +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_management import ( + ModelMerger, MergeInterpolationMethod, + ModelManager, ModelType, BaseModelType, +) +from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns -DEST_MERGED_MODEL_DIR = "merged_models" config = InvokeAIAppConfig.get_config() -def merge_diffusion_models( - model_ids_or_paths: List[Union[str, Path]], - alpha: float = 0.5, - interp: str = None, - force: bool = False, - **kwargs, -) -> DiffusionPipeline: - """ - model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids - alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipe = DiffusionPipeline.from_pretrained( - model_ids_or_paths[0], - cache_dir=kwargs.get("cache_dir", config.cache_dir), - custom_pipeline="checkpoint_merger", - ) - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_ids_or_paths, - alpha=alpha, - interp=interp, - force=force, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - -def merge_diffusion_models_and_commit( - models: List["str"], - merged_model_name: str, - alpha: float = 0.5, - interp: str = None, - force: bool = False, - **kwargs, -): - """ - models - up to three models, designated by their InvokeAI models.yaml model name - merged_model_name = name for new model - alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - interp - The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - config_file = config.model_conf_path - model_manager = ModelManager(OmegaConf.load(config_file)) - for mod in models: - assert mod in model_manager.model_names(), f'** Unknown model "{mod}"' - assert ( - model_manager.model_info(mod).get("format", None) == "diffusers" - ), f"** {mod} is not a diffusers model. It must be optimized before merging." - model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models] - - merged_pipe = merge_diffusion_models( - model_ids_or_paths, alpha, interp, force, **kwargs - ) - dump_path = config.models_dir / DEST_MERGED_MODEL_DIR - - os.makedirs(dump_path, exist_ok=True) - dump_path = dump_path / merged_model_name - merged_pipe.save_pretrained(dump_path, safe_serialization=1) - import_args = dict( - model_name=merged_model_name, description=f'Merge of models {", ".join(models)}' - ) - if vae := model_manager.config[models[0]].get("vae", None): - logger.info(f"Using configured VAE assigned to {models[0]}") - import_args.update(vae=vae) - model_manager.import_diffuser_model(dump_path, **import_args) - model_manager.commit(config_file) - - def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") parser.add_argument( @@ -131,10 +45,17 @@ def _parse_args() -> Namespace: ) parser.add_argument( "--models", + dest="model_names", type=str, nargs="+", help="Two to three model names to be merged", ) + parser.add_argument( + "--base_model", + type=str, + choices=[x.value for x in BaseModelType], + help="The base model shared by the models to be merged", + ) parser.add_argument( "--merged_model_name", "--destination", @@ -192,6 +113,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): window_height, window_width = curses.initscr().getmaxyx() self.model_names = self.get_model_names() + self.current_base = 0 max_width = max([len(x) for x in self.model_names]) max_width += 6 horizontal_layout = max_width * 3 < window_width @@ -208,12 +130,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): value="Use up and down arrows to move, to select an item, and to move from one field to the next.", editable=False, ) + self.nextrely += 1 + self.base_select = self.add_widget_intelligent( + SingleSelectColumns, + values=[ + 'Models Built on SD-1.x', + 'Models Built on SD-2.x', + ], + value=[self.current_base], + columns = 4, + max_height = 2, + relx=8, + scroll_exit = True, + ) + self.base_select.on_changed = self._populate_models self.add_widget_intelligent( npyscreen.FixedText, value="MODEL 1", color="GOOD", editable=False, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) self.model1 = self.add_widget_intelligent( npyscreen.SelectOne, @@ -222,7 +158,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_height=len(self.model_names), max_width=max_width, scroll_exit=True, - rely=5, + rely=7, ) self.add_widget_intelligent( npyscreen.FixedText, @@ -230,7 +166,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): color="GOOD", editable=False, relx=max_width + 3 if horizontal_layout else None, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) self.model2 = self.add_widget_intelligent( npyscreen.SelectOne, @@ -240,7 +176,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_height=len(self.model_names), max_width=max_width, relx=max_width + 3 if horizontal_layout else None, - rely=5 if horizontal_layout else None, + rely=7 if horizontal_layout else None, scroll_exit=True, ) self.add_widget_intelligent( @@ -249,7 +185,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): color="GOOD", editable=False, relx=max_width * 2 + 3 if horizontal_layout else None, - rely=4 if horizontal_layout else None, + rely=6 if horizontal_layout else None, ) models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") @@ -262,24 +198,26 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): max_width=max_width, scroll_exit=True, relx=max_width * 2 + 3 if horizontal_layout else None, - rely=5 if horizontal_layout else None, + rely=7 if horizontal_layout else None, ) for m in [self.model1, self.model2, self.model3]: m.when_value_edited = self.models_changed self.merged_model_name = self.add_widget_intelligent( - npyscreen.TitleText, + TextBox, name="Name for merged model:", labelColor="CONTROL", + max_height=3, value="", scroll_exit=True, ) self.force = self.add_widget_intelligent( npyscreen.Checkbox, - name="Force merge of incompatible models", + name="Force merge of models created by different diffusers library versions", labelColor="CONTROL", - value=False, + value=True, scroll_exit=True, ) + self.nextrely += 1 self.merge_method = self.add_widget_intelligent( npyscreen.TitleSelectOne, name="Merge Method:", @@ -341,7 +279,8 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): interp = self.interpolations[self.merge_method.value[0]] args = dict( - models=models, + model_names=models, + base_model=tuple(BaseModelType)[self.base_select.value[0]], alpha=self.alpha.value, interp=interp, force=self.force.value, @@ -379,21 +318,30 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): else: return True - def get_model_names(self) -> List[str]: + def get_model_names(self, base_model: BaseModelType=None) -> List[str]: model_names = [ - name - for name in self.model_manager.model_names() - if self.model_manager.model_info(name).get("format") == "diffusers" + info["name"] + for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) + if info["model_format"] == "diffusers" ] return sorted(model_names) + def _populate_models(self,value=None): + base_model = tuple(BaseModelType)[value[0]] + self.model_names = self.get_model_names(base_model) + + models_plus_none = self.model_names.copy() + models_plus_none.insert(0, "None") + self.model1.values = self.model_names + self.model2.values = self.model_names + self.model3.values = models_plus_none + + self.display() + class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self): + def __init__(self, model_manager:ModelManager): super().__init__() - conf = OmegaConf.load(config.model_conf_path) - self.model_manager = ModelManager( - conf, "cpu", "float16" - ) # precision doesn't really matter here + self.model_manager = model_manager def onStart(self): npyscreen.setTheme(npyscreen.Themes.ElegantTheme) @@ -401,44 +349,41 @@ class Mergeapp(npyscreen.NPSAppManaged): def run_gui(args: Namespace): - mergeapp = Mergeapp() + model_manager = ModelManager(config.model_conf_path) + mergeapp = Mergeapp(model_manager) mergeapp.run() args = mergeapp.merge_arguments - merge_diffusion_models_and_commit(**args) + merger = ModelMerger(model_manager) + merger.merge_diffusion_models_and_save(**args) logger.info(f'Models merged into new model: "{args["merged_model_name"]}".') def run_cli(args: Namespace): assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1" assert ( - args.models and len(args.models) >= 1 and len(args.models) <= 3 + args.model_names and len(args.model_names) >= 1 and len(args.model_names) <= 3 ), "Please provide the --models argument to list 2 to 3 models to merge. Use --help for full usage." if not args.merged_model_name: - args.merged_model_name = "+".join(args.models) + args.merged_model_name = "+".join(args.model_names) logger.info( f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' ) - model_manager = ModelManager(OmegaConf.load(config.model_conf_path)) - assert ( - args.clobber or args.merged_model_name not in model_manager.model_names() - ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' + model_manager = ModelManager(config.model_conf_path) + assert ( + not model_manager.model_exists(args.merged_model_name, args.base_model, ModelType.Main) or args.clobber + ), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.' - merge_diffusion_models_and_commit(**vars(args)) - logger.info(f'Models merged into new model: "{args.merged_model_name}".') + merger = ModelMerger(model_manager) + merger.merge_diffusion_models_and_save(**vars(args)) + logger.info(f'Models merged into new model: "{args.merged_model_name}".') def main(): args = _parse_args() - config.root = args.root_dir - - cache_dir = config.cache_dir - os.environ[ - "HF_HOME" - ] = cache_dir # because not clear the merge pipeline is honoring cache_dir - args.cache_dir = cache_dir + config.parse_args(['--root',str(args.root_dir)]) try: if args.front_end: diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index cab4738373..fe4bce682b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -1,6 +1,5 @@ import { log } from 'app/logging/useLogger'; import { appSocketConnected, socketConnected } from 'services/events/actions'; -import { receivedPageOfImages } from 'services/api/thunks/image'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { startAppListening } from '../..'; @@ -14,19 +13,10 @@ export const addSocketConnectedEventListener = () => { moduleLog.debug({ timestamp }, 'Connected'); - const { nodes, config, gallery } = getState(); + const { nodes, config } = getState(); const { disabledTabs } = config; - if (!gallery.ids.length) { - dispatch( - receivedPageOfImages({ - categories: ['general'], - is_intermediate: false, - }) - ); - } - if (!nodes.schema && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 97e33f300b..9a0bc865a4 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -1,15 +1,16 @@ import { Tooltip, useColorMode, useToken } from '@chakra-ui/react'; import { MultiSelect, MultiSelectProps } from '@mantine/core'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { memo } from 'react'; +import { RefObject, memo } from 'react'; import { mode } from 'theme/util/mode'; type IAIMultiSelectProps = MultiSelectProps & { tooltip?: string; + inputRef?: RefObject; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, ...rest } = props; + const { searchable = true, tooltip, inputRef, ...rest } = props; const { base50, base100, @@ -33,6 +34,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( ({ label: { diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts index c75041eb6c..605aa8b162 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToInvoke.ts @@ -6,10 +6,15 @@ import { validateSeedWeights } from 'common/util/seedWeightPairs'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { systemSelector } from 'features/system/store/systemSelectors'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; +import { + modelsApi, + useGetMainModelsQuery, +} from '../../services/api/endpoints/models'; const readinessSelector = createSelector( [stateSelector, activeTabNameSelector], - ({ generation, system, batch }, activeTabName) => { + (state, activeTabName) => { + const { generation, system, batch } = state; const { shouldGenerateVariations, seedWeights, initialImage, seed } = generation; @@ -32,6 +37,13 @@ const readinessSelector = createSelector( reasonsWhyNotReady.push('No initial image selected'); } + const { isSuccess: mainModelsSuccessfullyLoaded } = + modelsApi.endpoints.getMainModels.select()(state); + if (!mainModelsSuccessfullyLoaded) { + isReady = false; + reasonsWhyNotReady.push('Models are not loaded'); + } + // TODO: job queue // Cannot generate if already processing an image if (isProcessing) { diff --git a/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx b/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx new file mode 100644 index 0000000000..1dae6f56e6 --- /dev/null +++ b/invokeai/frontend/web/src/features/embedding/components/AddEmbeddingButton.tsx @@ -0,0 +1,33 @@ +import IAIIconButton from 'common/components/IAIIconButton'; +import { memo } from 'react'; +import { BiCode } from 'react-icons/bi'; + +type Props = { + onClick: () => void; +}; + +const AddEmbeddingButton = (props: Props) => { + const { onClick } = props; + return ( + } + sx={{ + p: 2, + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + }} + variant="link" + onClick={onClick} + /> + ); +}; + +export default memo(AddEmbeddingButton); diff --git a/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx new file mode 100644 index 0000000000..3c2ded0166 --- /dev/null +++ b/invokeai/frontend/web/src/features/embedding/components/ParamEmbeddingPopover.tsx @@ -0,0 +1,151 @@ +import { + Flex, + Popover, + PopoverBody, + PopoverContent, + PopoverTrigger, + Text, +} from '@chakra-ui/react'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { + PropsWithChildren, + forwardRef, + useCallback, + useMemo, + useRef, +} from 'react'; +import { useGetTextualInversionModelsQuery } from 'services/api/endpoints/models'; +import { PARAMETERS_PANEL_WIDTH } from 'theme/util/constants'; + +type EmbeddingSelectItem = { + label: string; + value: string; + description?: string; +}; + +type Props = PropsWithChildren & { + onSelect: (v: string) => void; + isOpen: boolean; + onClose: () => void; +}; + +const ParamEmbeddingPopover = (props: Props) => { + const { onSelect, isOpen, onClose, children } = props; + const { data: embeddingQueryData } = useGetTextualInversionModelsQuery(); + const inputRef = useRef(null); + + const data = useMemo(() => { + if (!embeddingQueryData) { + return []; + } + + const data: EmbeddingSelectItem[] = []; + + forEach(embeddingQueryData.entities, (embedding, _) => { + if (!embedding) return; + + data.push({ + value: embedding.name, + label: embedding.name, + description: embedding.description, + }); + }); + + return data; + }, [embeddingQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + if (v.length === 0) { + return; + } + + onSelect(v[0]); + }, + [onSelect] + ); + + return ( + + {children} + + + {data.length === 0 ? ( + + + No Embeddings Loaded + + + ) : ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + )} + + + + ); +}; + +export default ParamEmbeddingPopover; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; diff --git a/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts b/invokeai/frontend/web/src/features/embedding/store/embeddingSlice.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx index ea0b3b0fd8..a8d4c84adc 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx @@ -23,6 +23,7 @@ export const makeSelector = (image_name: string) => ({ gallery }) => { const isSelected = gallery.selection.includes(image_name); const selectionCount = gallery.selection.length; + return { isSelected, selectionCount, @@ -117,7 +118,7 @@ const GalleryImage = (props: HoverableImageProps) => { resetIcon={} resetTooltip="Delete image" imageSx={{ w: 'full', h: 'full' }} - withResetIcon + // withResetIcon // removed bc it's too easy to accidentally delete images isDropDisabled={true} isUploadDisabled={true} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index 33edb303e3..a5fc653913 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -182,6 +182,15 @@ const ImageGalleryContent = () => { return () => osInstance()?.destroy(); }, [scroller, initialize, osInstance]); + useEffect(() => { + dispatch( + receivedPageOfImages({ + categories: ['general'], + is_intermediate: false, + }) + ); + }, [dispatch]); + const handleClickImagesCategory = useCallback(() => { dispatch(imageCategoriesChanged(IMAGE_CATEGORIES)); dispatch(setGalleryView('images')); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx index 23459e9410..4ca9700a8c 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -4,7 +4,12 @@ import IAIIconButton from 'common/components/IAIIconButton'; import IAISlider from 'common/components/IAISlider'; import { memo, useCallback } from 'react'; import { FaTrash } from 'react-icons/fa'; -import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; +import { + Lora, + loraRemoved, + loraWeightChanged, + loraWeightReset, +} from '../store/loraSlice'; type Props = { lora: Lora; @@ -22,7 +27,7 @@ const ParamLora = (props: Props) => { ); const handleReset = useCallback(() => { - dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); + dispatch(loraWeightReset(lora.id)); }, [dispatch, lora.id]); const handleRemoveLora = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index 54ac3d615d..9168814f35 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -1,4 +1,4 @@ -import { Text } from '@chakra-ui/react'; +import { Flex, Text } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; @@ -61,6 +61,16 @@ const ParamLoraSelect = () => { [dispatch, lorasQueryData?.entities] ); + if (lorasQueryData?.ids.length === 0) { + return ( + + + No LoRAs Loaded + + + ); + } + return ( = { - weight: 1, + weight: 0.75, }; export type LoraState = { @@ -38,9 +38,14 @@ export const loraSlice = createSlice({ const { id, weight } = action.payload; state.loras[id].weight = weight; }, + loraWeightReset: (state, action: PayloadAction) => { + const id = action.payload; + state.loras[id].weight = defaultLoRAConfig.weight; + }, }, }); -export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; +export const { loraAdded, loraRemoved, loraWeightChanged, loraWeightReset } = + loraSlice.actions; export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx index 589b751d6b..3e5320ad47 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamNegativeConditioning.tsx @@ -1,29 +1,107 @@ -import { FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import type { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; import { setNegativePrompt } from 'features/parameters/store/generationSlice'; +import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; +import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; const ParamNegativeConditioning = () => { const negativePrompt = useAppSelector( (state: RootState) => state.generation.negativePrompt ); - + const promptRef = useRef(null); + const { isOpen, onClose, onOpen } = useDisclosure(); const dispatch = useAppDispatch(); const { t } = useTranslation(); + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setNegativePrompt(e.target.value)); + }, + [dispatch] + ); + const handleKeyDown = useCallback( + (e: KeyboardEvent) => { + if (e.key === '<') { + onOpen(); + } + }, + [onOpen] + ); + + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = negativePrompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += negativePrompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setNegativePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, negativePrompt] + ); + return ( - dispatch(setNegativePrompt(e.target.value))} - placeholder={t('parameters.negativePromptPlaceholder')} - fontSize="sm" - minH={16} - /> + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx index f42942a84b..cbff29e89c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamPositiveConditioning.tsx @@ -1,4 +1,4 @@ -import { Box, FormControl } from '@chakra-ui/react'; +import { Box, FormControl, useDisclosure } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { ChangeEvent, KeyboardEvent, useCallback, useRef } from 'react'; @@ -11,12 +11,15 @@ import { } from 'features/parameters/store/generationSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { isEqual } from 'lodash-es'; -import { useHotkeys } from 'react-hotkeys-hook'; -import { useTranslation } from 'react-i18next'; import { userInvoked } from 'app/store/actions'; import IAITextarea from 'common/components/IAITextarea'; import { useIsReadyToInvoke } from 'common/hooks/useIsReadyToInvoke'; +import AddEmbeddingButton from 'features/embedding/components/AddEmbeddingButton'; +import ParamEmbeddingPopover from 'features/embedding/components/ParamEmbeddingPopover'; +import { isEqual } from 'lodash-es'; +import { flushSync } from 'react-dom'; +import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; const promptInputSelector = createSelector( [(state: RootState) => state.generation, activeTabNameSelector], @@ -40,14 +43,15 @@ const ParamPositiveConditioning = () => { const dispatch = useAppDispatch(); const { prompt, activeTabName } = useAppSelector(promptInputSelector); const isReady = useIsReadyToInvoke(); - const promptRef = useRef(null); - + const { isOpen, onClose, onOpen } = useDisclosure(); const { t } = useTranslation(); - - const handleChangePrompt = (e: ChangeEvent) => { - dispatch(setPositivePrompt(e.target.value)); - }; + const handleChangePrompt = useCallback( + (e: ChangeEvent) => { + dispatch(setPositivePrompt(e.target.value)); + }, + [dispatch] + ); useHotkeys( 'alt+a', @@ -57,6 +61,45 @@ const ParamPositiveConditioning = () => { [] ); + const handleSelectEmbedding = useCallback( + (v: string) => { + if (!promptRef.current) { + return; + } + + // this is where we insert the TI trigger + const caret = promptRef.current.selectionStart; + + if (caret === undefined) { + return; + } + + let newPrompt = prompt.slice(0, caret); + + if (newPrompt[newPrompt.length - 1] !== '<') { + newPrompt += '<'; + } + + newPrompt += `${v}>`; + + // we insert the cursor after the `>` + const finalCaretPos = newPrompt.length; + + newPrompt += prompt.slice(caret); + + // must flush dom updates else selection gets reset + flushSync(() => { + dispatch(setPositivePrompt(newPrompt)); + }); + + // set the caret position to just after the TI trigger + promptRef.current.selectionStart = finalCaretPos; + promptRef.current.selectionEnd = finalCaretPos; + onClose(); + }, + [dispatch, onClose, prompt] + ); + const handleKeyDown = useCallback( (e: KeyboardEvent) => { if (e.key === 'Enter' && e.shiftKey === false && isReady) { @@ -64,25 +107,50 @@ const ParamPositiveConditioning = () => { dispatch(clampSymmetrySteps()); dispatch(userInvoked(activeTabName)); } + if (e.key === '<') { + onOpen(); + } }, - [dispatch, activeTabName, isReady] + [isReady, dispatch, activeTabName, onOpen] ); + // const handleSelect = (e: MouseEvent) => { + // const target = e.target as HTMLTextAreaElement; + // setCaret({ start: target.selectionStart, end: target.selectionEnd }); + // }; + return ( - + + + + {!isOpen && ( + + + + )} ); }; diff --git a/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx b/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx index a742e2a587..30cc1d2158 100644 --- a/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx +++ b/invokeai/frontend/web/src/features/ui/components/PinParametersPanelButton.tsx @@ -1,4 +1,3 @@ -import { Tooltip } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton, { IAIIconButtonProps, @@ -25,26 +24,25 @@ const PinParametersPanelButton = (props: PinParametersPanelButtonProps) => { }; return ( - - : } - variant="ghost" - size="sm" - sx={{ - color: 'base.700', - _hover: { - color: 'base.550', - }, - _active: { - color: 'base.500', - }, - ...sx, - }} - /> - + : } + variant="ghost" + size="sm" + sx={{ + color: 'base.700', + _hover: { + color: 'base.550', + }, + _active: { + color: 'base.500', + }, + ...sx, + }} + /> ); }; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 38af668cac..861bf49405 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,10 +1,10 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; import { setActiveTabReducer } from './extraReducers'; import { InvokeTabName } from './tabMap'; import { AddNewModelType, UIState } from './uiTypes'; -import { SchedulerParam } from 'features/parameters/store/parameterZodSchemas'; export const initialUIState: UIState = { activeTab: 0, @@ -19,6 +19,7 @@ export const initialUIState: UIState = { shouldShowGallery: true, shouldHidePreview: false, shouldShowProgressInViewer: true, + shouldShowEmbeddingPicker: false, favoriteSchedulers: [], }; @@ -96,6 +97,9 @@ export const uiSlice = createSlice({ ) => { state.favoriteSchedulers = action.payload; }, + toggleEmbeddingPicker: (state) => { + state.shouldShowEmbeddingPicker = !state.shouldShowEmbeddingPicker; + }, }, extraReducers(builder) { builder.addCase(initialImageChanged, (state) => { @@ -122,6 +126,7 @@ export const { toggleGalleryPanel, setShouldShowProgressInViewer, favoriteSchedulersChanged, + toggleEmbeddingPicker, } = uiSlice.actions; export default uiSlice.reducer; diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index d55a1d8fcf..ad0250e56d 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -27,5 +27,6 @@ export interface UIState { shouldPinGallery: boolean; shouldShowGallery: boolean; shouldShowProgressInViewer: boolean; + shouldShowEmbeddingPicker: boolean; favoriteSchedulers: SchedulerParam[]; } diff --git a/invokeai/frontend/web/src/services/events/middleware.ts b/invokeai/frontend/web/src/services/events/middleware.ts index 85641b88a0..665761a626 100644 --- a/invokeai/frontend/web/src/services/events/middleware.ts +++ b/invokeai/frontend/web/src/services/events/middleware.ts @@ -1,18 +1,18 @@ import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit'; -import { io, Socket } from 'socket.io-client'; +import { Socket, io } from 'socket.io-client'; +import { AppThunkDispatch, RootState } from 'app/store/store'; +import { getTimestamp } from 'common/util/getTimestamp'; +import { sessionCreated } from 'services/api/thunks/session'; import { ClientToServerEvents, ServerToClientEvents, } from 'services/events/types'; import { socketSubscribed, socketUnsubscribed } from './actions'; -import { AppThunkDispatch, RootState } from 'app/store/store'; -import { getTimestamp } from 'common/util/getTimestamp'; -import { sessionCreated } from 'services/api/thunks/session'; // import { OpenAPI } from 'services/api/types'; -import { setEventListeners } from 'services/events/util/setEventListeners'; import { log } from 'app/logging/useLogger'; import { $authToken, $baseUrl } from 'services/api/client'; +import { setEventListeners } from 'services/events/util/setEventListeners'; const socketioLog = log.child({ namespace: 'socketio' }); @@ -88,7 +88,7 @@ export const socketMiddleware = () => { socketSubscribed({ sessionId: sessionId, timestamp: getTimestamp(), - boardId: getState().boards.selectedBoardId, + boardId: getState().gallery.selectedBoardId, }) ); }