From b937b7da011a34d9d0d5fbfc8f3d77f2e85afc73 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 22 Jun 2023 17:34:12 +1000 Subject: [PATCH] feat(models): update model manager service & route to return list of models --- invokeai/app/api/routers/models.py | 7 +++---- .../app/services/model_manager_service.py | 19 ++++--------------- .../backend/model_management/model_manager.py | 16 ++++++---------- 3 files changed, 13 insertions(+), 29 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 0abcc19dcf..50d645eb57 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel): info: DiffusersModelInfo = Field(description="The converted model info") class ModelsList(BaseModel): - models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend - #models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]] + models: list[MODEL_CONFIGS] @models_router.get( @@ -72,10 +71,10 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: BaseModelType = Query( + base_model: Optional[BaseModelType] = Query( default=None, description="Base model" ), - model_type: ModelType = Query( + model_type: Optional[ModelType] = Query( default=None, description="The type of model to get" ), ) -> ModelsList: diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8956b55139..8b46b17ad0 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -5,7 +5,7 @@ from __future__ import annotations import torch from abc import ABC, abstractmethod from pathlib import Path -from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING +from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING from dataclasses import dataclass from invokeai.backend.model_management.model_manager import ( @@ -273,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None - ) -> dict: + ) -> list[dict]: + # ) -> dict: """ - Return a dict of models in the format: - { model_type1: - { model_name1: {'status': 'active'|'cached'|'not loaded', - 'model_name' : name, - 'model_type' : SDModelType, - 'description': description, - 'format': 'folder'|'safetensors'|'ckpt' - }, - model_name2: { etc } - }, - model_type2: - { model_name_n: etc - } + Return a list of models. """ return self.mgr.list_models(base_model, model_type) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 9a8c7e64c6..f9a66a87dd 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -473,9 +473,9 @@ class ModelManager(object): self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, - ) -> Dict[str, Dict[str, str]]: + ) -> list[dict]: """ - Return a dict of models, in format [base_model][model_type][model_name] + Return a list of models. Please use model_manager.models() to get all the model names, model_manager.model_info('model-name') to get the stanza for the model @@ -483,7 +483,7 @@ class ModelManager(object): object derived from models.yaml """ - models = dict() + models = [] for model_key in sorted(self.models, key=str.casefold): model_config = self.models[model_key] @@ -493,20 +493,16 @@ class ModelManager(object): if model_type is not None and cur_model_type != model_type: continue - if cur_base_model not in models: - models[cur_base_model] = dict() - if cur_model_type not in models[cur_base_model]: - models[cur_base_model][cur_model_type] = dict() - - models[cur_base_model][cur_model_type][cur_model_name] = dict( + model_dict = dict( **model_config.dict(exclude_defaults=True), - # OpenAPIModelInfoBase name=cur_model_name, base_model=cur_base_model, type=cur_model_type, ) + models.append(model_dict) + return models def print_models(self) -> None: