From 22d2c2b3e397db563c64e7f41b6c208064606c9f Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Tue, 14 Mar 2023 22:15:53 -0700 Subject: [PATCH 1/7] [api] Add models router and list model API. --- invokeai/app/api/routers/models.py | 279 +++++++++++++++++++++++++++++ invokeai/app/api_app.py | 4 +- 2 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/api/routers/models.py diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py new file mode 100644 index 0000000000..5b3fbebddd --- /dev/null +++ b/invokeai/app/api/routers/models.py @@ -0,0 +1,279 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Annotated, Any, List, Literal, Optional, Union + +from fastapi.routing import APIRouter +from pydantic import BaseModel, Field, parse_obj_as + +from ..dependencies import ApiDependencies + +models_router = APIRouter(prefix="/v1/models", tags=["models"]) + + +class VaeRepo(BaseModel): + repo_id: str = Field(description="The repo ID to use for this VAE") + path: Optional[str] = Field(description="The path to the VAE") + subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") + + +class ModelInfo(BaseModel): + description: Optional[str] = Field(description="A description of the model") + + +class CkptModelInfo(ModelInfo): + format: Literal['ckpt'] = 'ckpt' + + config: str = Field(description="The path to the model config") + weights: str = Field(description="The path to the model weights") + vae: str = Field(description="The path to the model VAE") + width: Optional[int] = Field(description="The width of the model") + height: Optional[int] = Field(description="The height of the model") + + +class DiffusersModelInfo(ModelInfo): + format: Literal['diffusers'] = 'diffusers' + + vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model") + repo_id: Optional[str] = Field(description="The repo ID to use for this model") + path: Optional[str] = Field(description="The path to the model") + + +class ModelsList(BaseModel): + models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] + + + +@models_router.get( + "/", + operation_id="list_models", + responses={200: {"model": ModelsList }}, +) +async def list_models() -> ModelsList: + """Gets a list of models""" + models_raw = ApiDependencies.invoker.services.model_manager.list_models() + models = parse_obj_as(ModelsList, { "models": models_raw }) + return models + + # @socketio.on("requestSystemConfig") + # def handle_request_capabilities(): + # print(">> System config requested") + # config = self.get_system_config() + # config["model_list"] = self.generate.model_manager.list_models() + # config["infill_methods"] = infill_methods() + # socketio.emit("systemConfig", config) + + # @socketio.on("searchForModels") + # def handle_search_models(search_folder: str): + # try: + # if not search_folder: + # socketio.emit( + # "foundModels", + # {"search_folder": None, "found_models": None}, + # ) + # else: + # ( + # search_folder, + # found_models, + # ) = self.generate.model_manager.search_models(search_folder) + # socketio.emit( + # "foundModels", + # {"search_folder": search_folder, "found_models": found_models}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + # print("\n") + + # @socketio.on("addNewModel") + # def handle_add_model(new_model_config: dict): + # try: + # model_name = new_model_config["name"] + # del new_model_config["name"] + # model_attributes = new_model_config + # if len(model_attributes["vae"]) == 0: + # del model_attributes["vae"] + # update = False + # current_model_list = self.generate.model_manager.list_models() + # if model_name in current_model_list: + # update = True + + # print(f">> Adding New Model: {model_name}") + + # self.generate.model_manager.add_model( + # model_name=model_name, + # model_attributes=model_attributes, + # clobber=True, + # ) + # self.generate.model_manager.commit(opt.conf) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "newModelAdded", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": update, + # }, + # ) + # print(f">> New Model Added: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("deleteModel") + # def handle_delete_model(model_name: str): + # try: + # print(f">> Deleting Model: {model_name}") + # self.generate.model_manager.del_model(model_name) + # self.generate.model_manager.commit(opt.conf) + # updated_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelDeleted", + # { + # "deleted_model_name": model_name, + # "model_list": updated_model_list, + # }, + # ) + # print(f">> Model Deleted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("requestModelChange") + # def handle_set_model(model_name: str): + # try: + # print(f">> Model change requested: {model_name}") + # model = self.generate.set_model(model_name) + # model_list = self.generate.model_manager.list_models() + # if model is None: + # socketio.emit( + # "modelChangeFailed", + # {"model_name": model_name, "model_list": model_list}, + # ) + # else: + # socketio.emit( + # "modelChanged", + # {"model_name": model_name, "model_list": model_list}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("convertToDiffusers") + # def convert_to_diffusers(model_to_convert: dict): + # try: + # if model_info := self.generate.model_manager.model_info( + # model_name=model_to_convert["model_name"] + # ): + # if "weights" in model_info: + # ckpt_path = Path(model_info["weights"]) + # original_config_file = Path(model_info["config"]) + # model_name = model_to_convert["model_name"] + # model_description = model_info["description"] + # else: + # self.socketio.emit( + # "error", {"message": "Model is not a valid checkpoint file"} + # ) + # else: + # self.socketio.emit( + # "error", {"message": "Could not retrieve model info."} + # ) + + # if not ckpt_path.is_absolute(): + # ckpt_path = Path(Globals.root, ckpt_path) + + # if original_config_file and not original_config_file.is_absolute(): + # original_config_file = Path(Globals.root, original_config_file) + + # diffusers_path = Path( + # ckpt_path.parent.absolute(), f"{model_name}_diffusers" + # ) + + # if model_to_convert["save_location"] == "root": + # diffusers_path = Path( + # global_converted_ckpts_dir(), f"{model_name}_diffusers" + # ) + + # if ( + # model_to_convert["save_location"] == "custom" + # and model_to_convert["custom_location"] is not None + # ): + # diffusers_path = Path( + # model_to_convert["custom_location"], f"{model_name}_diffusers" + # ) + + # if diffusers_path.exists(): + # shutil.rmtree(diffusers_path) + + # self.generate.model_manager.convert_and_import( + # ckpt_path, + # diffusers_path, + # model_name=model_name, + # model_description=model_description, + # vae=None, + # original_config_file=original_config_file, + # commit_to_conf=opt.conf, + # ) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelConverted", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Model Converted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("mergeDiffusersModels") + # def merge_diffusers_models(model_merge_info: dict): + # try: + # models_to_merge = model_merge_info["models_to_merge"] + # model_ids_or_paths = [ + # self.generate.model_manager.model_name_or_path(x) + # for x in models_to_merge + # ] + # merged_pipe = merge_diffusion_models( + # model_ids_or_paths, + # model_merge_info["alpha"], + # model_merge_info["interp"], + # model_merge_info["force"], + # ) + + # dump_path = global_models_dir() / "merged_models" + # if model_merge_info["model_merge_save_path"] is not None: + # dump_path = Path(model_merge_info["model_merge_save_path"]) + + # os.makedirs(dump_path, exist_ok=True) + # dump_path = dump_path / model_merge_info["merged_model_name"] + # merged_pipe.save_pretrained(dump_path, safe_serialization=1) + + # merged_model_config = dict( + # model_name=model_merge_info["merged_model_name"], + # description=f'Merge of models {", ".join(models_to_merge)}', + # commit_to_conf=opt.conf, + # ) + + # if vae := self.generate.model_manager.config[models_to_merge[0]].get( + # "vae", None + # ): + # print(f">> Using configured VAE assigned to {models_to_merge[0]}") + # merged_model_config.update(vae=vae) + + # self.generate.model_manager.import_diffuser_model( + # dump_path, **merged_model_config + # ) + # new_model_list = self.generate.model_manager.list_models() + + # socketio.emit( + # "modelsMerged", + # { + # "merged_models": models_to_merge, + # "merged_model_name": model_merge_info["merged_model_name"], + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Models Merged: {models_to_merge}") + # print(f">> New Model Added: {model_merge_info['merged_model_name']}") + # except Exception as e: + # self.handle_exceptions(e) \ No newline at end of file diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 0ce2386557..ab05cb3344 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -14,7 +14,7 @@ from pydantic.schema import schema from ..backend import Args from .api.dependencies import ApiDependencies -from .api.routers import images, sessions +from .api.routers import images, sessions, models from .api.sockets import SocketIO from .invocations import * from .invocations.baseinvocation import BaseInvocation @@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api") app.include_router(images.images_router, prefix="/api") +app.include_router(models.models_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? From 5f92f290fc65f942bf0134de92eb1faf83574d95 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Tue, 14 Mar 2023 22:15:53 -0700 Subject: [PATCH 2/7] [api] Add models router and list model API. --- invokeai/app/api/routers/models.py | 279 +++++++++++++++++++++++++++++ invokeai/app/api_app.py | 4 +- 2 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 invokeai/app/api/routers/models.py diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py new file mode 100644 index 0000000000..5b3fbebddd --- /dev/null +++ b/invokeai/app/api/routers/models.py @@ -0,0 +1,279 @@ +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) + +from typing import Annotated, Any, List, Literal, Optional, Union + +from fastapi.routing import APIRouter +from pydantic import BaseModel, Field, parse_obj_as + +from ..dependencies import ApiDependencies + +models_router = APIRouter(prefix="/v1/models", tags=["models"]) + + +class VaeRepo(BaseModel): + repo_id: str = Field(description="The repo ID to use for this VAE") + path: Optional[str] = Field(description="The path to the VAE") + subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") + + +class ModelInfo(BaseModel): + description: Optional[str] = Field(description="A description of the model") + + +class CkptModelInfo(ModelInfo): + format: Literal['ckpt'] = 'ckpt' + + config: str = Field(description="The path to the model config") + weights: str = Field(description="The path to the model weights") + vae: str = Field(description="The path to the model VAE") + width: Optional[int] = Field(description="The width of the model") + height: Optional[int] = Field(description="The height of the model") + + +class DiffusersModelInfo(ModelInfo): + format: Literal['diffusers'] = 'diffusers' + + vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model") + repo_id: Optional[str] = Field(description="The repo ID to use for this model") + path: Optional[str] = Field(description="The path to the model") + + +class ModelsList(BaseModel): + models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] + + + +@models_router.get( + "/", + operation_id="list_models", + responses={200: {"model": ModelsList }}, +) +async def list_models() -> ModelsList: + """Gets a list of models""" + models_raw = ApiDependencies.invoker.services.model_manager.list_models() + models = parse_obj_as(ModelsList, { "models": models_raw }) + return models + + # @socketio.on("requestSystemConfig") + # def handle_request_capabilities(): + # print(">> System config requested") + # config = self.get_system_config() + # config["model_list"] = self.generate.model_manager.list_models() + # config["infill_methods"] = infill_methods() + # socketio.emit("systemConfig", config) + + # @socketio.on("searchForModels") + # def handle_search_models(search_folder: str): + # try: + # if not search_folder: + # socketio.emit( + # "foundModels", + # {"search_folder": None, "found_models": None}, + # ) + # else: + # ( + # search_folder, + # found_models, + # ) = self.generate.model_manager.search_models(search_folder) + # socketio.emit( + # "foundModels", + # {"search_folder": search_folder, "found_models": found_models}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + # print("\n") + + # @socketio.on("addNewModel") + # def handle_add_model(new_model_config: dict): + # try: + # model_name = new_model_config["name"] + # del new_model_config["name"] + # model_attributes = new_model_config + # if len(model_attributes["vae"]) == 0: + # del model_attributes["vae"] + # update = False + # current_model_list = self.generate.model_manager.list_models() + # if model_name in current_model_list: + # update = True + + # print(f">> Adding New Model: {model_name}") + + # self.generate.model_manager.add_model( + # model_name=model_name, + # model_attributes=model_attributes, + # clobber=True, + # ) + # self.generate.model_manager.commit(opt.conf) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "newModelAdded", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": update, + # }, + # ) + # print(f">> New Model Added: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("deleteModel") + # def handle_delete_model(model_name: str): + # try: + # print(f">> Deleting Model: {model_name}") + # self.generate.model_manager.del_model(model_name) + # self.generate.model_manager.commit(opt.conf) + # updated_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelDeleted", + # { + # "deleted_model_name": model_name, + # "model_list": updated_model_list, + # }, + # ) + # print(f">> Model Deleted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("requestModelChange") + # def handle_set_model(model_name: str): + # try: + # print(f">> Model change requested: {model_name}") + # model = self.generate.set_model(model_name) + # model_list = self.generate.model_manager.list_models() + # if model is None: + # socketio.emit( + # "modelChangeFailed", + # {"model_name": model_name, "model_list": model_list}, + # ) + # else: + # socketio.emit( + # "modelChanged", + # {"model_name": model_name, "model_list": model_list}, + # ) + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("convertToDiffusers") + # def convert_to_diffusers(model_to_convert: dict): + # try: + # if model_info := self.generate.model_manager.model_info( + # model_name=model_to_convert["model_name"] + # ): + # if "weights" in model_info: + # ckpt_path = Path(model_info["weights"]) + # original_config_file = Path(model_info["config"]) + # model_name = model_to_convert["model_name"] + # model_description = model_info["description"] + # else: + # self.socketio.emit( + # "error", {"message": "Model is not a valid checkpoint file"} + # ) + # else: + # self.socketio.emit( + # "error", {"message": "Could not retrieve model info."} + # ) + + # if not ckpt_path.is_absolute(): + # ckpt_path = Path(Globals.root, ckpt_path) + + # if original_config_file and not original_config_file.is_absolute(): + # original_config_file = Path(Globals.root, original_config_file) + + # diffusers_path = Path( + # ckpt_path.parent.absolute(), f"{model_name}_diffusers" + # ) + + # if model_to_convert["save_location"] == "root": + # diffusers_path = Path( + # global_converted_ckpts_dir(), f"{model_name}_diffusers" + # ) + + # if ( + # model_to_convert["save_location"] == "custom" + # and model_to_convert["custom_location"] is not None + # ): + # diffusers_path = Path( + # model_to_convert["custom_location"], f"{model_name}_diffusers" + # ) + + # if diffusers_path.exists(): + # shutil.rmtree(diffusers_path) + + # self.generate.model_manager.convert_and_import( + # ckpt_path, + # diffusers_path, + # model_name=model_name, + # model_description=model_description, + # vae=None, + # original_config_file=original_config_file, + # commit_to_conf=opt.conf, + # ) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelConverted", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Model Converted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("mergeDiffusersModels") + # def merge_diffusers_models(model_merge_info: dict): + # try: + # models_to_merge = model_merge_info["models_to_merge"] + # model_ids_or_paths = [ + # self.generate.model_manager.model_name_or_path(x) + # for x in models_to_merge + # ] + # merged_pipe = merge_diffusion_models( + # model_ids_or_paths, + # model_merge_info["alpha"], + # model_merge_info["interp"], + # model_merge_info["force"], + # ) + + # dump_path = global_models_dir() / "merged_models" + # if model_merge_info["model_merge_save_path"] is not None: + # dump_path = Path(model_merge_info["model_merge_save_path"]) + + # os.makedirs(dump_path, exist_ok=True) + # dump_path = dump_path / model_merge_info["merged_model_name"] + # merged_pipe.save_pretrained(dump_path, safe_serialization=1) + + # merged_model_config = dict( + # model_name=model_merge_info["merged_model_name"], + # description=f'Merge of models {", ".join(models_to_merge)}', + # commit_to_conf=opt.conf, + # ) + + # if vae := self.generate.model_manager.config[models_to_merge[0]].get( + # "vae", None + # ): + # print(f">> Using configured VAE assigned to {models_to_merge[0]}") + # merged_model_config.update(vae=vae) + + # self.generate.model_manager.import_diffuser_model( + # dump_path, **merged_model_config + # ) + # new_model_list = self.generate.model_manager.list_models() + + # socketio.emit( + # "modelsMerged", + # { + # "merged_models": models_to_merge, + # "merged_model_name": model_merge_info["merged_model_name"], + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Models Merged: {models_to_merge}") + # print(f">> New Model Added: {model_merge_info['merged_model_name']}") + # except Exception as e: + # self.handle_exceptions(e) \ No newline at end of file diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 0ce2386557..ab05cb3344 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -14,7 +14,7 @@ from pydantic.schema import schema from ..backend import Args from .api.dependencies import ApiDependencies -from .api.routers import images, sessions +from .api.routers import images, sessions, models from .api.sockets import SocketIO from .invocations import * from .invocations.baseinvocation import BaseInvocation @@ -76,6 +76,8 @@ app.include_router(sessions.session_router, prefix="/api") app.include_router(images.images_router, prefix="/api") +app.include_router(models.models_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? From b52b9985bd4a1292c960ff96c5601b8a0cb12715 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sat, 25 Mar 2023 21:51:38 -0400 Subject: [PATCH 3/7] Adds (Untested) Add Delete Load Endpoints --- invokeai/app/api/routers/models.py | 87 +++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 5b3fbebddd..c988d89596 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,7 +2,7 @@ from typing import Annotated, Any, List, Literal, Optional, Union -from fastapi.routing import APIRouter +from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies @@ -17,6 +17,7 @@ class VaeRepo(BaseModel): class ModelInfo(BaseModel): + model_name: str = Field(..., description="The name of the model") description: Optional[str] = Field(description="A description of the model") @@ -54,6 +55,90 @@ async def list_models() -> ModelsList: models = parse_obj_as(ModelsList, { "models": models_raw }) return models +# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed. +# Seriously, check this. I have no idea how to test it. + +""" Add Model """ +@models_router.post( + "/", + operation_id="add_model", + responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}}, +) +async def add_model( + model_info: Union[CkptModelInfo, DiffusersModelInfo], +) -> Union[CkptModelInfo, DiffusersModelInfo]: + """Adds a new model""" + try: + model_name = model_info["model_name"] + del model_info["model_name"] + model_attributes = model_info + + if len(model_attributes.get("vae", [])) == 0: + del model_attributes["vae"] + + ApiDependencies.invoker.services.model_manager.add_model( + model_name=model_name, + model_attributes=model_attributes, + clobber=True, + ) + # How does Ckpt support deprecation change the above? + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + # or raise the exception to be handled by the global exception handler + raise HTTPException(status_code=500, detail=str(e)) + + return model_info + +""" Delete Model """ +@models_router.delete( + "/{model_name}", + operation_id="del_model", + responses={204: {"description": "Model deleted"}, 404: {"description": "Model not found"}}, +) +async def delete_model(model_name: str) -> None: + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # delete model + print(f">> Deleting Model: {model_name}") + ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) + print(f">> Model Deleted: {model_name}") + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + +""" Load Model """ +models_router.post( + "/load/{model_name}", + operation_id="load_model", + responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}}, +) +async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]: + """ + Load an existing model by name + """ + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # load model + model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name) + print(f">> Model Loaded: {model_name}") + return model_info + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + + + + # @socketio.on("requestSystemConfig") # def handle_request_capabilities(): # print(">> System config requested") From 545b41639e1a7d069b2bada981d750b6ede410f4 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sun, 26 Mar 2023 22:42:47 -0400 Subject: [PATCH 4/7] Addressing Review Notes --- invokeai/app/api/routers/models.py | 72 +++++++++--------------------- 1 file changed, 21 insertions(+), 51 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index c988d89596..77b9b947f8 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -17,7 +17,6 @@ class VaeRepo(BaseModel): class ModelInfo(BaseModel): - model_name: str = Field(..., description="The name of the model") description: Optional[str] = Field(description="A description of the model") @@ -38,12 +37,13 @@ class DiffusersModelInfo(ModelInfo): repo_id: Optional[str] = Field(description="The repo ID to use for this model") path: Optional[str] = Field(description="The path to the model") +class CreateModelRequest (BaseModel): + name: str = Field(description="The name of the model") + info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") class ModelsList(BaseModel): models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] - - @models_router.get( "/", operation_id="list_models", @@ -55,30 +55,28 @@ async def list_models() -> ModelsList: models = parse_obj_as(ModelsList, { "models": models_raw }) return models -# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed. -# Seriously, check this. I have no idea how to test it. - -""" Add Model """ +#Update Model @models_router.post( "/", - operation_id="add_model", - responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}}, + operation_id="update_model", + responses={ + 201: { + "model": Union[CkptModelInfo, DiffusersModelInfo], + "new_model_list": ModelsList + }, + 202: { + "description": "Model submission is processing. Check back later." + }, + }, ) -async def add_model( - model_info: Union[CkptModelInfo, DiffusersModelInfo], -) -> Union[CkptModelInfo, DiffusersModelInfo]: - """Adds a new model""" +async def update_model( + model_request: CreateModelRequest +) -> CreateModelRequest: + #Adds a new model try: - model_name = model_info["model_name"] - del model_info["model_name"] - model_attributes = model_info - - if len(model_attributes.get("vae", [])) == 0: - del model_attributes["vae"] - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_name, - model_attributes=model_attributes, + model_name=model_request["name"], + model_attributes=model_request["info"], clobber=True, ) # How does Ckpt support deprecation change the above? @@ -88,7 +86,7 @@ async def add_model( # or raise the exception to be handled by the global exception handler raise HTTPException(status_code=500, detail=str(e)) - return model_info + return model_request """ Delete Model """ @models_router.delete( @@ -111,34 +109,6 @@ async def delete_model(model_name: str) -> None: raise HTTPException(status_code=500, detail=str(e)) -""" Load Model """ -models_router.post( - "/load/{model_name}", - operation_id="load_model", - responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}}, -) -async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]: - """ - Load an existing model by name - """ - try: - # check if model exists - if model_name not in ApiDependencies.invoker.services.model_manager.models: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - - # load model - model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name) - print(f">> Model Loaded: {model_name}") - return model_info - - except Exception as e: - # Handle any exceptions thrown during the execution of the method - raise HTTPException(status_code=500, detail=str(e)) - - - - - # @socketio.on("requestSystemConfig") # def handle_request_capabilities(): # print(">> System config requested") From f53b125caaede3e56fad15dffcc60064e091baa6 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Mon, 27 Mar 2023 22:48:37 -0400 Subject: [PATCH 5/7] Second Round Fixes --- invokeai/app/api/routers/models.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 77b9b947f8..54424acaeb 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -41,6 +41,11 @@ class CreateModelRequest (BaseModel): name: str = Field(description="The name of the model") info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") +class CreateModelResponse (BaseModel): + name: str = Field(description="The name of the new model") + info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") + + class ModelsList(BaseModel): models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] @@ -61,8 +66,7 @@ async def list_models() -> ModelsList: operation_id="update_model", responses={ 201: { - "model": Union[CkptModelInfo, DiffusersModelInfo], - "new_model_list": ModelsList + "model_response": Union[CkptModelInfo, DiffusersModelInfo], }, 202: { "description": "Model submission is processing. Check back later." @@ -71,22 +75,23 @@ async def list_models() -> ModelsList: ) async def update_model( model_request: CreateModelRequest -) -> CreateModelRequest: - #Adds a new model +) -> CreateModelResponse: + """Adds a new model to the active model configuration file.""" try: ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_request["name"], - model_attributes=model_request["info"], + model_name=model_request.name, + model_attributes=model_request.info, clobber=True, ) - # How does Ckpt support deprecation change the above? + model_response = CreateModelResponse(status="success") except Exception as e: # Handle any exceptions thrown during the execution of the method # or raise the exception to be handled by the global exception handler raise HTTPException(status_code=500, detail=str(e)) - return model_request + + return model_response """ Delete Model """ @models_router.delete( @@ -95,6 +100,7 @@ async def update_model( responses={204: {"description": "Model deleted"}, 404: {"description": "Model not found"}}, ) async def delete_model(model_name: str) -> None: + """Deletes a model based on the model name.""" try: # check if model exists if model_name not in ApiDependencies.invoker.services.model_manager.models: From 5860b517a7c522155131c537930d8f16c0b5fab3 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Tue, 28 Mar 2023 10:39:20 -0400 Subject: [PATCH 6/7] Updated to fix Annotated pydantic errors on modelInfo --- invokeai/app/api/routers/models.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 54424acaeb..987fe3319c 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -15,11 +15,9 @@ class VaeRepo(BaseModel): path: Optional[str] = Field(description="The path to the VAE") subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") - class ModelInfo(BaseModel): description: Optional[str] = Field(description="A description of the model") - class CkptModelInfo(ModelInfo): format: Literal['ckpt'] = 'ckpt' @@ -29,7 +27,6 @@ class CkptModelInfo(ModelInfo): width: Optional[int] = Field(description="The width of the model") height: Optional[int] = Field(description="The height of the model") - class DiffusersModelInfo(ModelInfo): format: Literal['diffusers'] = 'diffusers' @@ -37,14 +34,17 @@ class DiffusersModelInfo(ModelInfo): repo_id: Optional[str] = Field(description="The repo ID to use for this model") path: Optional[str] = Field(description="The path to the model") +class modelInfo(ModelInfo): + info: Annotated[Union[CkptModelInfo,DiffusersModelInfo], Field(discriminator="format")] + class CreateModelRequest (BaseModel): name: str = Field(description="The name of the model") - info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") + info: modelInfo = Field(description="The model details and configuration") class CreateModelResponse (BaseModel): name: str = Field(description="The name of the new model") - info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") - + info: modelInfo = Field(description="The model details and configuration") + status: str = Field(description="The status of the API response") class ModelsList(BaseModel): models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] @@ -83,7 +83,7 @@ async def update_model( model_attributes=model_request.info, clobber=True, ) - model_response = CreateModelResponse(status="success") + model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") except Exception as e: # Handle any exceptions thrown during the execution of the method From ff287e6260188f6f0810fb9135752db3d6786d3c Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Tue, 28 Mar 2023 23:55:02 -0400 Subject: [PATCH 7/7] Updated Fields --- invokeai/app/api/routers/models.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 987fe3319c..2e37ad2cbc 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -34,16 +34,13 @@ class DiffusersModelInfo(ModelInfo): repo_id: Optional[str] = Field(description="The repo ID to use for this model") path: Optional[str] = Field(description="The path to the model") -class modelInfo(ModelInfo): - info: Annotated[Union[CkptModelInfo,DiffusersModelInfo], Field(discriminator="format")] - class CreateModelRequest (BaseModel): name: str = Field(description="The name of the model") - info: modelInfo = Field(description="The model details and configuration") + info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model details and configuration") class CreateModelResponse (BaseModel): name: str = Field(description="The name of the new model") - info: modelInfo = Field(description="The model details and configuration") + info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model details and configuration") status: str = Field(description="The status of the API response") class ModelsList(BaseModel):