mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
update_model working
This commit is contained in:
parent
752b4d50cf
commit
5d099f4a49
@ -8,79 +8,33 @@ from pydantic import BaseModel, Field, parse_obj_as
|
|||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management import AddModelResult
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||||
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
class VaeRepo(BaseModel):
|
|
||||||
repo_id: str = Field(description="The repo ID to use for this VAE")
|
|
||||||
path: Optional[str] = Field(description="The path to the VAE")
|
|
||||||
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
|
||||||
description: Optional[str] = Field(description="A description of the model")
|
|
||||||
model_name: str = Field(description="The name of the model")
|
|
||||||
model_type: str = Field(description="The type of the model")
|
|
||||||
|
|
||||||
class DiffusersModelInfo(ModelInfo):
|
|
||||||
format: Literal['folder'] = 'folder'
|
|
||||||
|
|
||||||
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
|
|
||||||
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
|
|
||||||
path: Optional[str] = Field(description="The path to the model")
|
|
||||||
|
|
||||||
class CkptModelInfo(ModelInfo):
|
|
||||||
format: Literal['ckpt'] = 'ckpt'
|
|
||||||
|
|
||||||
config: str = Field(description="The path to the model config")
|
|
||||||
weights: str = Field(description="The path to the model weights")
|
|
||||||
vae: str = Field(description="The path to the model VAE")
|
|
||||||
width: Optional[int] = Field(description="The width of the model")
|
|
||||||
height: Optional[int] = Field(description="The height of the model")
|
|
||||||
|
|
||||||
class SafetensorsModelInfo(CkptModelInfo):
|
|
||||||
format: Literal['safetensors'] = 'safetensors'
|
|
||||||
|
|
||||||
class CreateModelRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the model")
|
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
|
||||||
|
|
||||||
class CreateModelResponse(BaseModel):
|
class CreateModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
model_name: str = Field(description="The name of the new model")
|
||||||
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
|
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ImportModelResponse(BaseModel):
|
class ImportModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the imported model")
|
name: str = Field(description="The name of the imported model")
|
||||||
# base_model: str = Field(description="The base model")
|
|
||||||
# model_type: str = Field(description="The model type")
|
|
||||||
info: AddModelResult = Field(description="The model info")
|
info: AddModelResult = Field(description="The model info")
|
||||||
status: str = Field(description="The status of the API response")
|
status: str = Field(description="The status of the API response")
|
||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: CkptModelInfo = Field(description="The converted model info")
|
|
||||||
save_location: str = Field(description="The path to save the converted model weights")
|
|
||||||
|
|
||||||
class ConvertedModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the new model")
|
|
||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[MODEL_CONFIGS]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/{base_model}/{model_type}",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList }},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: Optional[BaseModelType] = Query(
|
base_model: Optional[BaseModelType] = Path(
|
||||||
default=None, description="Base model"
|
default=None, description="Base model"
|
||||||
),
|
),
|
||||||
model_type: Optional[ModelType] = Query(
|
model_type: Optional[ModelType] = Path(
|
||||||
default=None, description="The type of model to get"
|
default=None, description="The type of model to get"
|
||||||
),
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
@ -90,23 +44,28 @@ async def list_models(
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"status": "success"}},
|
responses={200: {"status": "success"}},
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
model_request: CreateModelRequest
|
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
||||||
|
model_type: ModelType = Path(default='main', description="The type of model"),
|
||||||
|
model_name: str = Path(default=None, description="model name"),
|
||||||
|
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> CreateModelResponse:
|
) -> CreateModelResponse:
|
||||||
""" Add Model """
|
""" Add Model """
|
||||||
model_request_info = model_request.info
|
|
||||||
info_dict = model_request_info.dict()
|
|
||||||
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
model_name=model_request.name,
|
model_name=model_name,
|
||||||
model_attributes=info_dict,
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
model_attributes=info.dict(),
|
||||||
clobber=True,
|
clobber=True,
|
||||||
)
|
)
|
||||||
|
model_response = CreateModelResponse(
|
||||||
|
model_name = model_name,
|
||||||
|
info = info,
|
||||||
|
status="success")
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
@ -318,6 +318,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
with an assertion error if provided attributes are incorrect or
|
with an assertion error if provided attributes are incorrect or
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
"""
|
"""
|
||||||
|
self.logger.debug(f'add/update model {model_name}')
|
||||||
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
|
||||||
|
|
||||||
|
|
||||||
@ -332,6 +333,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
then the underlying weight file or diffusers directory will be deleted
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
as well. Call commit() to write to disk.
|
as well. Call commit() to write to disk.
|
||||||
"""
|
"""
|
||||||
|
self.logger.debug(f'delete model {model_name}')
|
||||||
self.mgr.del_model(model_name, base_model, model_type)
|
self.mgr.del_model(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
|
||||||
|
@ -205,7 +205,7 @@ class ModelInstall(object):
|
|||||||
self.heuristic_import(child, models_installed=models_installed)
|
self.heuristic_import(child, models_installed=models_installed)
|
||||||
|
|
||||||
# huggingface repo
|
# huggingface repo
|
||||||
elif str(model_path_id_or_url).split('/') == 2:
|
elif len(str(model_path_id_or_url).split('/')) == 2:
|
||||||
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
||||||
|
|
||||||
# a URL
|
# a URL
|
||||||
|
@ -612,6 +612,7 @@ class ModelManager(object):
|
|||||||
self.cache.uncache_model(cache_id)
|
self.cache.uncache_model(cache_id)
|
||||||
|
|
||||||
self.models[model_key] = model_config
|
self.models[model_key] = model_config
|
||||||
|
self.commit()
|
||||||
return AddModelResult(
|
return AddModelResult(
|
||||||
name = model_name,
|
name = model_name,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
|
Loading…
Reference in New Issue
Block a user