feat(models): update model manager service & route to return list of models

This commit is contained in:
psychedelicious 2023-06-22 17:34:12 +10:00
parent 21245a0fb2
commit b937b7da01
3 changed files with 13 additions and 29 deletions

View File

@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend models: list[MODEL_CONFIGS]
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get( @models_router.get(
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: BaseModelType = Query( base_model: Optional[BaseModelType] = Query(
default=None, description="Base model" default=None, description="Base model"
), ),
model_type: ModelType = Query( model_type: Optional[ModelType] = Query(
default=None, description="The type of model to get" default=None, description="The type of model to get"
), ),
) -> ModelsList: ) -> ModelsList:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import torch import torch
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
from dataclasses import dataclass from dataclasses import dataclass
from invokeai.backend.model_management.model_manager import ( from invokeai.backend.model_management.model_manager import (
@ -273,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None model_type: Optional[ModelType] = None
) -> dict: ) -> list[dict]:
# ) -> dict:
""" """
Return a dict of models in the format: Return a list of models.
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
""" """
return self.mgr.list_models(base_model, model_type) return self.mgr.list_models(base_model, model_type)

View File

@ -473,9 +473,9 @@ class ModelManager(object):
self, self,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
) -> Dict[str, Dict[str, str]]: ) -> list[dict]:
""" """
Return a dict of models, in format [base_model][model_type][model_name] Return a list of models.
Please use model_manager.models() to get all the model names, Please use model_manager.models() to get all the model names,
model_manager.model_info('model-name') to get the stanza for the model model_manager.model_info('model-name') to get the stanza for the model
@ -483,7 +483,7 @@ class ModelManager(object):
object derived from models.yaml object derived from models.yaml
""" """
models = dict() models = []
for model_key in sorted(self.models, key=str.casefold): for model_key in sorted(self.models, key=str.casefold):
model_config = self.models[model_key] model_config = self.models[model_key]
@ -493,20 +493,16 @@ class ModelManager(object):
if model_type is not None and cur_model_type != model_type: if model_type is not None and cur_model_type != model_type:
continue continue
if cur_base_model not in models: model_dict = dict(
models[cur_base_model] = dict()
if cur_model_type not in models[cur_base_model]:
models[cur_base_model][cur_model_type] = dict()
models[cur_base_model][cur_model_type][cur_model_name] = dict(
**model_config.dict(exclude_defaults=True), **model_config.dict(exclude_defaults=True),
# OpenAPIModelInfoBase # OpenAPIModelInfoBase
name=cur_model_name, name=cur_model_name,
base_model=cur_base_model, base_model=cur_base_model,
type=cur_model_type, type=cur_model_type,
) )
models.append(model_dict)
return models return models
def print_models(self) -> None: def print_models(self) -> None: