mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(models): update model manager service & route to return list of models
This commit is contained in:
parent
21245a0fb2
commit
b937b7da01
@ -62,8 +62,7 @@ class ConvertedModelResponse(BaseModel):
|
|||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
|
models: list[MODEL_CONFIGS]
|
||||||
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
|
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
@ -72,10 +71,10 @@ class ModelsList(BaseModel):
|
|||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: BaseModelType = Query(
|
base_model: Optional[BaseModelType] = Query(
|
||||||
default=None, description="Base model"
|
default=None, description="Base model"
|
||||||
),
|
),
|
||||||
model_type: ModelType = Query(
|
model_type: Optional[ModelType] = Query(
|
||||||
default=None, description="The type of model to get"
|
default=None, description="The type of model to get"
|
||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import torch
|
import torch
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Callable, List, Tuple, types, TYPE_CHECKING
|
from typing import Optional, Union, Callable, List, Tuple, types, TYPE_CHECKING
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from invokeai.backend.model_management.model_manager import (
|
from invokeai.backend.model_management.model_manager import (
|
||||||
@ -273,21 +273,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None
|
model_type: Optional[ModelType] = None
|
||||||
) -> dict:
|
) -> list[dict]:
|
||||||
|
# ) -> dict:
|
||||||
"""
|
"""
|
||||||
Return a dict of models in the format:
|
Return a list of models.
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
return self.mgr.list_models(base_model, model_type)
|
return self.mgr.list_models(base_model, model_type)
|
||||||
|
|
||||||
|
@ -473,9 +473,9 @@ class ModelManager(object):
|
|||||||
self,
|
self,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Dict[str, Dict[str, str]]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Return a dict of models, in format [base_model][model_type][model_name]
|
Return a list of models.
|
||||||
|
|
||||||
Please use model_manager.models() to get all the model names,
|
Please use model_manager.models() to get all the model names,
|
||||||
model_manager.model_info('model-name') to get the stanza for the model
|
model_manager.model_info('model-name') to get the stanza for the model
|
||||||
@ -483,7 +483,7 @@ class ModelManager(object):
|
|||||||
object derived from models.yaml
|
object derived from models.yaml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = dict()
|
models = []
|
||||||
for model_key in sorted(self.models, key=str.casefold):
|
for model_key in sorted(self.models, key=str.casefold):
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
|
|
||||||
@ -493,20 +493,16 @@ class ModelManager(object):
|
|||||||
if model_type is not None and cur_model_type != model_type:
|
if model_type is not None and cur_model_type != model_type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cur_base_model not in models:
|
model_dict = dict(
|
||||||
models[cur_base_model] = dict()
|
|
||||||
if cur_model_type not in models[cur_base_model]:
|
|
||||||
models[cur_base_model][cur_model_type] = dict()
|
|
||||||
|
|
||||||
models[cur_base_model][cur_model_type][cur_model_name] = dict(
|
|
||||||
**model_config.dict(exclude_defaults=True),
|
**model_config.dict(exclude_defaults=True),
|
||||||
|
|
||||||
# OpenAPIModelInfoBase
|
# OpenAPIModelInfoBase
|
||||||
name=cur_model_name,
|
name=cur_model_name,
|
||||||
base_model=cur_base_model,
|
base_model=cur_base_model,
|
||||||
type=cur_model_type,
|
type=cur_model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
models.append(model_dict)
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def print_models(self) -> None:
|
def print_models(self) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user