list_models() now returns a dict of {type,{name: info}}

This commit is contained in:
Lincoln Stein
2023-05-15 23:44:08 -04:00
parent c8f765cc06
commit 4fe94a9315
3 changed files with 69 additions and 34 deletions

View File

@ -467,9 +467,25 @@ class ModelManager(object):
return True
return False
def list_models(self) -> dict:
def list_models(self, model_type: SDModelType=SDModelType.Diffusers) -> dict[str,dict[str,str]]:
"""
Return a dict of models
Return a dict of models, in format [model_type][model_name], with
following fields:
model_name
model_type
format
description
status
# for folders only
repo_id
path
subfolder
vae
# for ckpts only
config
weights
vae
Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model
named 'model-name', and model_manager.config to get the full OmegaConf
@ -485,13 +501,15 @@ class ModelManager(object):
if model_key == 'config_file_version':
continue
model_name, model_type = self.parse_key(model_key)
models[model_key] = dict()
# TODO: return all models in future
if model_type != SDModelType.Diffusers:
model_name, stanza_type = self.parse_key(model_key)
if model_type is not None and model_type != stanza_type:
continue
if stanza_type not in models:
models[stanza_type] = dict()
models[stanza_type][model_name] = dict()
model_format = stanza.get('format')
# Common Attribs
@ -501,7 +519,7 @@ class ModelManager(object):
subfolder=stanza.get('subfolder'),
)
description = stanza.get("description", None)
models[model_key].update(
models[stanza_type][model_name].update(
model_name=model_name,
model_type=model_type,
format=model_format,
@ -509,10 +527,9 @@ class ModelManager(object):
status=status.value,
)
# Checkpoint Config Parse
if model_format in ["ckpt","safetensors"]:
models[model_key].update(
models[stanza_type][model_name].update(
config = str(stanza.get("config", None)),
weights = str(stanza.get("weights", None)),
vae = str(stanza.get("vae", None)),
@ -528,7 +545,7 @@ class ModelManager(object):
subfolder = str(vae.get("subfolder", None)),
)
models[model_key].update(
models[stanza_type][model_name].update(
vae = vae,
repo_id = str(stanza.get("repo_id", None)),
path = str(stanza.get("path", None)),
@ -540,11 +557,12 @@ class ModelManager(object):
"""
Print a table of models, their descriptions, and load status
"""
for model_key, model_info in self.list_models().items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line)
for model_type, model_dict in self.list_models().items():
for model_name, model_info in model_dict.items():
line = f'{model_info["model_name"]:25s} {model_info["status"]:>15s} {model_info["model_type"]:10s} {model_info["description"]}'
if model_info["status"] in ["in gpu","locked in gpu"]:
line = f"\033[1m{line}\033[0m"
print(line)
def del_model(
self,