mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
list_models() now returns a dict of {type,{name: info}}
This commit is contained in:
parent
c8f765cc06
commit
4fe94a9315
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
from typing import Annotated, Literal, Optional, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
|
from fastapi import Query
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
from invokeai.backend import SDModelType
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -58,7 +60,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
|
models: dict[str, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -66,9 +68,13 @@ class ModelsList(BaseModel):
|
|||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models() -> ModelsList:
|
async def list_models(
|
||||||
|
model_type: SDModelType = Query(
|
||||||
|
default=SDModelType.Diffusers, description="The type of model to get"
|
||||||
|
),
|
||||||
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)
|
||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
@ -99,16 +99,21 @@ class ModelManagerServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(self) -> dict:
|
def list_models(self, model_type: SDModelType=None) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_key1: {'status': 'active'|'cached'|'not loaded',
|
{ model_type1:
|
||||||
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||||
'model_name' : name,
|
'model_name' : name,
|
||||||
'model_type' : SDModelType,
|
'model_type' : SDModelType,
|
||||||
'description': description,
|
'description': description,
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
},
|
},
|
||||||
model_key2: { etc }
|
model_name2: { etc }
|
||||||
|
},
|
||||||
|
model_type2:
|
||||||
|
{ model_name_n: etc
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -385,15 +390,21 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
return self.mgr.model_names()
|
return self.mgr.model_names()
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self, model_type: SDModelType=None) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a dict of models in the format:
|
||||||
{ model_key: {'status': 'active'|'cached'|'not loaded',
|
{ model_type1:
|
||||||
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||||
'model_name' : name,
|
'model_name' : name,
|
||||||
'model_type' : SDModelType,
|
'model_type' : SDModelType,
|
||||||
'description': description,
|
'description': description,
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
},
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
},
|
||||||
|
model_type2:
|
||||||
|
{ model_name_n: etc
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models()
|
return self.mgr.list_models()
|
||||||
|
|
||||||
|
@ -467,9 +467,25 @@ class ModelManager(object):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def list_models(self) -> dict:
|
def list_models(self, model_type: SDModelType=SDModelType.Diffusers) -> dict[str,dict[str,str]]:
|
||||||
"""
|
"""
|
||||||
Return a dict of models
|
Return a dict of models, in format [model_type][model_name], with
|
||||||
|
following fields:
|
||||||
|
model_name
|
||||||
|
model_type
|
||||||
|
format
|
||||||
|
description
|
||||||
|
status
|
||||||
|
# for folders only
|
||||||
|
repo_id
|
||||||
|
path
|
||||||
|
subfolder
|
||||||
|
vae
|
||||||
|
# for ckpts only
|
||||||
|
config
|
||||||
|
weights
|
||||||
|
vae
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
named 'model-name', and model_manager.config to get the full OmegaConf
|
named 'model-name', and model_manager.config to get the full OmegaConf
|
||||||
@ -485,13 +501,15 @@ class ModelManager(object):
|
|||||||
if model_key == 'config_file_version':
|
if model_key == 'config_file_version':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
model_name, model_type = self.parse_key(model_key)
|
model_name, stanza_type = self.parse_key(model_key)
|
||||||
models[model_key] = dict()
|
if model_type is not None and model_type != stanza_type:
|
||||||
|
|
||||||
# TODO: return all models in future
|
|
||||||
if model_type != SDModelType.Diffusers:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if stanza_type not in models:
|
||||||
|
models[stanza_type] = dict()
|
||||||
|
|
||||||
|
models[stanza_type][model_name] = dict()
|
||||||
|
|
||||||
model_format = stanza.get('format')
|
model_format = stanza.get('format')
|
||||||
|
|
||||||
# Common Attribs
|
# Common Attribs
|
||||||
@ -501,7 +519,7 @@ class ModelManager(object):
|
|||||||
subfolder=stanza.get('subfolder'),
|
subfolder=stanza.get('subfolder'),
|
||||||
)
|
)
|
||||||
description = stanza.get("description", None)
|
description = stanza.get("description", None)
|
||||||
models[model_key].update(
|
models[stanza_type][model_name].update(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
format=model_format,
|
format=model_format,
|
||||||
@ -509,10 +527,9 @@ class ModelManager(object):
|
|||||||
status=status.value,
|
status=status.value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Checkpoint Config Parse
|
# Checkpoint Config Parse
|
||||||
if model_format in ["ckpt","safetensors"]:
|
if model_format in ["ckpt","safetensors"]:
|
||||||
models[model_key].update(
|
models[stanza_type][model_name].update(
|
||||||
config = str(stanza.get("config", None)),
|
config = str(stanza.get("config", None)),
|
||||||
weights = str(stanza.get("weights", None)),
|
weights = str(stanza.get("weights", None)),
|
||||||
vae = str(stanza.get("vae", None)),
|
vae = str(stanza.get("vae", None)),
|
||||||
@ -528,7 +545,7 @@ class ModelManager(object):
|
|||||||
subfolder = str(vae.get("subfolder", None)),
|
subfolder = str(vae.get("subfolder", None)),
|
||||||
)
|
)
|
||||||
|
|
||||||
models[model_key].update(
|
models[stanza_type][model_name].update(
|
||||||
vae = vae,
|
vae = vae,
|
||||||
repo_id = str(stanza.get("repo_id", None)),
|
repo_id = str(stanza.get("repo_id", None)),
|
||||||
path = str(stanza.get("path", None)),
|
path = str(stanza.get("path", None)),
|
||||||
@ -540,7 +557,8 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Print a table of models, their descriptions, and load status
|
Print a table of models, their descriptions, and load status
|
||||||
"""
|
"""
|
||||||
for model_key, model_info in self.list_models().items():
|
for model_type, model_dict in self.list_models().items():
|
||||||
|
for model_name, model_info in model_dict.items():
|
||||||
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
|
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
|
||||||
if model_info["status"] in ["in gpu","locked in gpu"]:
|
if model_info["status"] in ["in gpu","locked in gpu"]:
|
||||||
line = f"\033[1m{line}\033[0m"
|
line = f"\033[1m{line}\033[0m"
|
||||||
|
Loading…
Reference in New Issue
Block a user