update_model working

This commit is contained in:
Lincoln Stein 2023-07-04 17:26:57 -04:00
parent 752b4d50cf
commit 5d099f4a49
4 changed files with 24 additions and 62 deletions

View File

@ -8,79 +8,33 @@ from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
model_name: str = Field(description="The name of the new model")
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: list[MODEL_CONFIGS]
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get(
"/",
"/{base_model}/{model_type}",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models(
base_model: Optional[BaseModelType] = Query(
base_model: Optional[BaseModelType] = Path(
default=None, description="Base model"
),
model_type: Optional[ModelType] = Query(
model_type: Optional[ModelType] = Path(
default=None, description="The type of model to get"
),
) -> ModelsList:
@ -90,23 +44,28 @@ async def list_models(
return models
@models_router.post(
"/",
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"status": "success"}},
)
async def update_model(
model_request: CreateModelRequest
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> CreateModelResponse:
""" Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict(),
clobber=True,
)
model_response = CreateModelResponse(
model_name = model_name,
info = info,
status="success")
return model_response

View File

@ -318,6 +318,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
@ -332,6 +333,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk.
"""
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)

View File

@ -205,7 +205,7 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed)
# huggingface repo
elif str(model_path_id_or_url).split('/') == 2:
elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL

View File

@ -612,6 +612,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id)
self.models[model_key] = model_config
self.commit()
return AddModelResult(
name = model_name,
model_type = model_type,