allow priority to be set at install job submission time

This commit is contained in:
Lincoln Stein
2023-09-24 10:08:21 -04:00
parent 6edee2d22b
commit 8bc1ca046c
5 changed files with 33 additions and 8 deletions

View File

@ -134,6 +134,10 @@ async def import_model(
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
priority: Optional[int] = Body(
description="Which import jobs run first. Lower values run before higher ones.",
default=10,
),
) -> ModelImportStatus:
"""
Add a model using its local path, repo_id, or remote URL.
@ -142,6 +146,13 @@ async def import_model(
series of background threads. The return object has a `job_id` property
that can be used to control the download job.
The priority controls which import jobs run first. Lower values run before
higher ones.
The prediction_type applies to SDv2 models only and can be one of
"v_prediction", "epsilon", or "sample". Default if not provided is
"v_prediction".
Listen on the event bus for a series of `model_event` events with an `id`
matching the returned job id to get the progress, completion status, errors,
and information on the model that was installed.
@ -149,7 +160,9 @@ async def import_model(
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.install_model(
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
location,
model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)},
priority=priority,
)
return ModelImportStatus(
job_id=result.id,

View File

@ -198,11 +198,13 @@ class ModelManagerServiceBase(ABC):
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
:param model_attributes: Additional attributes to supplement/override
the model information gained from automated probing.
:param priority: Queue priority. Lower values have higher priority.
Typical usage:
job = model_manager.install(
@ -413,6 +415,7 @@ class ModelManagerService(ModelManagerServiceBase):
self,
source: Union[str, Path, AnyHttpUrl],
model_attributes: Optional[Dict[str, Any]] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
"""
Add a model using a path, repo_id or URL.
@ -420,6 +423,8 @@ class ModelManagerService(ModelManagerServiceBase):
:param model_attributes: Dictionary of ModelConfigBase fields to
attach to the model. When installing a URL or repo_id, some metadata
values, such as `tags` will be automagically added.
:param priority: Queue priority for this install job. Lower value jobs
will run before higher value ones.
"""
self.logger.debug(f"add model {source}")
variant = "fp16" if self._loader.precision == "float16" else None
@ -427,6 +432,7 @@ class ModelManagerService(ModelManagerServiceBase):
source,
probe_override=model_attributes,
variant=variant,
priority=priority,
)
def list_install_jobs(self) -> List[ModelInstallJob]:

View File

@ -135,8 +135,9 @@ class DownloadQueueBase(ABC):
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
@ -146,6 +147,7 @@ class DownloadQueueBase(ABC):
:param source: Source of the download - URL, repo_id or Path
:param destdir: Directory to download into.
:param priority: Initial priority for this job [10]
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]

View File

@ -116,15 +116,14 @@ class DownloadQueue(DownloadQueueBase):
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase:
"""
Create a download job and return its ID.
"""
"""Create a download job and return its ID."""
kwargs = dict()
if Path(source).exists():
@ -142,6 +141,7 @@ class DownloadQueue(DownloadQueueBase):
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handlers=(event_handlers or self._event_handlers),
priority=priority,
**kwargs,
)

View File

@ -492,10 +492,11 @@ class ModelInstall(ModelInstallBase):
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue
job = self._make_download_job(source, variant, access_token)
job = self._make_download_job(source, variant=variant, access_token=access_token, priority=priority)
handler = (
self._complete_registration_handler
if inplace and Path(source).exists()
@ -581,6 +582,7 @@ class ModelInstall(ModelInstallBase):
source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> DownloadJobBase:
# Clean up a common source of error. Doesn't work with Paths.
if isinstance(source, str):
@ -606,7 +608,9 @@ class ModelInstall(ModelInstallBase):
kwargs = {}
else:
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
return cls(
source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs
)
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
self._download_queue.join()