adjust for change in list_models() API

This commit is contained in:
Lincoln Stein 2023-06-23 14:13:37 -04:00
parent 58d1857ab6
commit 54b74427f4

View File

@ -121,21 +121,21 @@ class ModelInstall(object):
# supplement with entries in models.yaml # supplement with entries in models.yaml
installed_models = self.mgr.list_models() installed_models = self.mgr.list_models()
for base in installed_models.keys(): for md in installed_models:
for model_type in installed_models[base].keys(): base = md['base_model']
for name, value in installed_models[base][model_type].items(): model_type = md['type']
key = ModelManager.create_key(name, base, model_type) name = md['name']
if key in model_dict: key = ModelManager.create_key(name, base, model_type)
model_dict[key].installed = True if key in model_dict:
else: model_dict[key].installed = True
model_dict[key] = ModelLoadInfo( else:
name = name, model_dict[key] = ModelLoadInfo(
base_type = base, name = name,
model_type = model_type, base_type = base,
# description = value.get('description'), model_type = model_type,
path = value.get('path'), path = value.get('path'),
installed = True, installed = True,
) )
return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())} return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())}
def starter_models(self)->Set[str]: def starter_models(self)->Set[str]:
@ -316,7 +316,7 @@ class ModelInstall(object):
attributes = dict( attributes = dict(
path = str(path), path = str(path),
description = str(description), description = str(description),
format = info.format, model_format = info.format,
) )
if info.model_type == ModelType.Pipeline: if info.model_type == ModelType.Pipeline:
attributes.update( attributes.update(