mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add install job control to web API
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -15,11 +16,10 @@ from invokeai.backend.model_manager import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelInstallJob,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -47,9 +47,18 @@ class ModelImportStatus(BaseModel):
|
||||
job_id: int
|
||||
source: str
|
||||
priority: int
|
||||
bytes: int
|
||||
total_bytes: int
|
||||
status: DownloadJobStatus
|
||||
|
||||
|
||||
class JobControlOperation(str, Enum):
|
||||
START = "Start"
|
||||
PAUSE = "Pause"
|
||||
CANCEL = "Cancel"
|
||||
CHANGE_PRIORITY = "Change Priority"
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -129,20 +138,27 @@ async def import_model(
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Model characteristics will be probed and configured automatically.
|
||||
The return object is a ModelInstallJob job ID. The work will be
|
||||
performed in the background. Listen on the event bus for a series of
|
||||
`model_event` events with an `id` matching the returned job id to get
|
||||
the progress, completion status, errors, and information on the
|
||||
model that was installed.
|
||||
"""
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has a `job_id` property
|
||||
that can be used to control the download job.
|
||||
|
||||
Listen on the event bus for a series of `model_event` events with an `id`
|
||||
matching the returned job id to get the progress, completion status, errors,
|
||||
and information on the model that was installed.
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.model_manager.install_model(
|
||||
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
|
||||
)
|
||||
return ModelImportStatus(job_id=result.id, source=result.source, priority=result.priority, status=result.status)
|
||||
return ModelImportStatus(
|
||||
job_id=result.id,
|
||||
source=result.source,
|
||||
priority=result.priority,
|
||||
bytes=result.bytes,
|
||||
total_bytes=result.total_bytes,
|
||||
status=result.status,
|
||||
)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -169,19 +185,24 @@ async def import_model(
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
|
||||
This call will block until the model is installed.
|
||||
"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
path = info.path
|
||||
job = ApiDependencies.invoker.services.model_manager.add_model(path)
|
||||
ApiDependencies.invoker.services.model_manager.wait_for_installs()
|
||||
key = job.model_key
|
||||
logger.info(f"Created model {key} for {path}")
|
||||
|
||||
# update with the provided info
|
||||
info_dict = info.dict()
|
||||
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
|
||||
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
|
||||
return parse_obj_as(ImportModelResponse, new_config.dict())
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -344,3 +365,92 @@ async def merge_models(
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/jobs",
|
||||
operation_id="list_install_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[ModelImportStatus],
|
||||
)
|
||||
async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
"""List active and pending model installation jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
mgr = ApiDependencies.invoker.services.model_manager
|
||||
try:
|
||||
jobs = mgr.list_install_jobs()
|
||||
return [
|
||||
ModelImportStatus(
|
||||
job_id=x.id,
|
||||
source=x.source,
|
||||
priority=x.priority,
|
||||
bytes=x.bytes,
|
||||
total_bytes=x.total_bytes,
|
||||
status=x.status,
|
||||
)
|
||||
for x in jobs
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/{job_id}",
|
||||
operation_id="control_install_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The job could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def control_install_jobs(
|
||||
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
|
||||
operation: JobControlOperation = Body(description="The operation to perform on the job."),
|
||||
priority_delta: Optional[int] = Body(
|
||||
description="Change in job priority for priority operations only. Negative numbers increase priority.",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelImportStatus:
|
||||
"""Start, pause, cancel, or change the run priority of a running model install job."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
mgr = ApiDependencies.invoker.services.model_manager
|
||||
try:
|
||||
job = mgr.id_to_job(job_id)
|
||||
|
||||
if operation == JobControlOperation.START:
|
||||
mgr.start_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.PAUSE:
|
||||
mgr.pause_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.CANCEL:
|
||||
mgr.cancel_job(job_id)
|
||||
|
||||
elif operation == JobControlOperation.CHANGE_PRIORITY:
|
||||
mgr.change_job_priority(job_id, priority_delta)
|
||||
else:
|
||||
raise ValueError(f"Unknown operation {JobControlOperation.value}")
|
||||
|
||||
return ModelImportStatus(
|
||||
job_id=job_id,
|
||||
source=job.source,
|
||||
priority=job.priority,
|
||||
status=job.status,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
except UnknownJobIDException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
@ -6,7 +6,6 @@ from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
@ -24,6 +24,7 @@ from invokeai.backend.model_manager import (
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
@ -32,11 +33,16 @@ if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
# "forward declaration" because of circular import issues
|
||||
class EventServiceBase:
|
||||
pass
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
|
||||
"""
|
||||
Initialize a ModelManagerService.
|
||||
|
||||
@ -211,6 +217,60 @@ class ModelManagerServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def merge_models(
|
||||
self,
|
||||
@ -256,35 +316,6 @@ class ModelManagerServiceBase(ABC):
|
||||
"""Reset model cache statistics for graph with graph_id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: ModelInstallJob):
|
||||
"""Cancel this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, job: ModelInstallJob):
|
||||
"""Pause this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, job: ModelInstallJob):
|
||||
"""(re)start this job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, job: ModelInstallJob, delta: int):
|
||||
"""
|
||||
Raise or lower the priority of the job.
|
||||
|
||||
:param job: Job to apply change to
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make my_job a really high priority job:
|
||||
manager.change_priority(my_job, -10).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
# implementation
|
||||
class ModelManagerService(ModelManagerServiceBase):
|
||||
@ -390,7 +421,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
attach to the model. When installing a URL or repo_id, some metadata
|
||||
values, such as `tags` will be automagically added.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {source}")
|
||||
self.logger.debug(f"add model {source}")
|
||||
variant = "fp16" if self._loader.precision == "float16" else None
|
||||
return self._loader.installer.install(
|
||||
source,
|
||||
@ -398,6 +429,59 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
"""Return a series of active or enqueued ModelInstallJobs."""
|
||||
queue = self._loader.queue
|
||||
jobs: List[DownloadJobBase] = queue.list_jobs()
|
||||
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
|
||||
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
return self._loader.queue.id_to_job(id)
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
return self._loader.installer.wait_for_installs()
|
||||
|
||||
def start_job(self, job_id: int):
|
||||
"""Start the given install job if it is paused or idle."""
|
||||
queue = self._loader.queue
|
||||
queue.start_job(queue.id_to_job(job_id))
|
||||
|
||||
def pause_job(self, job_id: int):
|
||||
"""Pause the given install job if it is paused or idle."""
|
||||
queue = self._loader.queue
|
||||
queue.pause_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_job(self, job_id: int):
|
||||
"""Cancel the given install job."""
|
||||
queue = self._loader.queue
|
||||
queue.cancel_job(queue.id_to_job(job_id))
|
||||
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
||||
:param job_id: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make this a really high priority job:
|
||||
manager.change_job_priority(-10).
|
||||
"""
|
||||
queue = self._loader.queue
|
||||
queue.change_priority(queue.id_to_job(job_id), delta)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
key: str,
|
||||
@ -415,7 +499,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
self.logger.debug(f"update model {key}")
|
||||
new_info = self._loader.store.update_model(key, new_config)
|
||||
return self._loader.installer.sync_model_path(new_info.key)
|
||||
self._loader.installer.sync_model_path(new_info.key)
|
||||
return new_info
|
||||
|
||||
def del_model(
|
||||
self,
|
||||
@ -570,28 +655,3 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
:param new_name: New name for the model
|
||||
"""
|
||||
return self.update_model(key, {"name": new_name})
|
||||
|
||||
def cancel_job(self, job: ModelInstallJob):
|
||||
"""Cancel this job."""
|
||||
self._loader.queue.cancel_job(job)
|
||||
|
||||
def pause_job(self, job: ModelInstallJob):
|
||||
"""Pause this job."""
|
||||
self._loader.queue.pause_job(job)
|
||||
|
||||
def start_job(self, job: ModelInstallJob):
|
||||
"""(re)start this job."""
|
||||
self._loader.queue.start_job(job)
|
||||
|
||||
def change_priority(self, job: ModelInstallJob, delta: int):
|
||||
"""
|
||||
Raise or lower the priority of the job.
|
||||
|
||||
:param job: Job to apply change to
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
Thus to make my_job a really high priority job:
|
||||
manager.change_priority(my_job, -10).
|
||||
"""
|
||||
self._loader.queue.change_priority(job, delta)
|
||||
|
@ -17,7 +17,7 @@ from .install import ModelInstall, ModelInstallJob # noqa F401
|
||||
from .loader import ModelInfo, ModelLoad # noqa F401
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .models import OPENAPI_MODEL_CONFIGS, read_checkpoint_meta # noqa F401
|
||||
from .probe import InvalidModelException, ModelProbeInfo # noqa F401
|
||||
from .probe import InvalidModelException, ModelProbe, ModelProbeInfo # noqa F401
|
||||
from .search import ModelSearch # noqa F401
|
||||
from .storage import ( # noqa F401
|
||||
DuplicateModelException,
|
||||
|
@ -383,22 +383,26 @@ class ModelInstall(ModelInstallBase):
|
||||
model_format=info.format,
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
|
||||
if info.model_type == ModelType.Main:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
"Configuration file for this checkpoint could not be determined"
|
||||
) from exc
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
|
@ -456,12 +456,8 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
|
||||
with open(self.model / "scheduler" / "scheduler_config.json", "r") as file:
|
||||
scheduler_conf = json.load(file)
|
||||
if scheduler_conf["prediction_type"] == "v_prediction":
|
||||
return SchedulerPredictionType.VPrediction
|
||||
elif scheduler_conf["prediction_type"] == "epsilon":
|
||||
return SchedulerPredictionType.Epsilon
|
||||
else:
|
||||
return None
|
||||
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
|
||||
return SchedulerPredictionType(prediction_type)
|
||||
|
||||
def get_variant_type(self) -> ModelVariantType:
|
||||
"""Return the ModelVariantType for diffusers-style main models."""
|
||||
|
Reference in New Issue
Block a user