fix(api): make list models params querys, make path /, remove defaults

The list models route should just be the base route path, and should use query parameters as opposed to path parameters (which cannot be optional)

Removed defaults for update model route - for the purposes of the API, we should always be explicit with this
This commit is contained in:
psychedelicious 2023-07-06 15:34:50 +10:00
parent 8f5fcb188c
commit c21245f590

View File

@ -26,17 +26,13 @@ class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/{base_model}/{model_type}", "/",
operation_id="list_models", operation_id="list_models",
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList }},
) )
async def list_models( async def list_models(
base_model: Optional[BaseModelType] = Path( base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
default=None, description="Base model" model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
),
model_type: Optional[ModelType] = Path(
default=None, description="The type of model to get"
),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
@ -54,10 +50,10 @@ async def list_models(
response_model = UpdateModelResponse, response_model = UpdateModelResponse,
) )
async def update_model( async def update_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(default=None, description="model name"), model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Add Model """
try: try:
@ -194,56 +190,3 @@ async def convert_model(
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e: