list_models() now returns a dict of {type,{name: info}}

This commit is contained in:
Lincoln Stein
2023-05-15 23:44:08 -04:00
parent c8f765cc06
commit 4fe94a9315
3 changed files with 69 additions and 34 deletions

View File

@ -2,9 +2,11 @@
from typing import Annotated, Literal, Optional, Union
from fastapi import Query
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import SDModelType
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -58,7 +60,7 @@ class ConvertedModelResponse(BaseModel):
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]
models: dict[str, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get(
@ -66,9 +68,13 @@ class ModelsList(BaseModel):
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models() -> ModelsList:
async def list_models(
model_type: SDModelType = Query(
default=SDModelType.Diffusers, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models()
models_raw = ApiDependencies.invoker.services.model_manager.list_models(model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models

View File

@ -99,16 +99,21 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def list_models(self) -> dict:
def list_models(self, model_type: SDModelType=None) -> dict:
"""
Return a dict of models in the format:
{ model_key1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_key2: { etc }
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
pass
@ -385,15 +390,21 @@ class ModelManagerService(ModelManagerServiceBase):
"""
return self.mgr.model_names()
def list_models(self) -> dict:
def list_models(self, model_type: SDModelType=None) -> dict:
"""
Return a dict of models in the format:
{ model_key: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
{ model_type1:
{ model_name1: {'status': 'active'|'cached'|'not loaded',
'model_name' : name,
'model_type' : SDModelType,
'description': description,
'format': 'folder'|'safetensors'|'ckpt'
},
model_name2: { etc }
},
model_type2:
{ model_name_n: etc
}
"""
return self.mgr.list_models()