mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
allow priority to be set at install job submission time
This commit is contained in:
@ -134,6 +134,10 @@ async def import_model(
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
),
|
||||
priority: Optional[int] = Body(
|
||||
description="Which import jobs run first. Lower values run before higher ones.",
|
||||
default=10,
|
||||
),
|
||||
) -> ModelImportStatus:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
@ -142,6 +146,13 @@ async def import_model(
|
||||
series of background threads. The return object has a `job_id` property
|
||||
that can be used to control the download job.
|
||||
|
||||
The priority controls which import jobs run first. Lower values run before
|
||||
higher ones.
|
||||
|
||||
The prediction_type applies to SDv2 models only and can be one of
|
||||
"v_prediction", "epsilon", or "sample". Default if not provided is
|
||||
"v_prediction".
|
||||
|
||||
Listen on the event bus for a series of `model_event` events with an `id`
|
||||
matching the returned job id to get the progress, completion status, errors,
|
||||
and information on the model that was installed.
|
||||
@ -149,7 +160,9 @@ async def import_model(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.model_manager.install_model(
|
||||
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
|
||||
location,
|
||||
model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)},
|
||||
priority=priority,
|
||||
)
|
||||
return ModelImportStatus(
|
||||
job_id=result.id,
|
||||
|
@ -198,11 +198,13 @@ class ModelManagerServiceBase(ABC):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
|
||||
:param model_attributes: Additional attributes to supplement/override
|
||||
the model information gained from automated probing.
|
||||
:param priority: Queue priority. Lower values have higher priority.
|
||||
|
||||
Typical usage:
|
||||
job = model_manager.install(
|
||||
@ -413,6 +415,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using a path, repo_id or URL.
|
||||
@ -420,6 +423,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param model_attributes: Dictionary of ModelConfigBase fields to
|
||||
attach to the model. When installing a URL or repo_id, some metadata
|
||||
values, such as `tags` will be automagically added.
|
||||
:param priority: Queue priority for this install job. Lower value jobs
|
||||
will run before higher value ones.
|
||||
"""
|
||||
self.logger.debug(f"add model {source}")
|
||||
variant = "fp16" if self._loader.precision == "float16" else None
|
||||
@ -427,6 +432,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
|
@ -135,8 +135,9 @@ class DownloadQueueBase(ABC):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
priority: int = 10,
|
||||
filename: Optional[Path] = None,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
@ -146,6 +147,7 @@ class DownloadQueueBase(ABC):
|
||||
|
||||
:param source: Source of the download - URL, repo_id or Path
|
||||
:param destdir: Directory to download into.
|
||||
:param priority: Initial priority for this job [10]
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
:param start: Immediately start job [True]
|
||||
|
@ -116,15 +116,14 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
priority: int = 10,
|
||||
filename: Optional[Path] = None,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job and return its ID.
|
||||
"""
|
||||
"""Create a download job and return its ID."""
|
||||
kwargs = dict()
|
||||
|
||||
if Path(source).exists():
|
||||
@ -142,6 +141,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handlers=(event_handlers or self._event_handlers),
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -492,10 +492,11 @@ class ModelInstall(ModelInstallBase):
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
|
||||
job = self._make_download_job(source, variant, access_token)
|
||||
job = self._make_download_job(source, variant=variant, access_token=access_token, priority=priority)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
@ -581,6 +582,7 @@ class ModelInstall(ModelInstallBase):
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> DownloadJobBase:
|
||||
# Clean up a common source of error. Doesn't work with Paths.
|
||||
if isinstance(source, str):
|
||||
@ -606,7 +608,9 @@ class ModelInstall(ModelInstallBase):
|
||||
kwargs = {}
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
|
||||
return cls(
|
||||
source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
|
||||
self._download_queue.join()
|
||||
|
Reference in New Issue
Block a user