add back the heuristic_import() method and extend repo_ids to arbitrary file paths

This commit is contained in:
Lincoln Stein
2024-02-11 23:37:49 -05:00
committed by psychedelicious
parent 94e8d1b6d5
commit 13a9ea35b5
6 changed files with 199 additions and 12 deletions

View File

@ -251,9 +251,75 @@ async def add_model_record(
return result
@model_manager_v2_router.post(
"/heuristic_import",
operation_id="heuristic_import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
)
async def heuristic_import(
source: str,
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
`source` can be any of the following.
1. A path on the local filesystem ('C:\\users\\fred\\model.safetensors')
2. A Url pointing to a single downloadable model file
3. A HuggingFace repo_id with any of the following formats:
- model/name
- model/name:fp16:vae
- model/name::vae -- use default precision
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`access_token` is an optional access token for use with Urls that require
authentication.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute
that can be used to monitor progress.
See the documentation for `import_model_record` for more information on
interpreting the job information returned by this route.
"""
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
result: ModelInstallJob = installer.heuristic_import(
source=source,
config=config,
)
logger.info(f"Started installation of {source}")
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return result
@model_manager_v2_router.post(
"/import",
operation_id="import_model_record",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
415: {"description": "Unrecognized file/folder format"},
@ -269,7 +335,7 @@ async def import_model(
default=None,
),
) -> ModelInstallJob:
"""Add a model using its local path, repo_id, or remote URL.
"""Install a model using its local path, repo_id, or remote URL.
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has `status` attribute