update_model working

This commit is contained in:
Lincoln Stein 2023-07-04 17:26:57 -04:00
parent 752b4d50cf
commit 5d099f4a49
4 changed files with 24 additions and 62 deletions

View File

@ -8,79 +8,33 @@ from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse(BaseModel): class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model") model_name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel): class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model") name: str = Field(description="The name of the imported model")
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info: AddModelResult = Field(description="The model info") info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[MODEL_CONFIGS] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/{base_model}/{model_type}",
operation_id="list_models", operation_id="list_models",
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: Optional[BaseModelType] = Query( base_model: Optional[BaseModelType] = Path(
default=None, description="Base model" default=None, description="Base model"
), ),
model_type: Optional[ModelType] = Query( model_type: Optional[ModelType] = Path(
default=None, description="The type of model to get" default=None, description="The type of model to get"
), ),
) -> ModelsList: ) -> ModelsList:
@ -90,23 +44,28 @@ async def list_models(
return models return models
@models_router.post( @models_router.post(
"/", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"status": "success"}}, responses={200: {"status": "success"}},
) )
async def update_model( async def update_model(
model_request: CreateModelRequest base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> CreateModelResponse: ) -> CreateModelResponse:
""" Add Model """ """ Add Model """
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
ApiDependencies.invoker.services.model_manager.add_model( ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name, model_name=model_name,
model_attributes=info_dict, base_model=base_model,
model_type=model_type,
model_attributes=info.dict(),
clobber=True, clobber=True,
) )
model_response = CreateModelResponse(
model_name = model_name,
info = info,
status="success")
return model_response return model_response

View File

@ -318,6 +318,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
@ -332,6 +333,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted then the underlying weight file or diffusers directory will be deleted
as well. Call commit() to write to disk. as well. Call commit() to write to disk.
""" """
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type) self.mgr.del_model(model_name, base_model, model_type)

View File

@ -205,7 +205,7 @@ class ModelInstall(object):
self.heuristic_import(child, models_installed=models_installed) self.heuristic_import(child, models_installed=models_installed)
# huggingface repo # huggingface repo
elif str(model_path_id_or_url).split('/') == 2: elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# a URL # a URL

View File

@ -612,6 +612,7 @@ class ModelManager(object):
self.cache.uncache_model(cache_id) self.cache.uncache_model(cache_id)
self.models[model_key] = model_config self.models[model_key] = model_config
self.commit()
return AddModelResult( return AddModelResult(
name = model_name, name = model_name,
model_type = model_type, model_type = model_type,