From 4fe94a9315911ba2c3a136a60e55acee2da09a6a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 15 May 2023 23:44:08 -0400 Subject: [PATCH] list_models() now returns a dict of {type,{name: info}} --- invokeai/app/api/routers/models.py | 12 +++-- .../app/services/model_manager_service.py | 41 +++++++++------ .../backend/model_management/model_manager.py | 50 +++++++++++++------ 3 files changed, 69 insertions(+), 34 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 97c0265c50..eb7daeb8dd 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,9 +2,11 @@ from typing import Annotated, Literal, Optional, Union +from fastapi import Query from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies +from invokeai.backend import SDModelType models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -58,7 +60,7 @@ class ConvertedModelResponse(BaseModel): info: DiffusersModelInfo = Field(description="The converted model info") class ModelsList(BaseModel): - models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]] + models: dict[str, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] @models_router.get( @@ -66,9 +68,13 @@ class ModelsList(BaseModel): operation_id="list_models", responses={200: {"model": ModelsList }}, ) -async def list_models() -> ModelsList: +async def list_models( + model_type: SDModelType = Query( + default=SDModelType.Diffusers, description="The type of model to get" + ), +) -> ModelsList: """Gets a list of models""" - models_raw = ApiDependencies.invoker.services.model_manager.list_models() + models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type) models = parse_obj_as(ModelsList, { "models": models_raw }) return models diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 0d140511e0..fbf5e09c94 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -99,16 +99,21 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def list_models(self) -> dict: + def list_models(self, model_type: SDModelType=None) -> dict: """ Return a dict of models in the format: - { model_key1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_key2: { etc } + { model_type1: + { model_name1: {'status': 'active'|'cached'|'not loaded', + 'model_name' : name, + 'model_type' : SDModelType, + 'description': description, + 'format': 'folder'|'safetensors'|'ckpt' + }, + model_name2: { etc } + }, + model_type2: + { model_name_n: etc + } """ pass @@ -385,15 +390,21 @@ class ModelManagerService(ModelManagerServiceBase): """ return self.mgr.model_names() - def list_models(self) -> dict: + def list_models(self, model_type: SDModelType=None) -> dict: """ Return a dict of models in the format: - { model_key: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, + { model_type1: + { model_name1: {'status': 'active'|'cached'|'not loaded', + 'model_name' : name, + 'model_type' : SDModelType, + 'description': description, + 'format': 'folder'|'safetensors'|'ckpt' + }, + model_name2: { etc } + }, + model_type2: + { model_name_n: etc + } """ return self.mgr.list_models() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c45494386e..bed52a2fc5 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -467,9 +467,25 @@ class ModelManager(object): return True return False - def list_models(self) -> dict: + def list_models(self, model_type: SDModelType=SDModelType.Diffusers) -> dict[str,dict[str,str]]: """ - Return a dict of models + Return a dict of models, in format [model_type][model_name], with + following fields: + model_name + model_type + format + description + status + # for folders only + repo_id + path + subfolder + vae + # for ckpts only + config + weights + vae + Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model named 'model-name', and model_manager.config to get the full OmegaConf @@ -485,13 +501,15 @@ class ModelManager(object): if model_key == 'config_file_version': continue - model_name, model_type = self.parse_key(model_key) - models[model_key] = dict() - - # TODO: return all models in future - if model_type != SDModelType.Diffusers: + model_name, stanza_type = self.parse_key(model_key) + if model_type is not None and model_type != stanza_type: continue + if stanza_type not in models: + models[stanza_type] = dict() + + models[stanza_type][model_name] = dict() + model_format = stanza.get('format') # Common Attribs @@ -501,7 +519,7 @@ class ModelManager(object): subfolder=stanza.get('subfolder'), ) description = stanza.get("description", None) - models[model_key].update( + models[stanza_type][model_name].update( model_name=model_name, model_type=model_type, format=model_format, @@ -509,10 +527,9 @@ class ModelManager(object): status=status.value, ) - # Checkpoint Config Parse if model_format in ["ckpt","safetensors"]: - models[model_key].update( + models[stanza_type][model_name].update( config = str(stanza.get("config", None)), weights = str(stanza.get("weights", None)), vae = str(stanza.get("vae", None)), @@ -528,7 +545,7 @@ class ModelManager(object): subfolder = str(vae.get("subfolder", None)), ) - models[model_key].update( + models[stanza_type][model_name].update( vae = vae, repo_id = str(stanza.get("repo_id", None)), path = str(stanza.get("path", None)), @@ -540,11 +557,12 @@ class ModelManager(object): """ Print a table of models, their descriptions, and load status """ - for model_key, model_info in self.list_models().items(): - line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}' - if model_info["status"] in ["in gpu","locked in gpu"]: - line = f"\033[1m{line}\033[0m" - print(line) + for model_type, model_dict in self.list_models().items(): + for model_name, model_info in model_dict.items(): + line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}' + if model_info["status"] in ["in gpu","locked in gpu"]: + line = f"\033[1m{line}\033[0m" + print(line) def del_model( self,