list_models() now returns a dict of {type,{name: info}}

This commit is contained in:
Lincoln Stein 2023-05-15 23:44:08 -04:00
parent c8f765cc06
commit 4fe94a9315
3 changed files with 69 additions and 34 deletions

View File

@ -2,9 +2,11 @@
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Literal, Optional, Union
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import SDModelType
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -58,7 +60,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]] models: dict[str, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -66,9 +68,13 @@ class ModelsList(BaseModel):
operation_id="list_models", operation_id="list_models",
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models() -> ModelsList: async def list_models(
model_type: SDModelType = Query(
default=SDModelType.Diffusers, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models() models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models

View File

@ -99,16 +99,21 @@ class ModelManagerServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def list_models(self) -> dict: def list_models(self, model_type: SDModelType=None) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_key1: {'status': 'active'|'cached'|'not loaded', { model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name, 'model_name' : name,
'model_type' : SDModelType, 'model_type' : SDModelType,
'description': description, 'description': description,
'format': 'folder'|'safetensors'|'ckpt' 'format': 'folder'|'safetensors'|'ckpt'
}, },
model_key2: { etc } model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
""" """
pass pass
@ -385,15 +390,21 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
return self.mgr.model_names() return self.mgr.model_names()
def list_models(self) -> dict: def list_models(self, model_type: SDModelType=None) -> dict:
""" """
Return a dict of models in the format: Return a dict of models in the format:
{ model_key: {'status': 'active'|'cached'|'not loaded', { model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name, 'model_name' : name,
'model_type' : SDModelType, 'model_type' : SDModelType,
'description': description, 'description': description,
'format': 'folder'|'safetensors'|'ckpt' 'format': 'folder'|'safetensors'|'ckpt'
}, },
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
""" """
return self.mgr.list_models() return self.mgr.list_models()

View File

@ -467,9 +467,25 @@ class ModelManager(object):
return True return True
return False return False
def list_models(self) -> dict: def list_models(self, model_type: SDModelType=SDModelType.Diffusers) -> dict[str,dict[str,str]]:
""" """
Return a dict of models Return a dict of models, in format [model_type][model_name], with
following fields:
model_name
model_type
format
description
status
# for folders only
repo_id
path
subfolder
vae
# for ckpts only
config
weights
vae
Please use model_manager.models() to get all the model names, Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf named 'model-name', and model_manager.config to get the full OmegaConf
@ -485,13 +501,15 @@ class ModelManager(object):
if model_key == 'config_file_version': if model_key == 'config_file_version':
continue continue
model_name, model_type = self.parse_key(model_key) model_name, stanza_type = self.parse_key(model_key)
models[model_key] = dict() if model_type is not None and model_type != stanza_type:
# TODO: return all models in future
if model_type != SDModelType.Diffusers:
continue continue
if stanza_type not in models:
models[stanza_type] = dict()
models[stanza_type][model_name] = dict()
model_format = stanza.get('format') model_format = stanza.get('format')
# Common Attribs # Common Attribs
@ -501,7 +519,7 @@ class ModelManager(object):
subfolder=stanza.get('subfolder'), subfolder=stanza.get('subfolder'),
) )
description = stanza.get("description", None) description = stanza.get("description", None)
models[model_key].update( models[stanza_type][model_name].update(
model_name=model_name, model_name=model_name,
model_type=model_type, model_type=model_type,
format=model_format, format=model_format,
@ -509,10 +527,9 @@ class ModelManager(object):
status=status.value, status=status.value,
) )
# Checkpoint Config Parse # Checkpoint Config Parse
if model_format in ["ckpt","safetensors"]: if model_format in ["ckpt","safetensors"]:
models[model_key].update( models[stanza_type][model_name].update(
config = str(stanza.get("config", None)), config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)), weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)), vae = str(stanza.get("vae", None)),
@ -528,7 +545,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)), subfolder = str(vae.get("subfolder", None)),
) )
models[model_key].update( models[stanza_type][model_name].update(
vae = vae, vae = vae,
repo_id = str(stanza.get("repo_id", None)), repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)), path = str(stanza.get("path", None)),
@ -540,7 +557,8 @@ class ModelManager(object):
""" """
Print a table of models, their descriptions, and load status Print a table of models, their descriptions, and load status
""" """
for model_key, model_info in self.list_models().items(): for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}' line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]: if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m" line = f"\033[1m{line}\033[0m"