refactor create_download_job; override probe info in install call

This commit is contained in:
Lincoln Stein
2023-09-13 18:53:33 -05:00
parent 6d8b2a7385
commit 4b932b275d
7 changed files with 168 additions and 80 deletions

View File

@ -46,12 +46,9 @@ class DownloadJobBase(BaseModel):
"""Class to monitor and control a model download request."""
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job")
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
source: str = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
@ -114,7 +111,7 @@ class DownloadQueueBase(ABC):
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase:
"""
Create a download job.
Create and submit a download job.
:param source: Source of the download - URL, repo_id or Path
:param destdir: Directory to download into.
@ -123,7 +120,22 @@ class DownloadQueueBase(ABC):
:param start: Immediately start job [True]
:param variant: Variant to download, such as "fp16" (repo_ids only).
:param event_handlers: Optional callables that will be called whenever job status changes.
:returns job id: The numeric ID of the DownloadJobBase object for this task.
:returns the job: job.id will be a non-negative value after execution
"""
pass
def submit_download_job(
self,
job: DownloadJobBase,
start: bool = True,
):
"""
Submit a download job.
:param job: A DownloadJobBase
:param start: Immediately start job [True]
After execution, `job.id` will be set to a non-negative value.
"""
pass

View File

@ -78,7 +78,7 @@ class DownloadQueue(DownloadQueueBase):
_queue: PriorityQueue
_lock: threading.Lock
_logger: InvokeAILogger
_event_handlers: Optional[List[DownloadEventHandler]]
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
_next_job_id: int = 0
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
_requests: requests.sessions.Session
@ -90,7 +90,7 @@ class DownloadQueue(DownloadQueueBase):
def __init__(
self,
max_parallel_dl: int = 5,
event_handlers: Optional[List[DownloadEventHandler]] = None,
event_handlers: List[DownloadEventHandler] = [],
requests_session: Optional[requests.sessions.Session] = None,
config: Optional[InvokeAIAppConfig] = None,
):
@ -139,19 +139,30 @@ class DownloadQueue(DownloadQueueBase):
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handlers=(event_handlers or self._event_handlers),
**kwargs,
)
return self.submit_download_job(job, start)
def submit_download_job(
self,
job: DownloadJobBase,
start: bool = True,
):
"""Submit a job."""
# add the queue's handlers
for handler in self._event_handlers:
job.add_event_handler(handler)
try:
self._lock.acquire()
id = self._next_job_id
self._jobs[id] = cls(
id=id,
source=source,
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handlers=(event_handlers or self._event_handlers),
**kwargs,
)
job.id = self._next_job_id
self._jobs[job.id] = job
self._next_job_id += 1
job = self._jobs[id]
finally:
self._lock.release()
if start:

View File

@ -20,7 +20,7 @@ Typical usage:
# register config, and install model in `models`
id: str = installer.install_path('/path/to/model')
1 # download some remote models and install them in the background
# download some remote models and install them in the background
installer.install('stabilityai/stable-diffusion-2-1')
installer.install('https://civitai.com/api/download/models/154208')
installer.install('runwayml/stable-diffusion-v1-5')
@ -48,6 +48,7 @@ The following exceptions may be raised:
DuplicateModelException
UnknownModelTypeException
"""
import re
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
@ -60,6 +61,7 @@ from invokeai.backend.util.logging import InvokeAILogger
from .search import ModelSearch
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
from .config import (
@ -71,6 +73,30 @@ from .config import (
)
class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field(
description="After model installation, this field will hold its primary key", default=None
)
probe_info: Optional[ModelProbeInfo] = Field(
description="If provided, information here will be used instead of probing the model.",
default=None,
)
class ModelInstallURLJob(DownloadJobURL, ModelInstallJob):
"""Job for installing URLs."""
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
"""Job for installing repo ids."""
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
"""Job for installing local paths."""
class ModelInstallBase(ABC):
"""Abstract base class for InvokeAI model installation"""
@ -103,17 +129,18 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def register_path(self, model_path: Union[Path, str]) -> str:
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
"""
Probe and register the model at model_path.
:param model_path: Filesystem Path to the model.
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
:returns id: The string ID of the registered model.
"""
pass
@abstractmethod
def install_path(self, model_path: Union[Path, str]) -> str:
def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
"""
Probe, register and install the model in the models directory.
@ -121,13 +148,18 @@ class ModelInstallBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
:returns id: The string ID of the installed model.
"""
pass
@abstractmethod
def install(
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
variant: Optional[str] = None,
info: Optional[ModelProbeInfo] = None,
) -> DownloadJobBase:
"""
Download and install the indicated model.
@ -146,6 +178,7 @@ class ModelInstallBase(ABC):
the models directory, but registered in place (the default).
:param variant: For HuggingFace models, this optional parameter
specifies which variant to download (e.g. 'fp16')
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
:returns DownloadQueueBase object.
The `inplace` flag does not affect the behavior of downloaded
@ -283,9 +316,9 @@ class ModelInstall(ModelInstallBase):
"""Return the queue."""
return self._download_queue
def register_path(self, model_path: Union[Path, str]) -> str: # noqa D102
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
return self._register(model_path, info)
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
@ -317,9 +350,13 @@ class ModelInstall(ModelInstallBase):
self._store.add_model(key, registration_data)
return key
def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102
def install_path(
self,
model_path: Union[Path, str],
info: Optional[ModelProbeInfo] = None,
) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
@ -343,66 +380,94 @@ class ModelInstall(ModelInstallBase):
self.unregister(key)
def install(
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
self,
source: Union[str, Path, AnyHttpUrl],
info: Optional[ModelProbeInfo] = None,
inplace: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase: # noqa D102
# choose a temporary directory inside the models directory
models_dir = self._config.models_path
queue = self._download_queue
def complete_installation(job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination)
info = self._store.get_model(model_id)
info.source = str(job.source)
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
jobs = queue.list_jobs()
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir.cleanup()
self._tmpdir = None
job = self._make_download_job(source, variant, access_token)
handler = (
self._complete_registration_handler
if inplace and Path(source).exists()
else self._complete_installation_handler
)
job.add_event_handler(handler)
job.probe_info = info
def complete_registration(job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
self._async_installs[source] = None
queue.submit_download_job(job, True)
return job
def _complete_installation_handler(self, job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_info)
info = self._store.get_model(model_id)
info.source = str(job.source)
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
jobs = self._download_queue.list_jobs()
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir.cleanup()
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_info)
info = self._store.get_model(model_id)
info.source = str(job.source)
info.description = f"Imported model {info.name}"
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
job.model_key = model_id
elif job.status == "error":
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
elif job.status == "cancelled":
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
def _make_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
variant: Optional[str] = None,
access_token: Optional[str] = None,
) -> DownloadJobBase:
# In the event that we are being asked to install a path that is already on disk,
# we simply probe and register/install it. The job does not actually do anything, but we
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
if Path(source).exists(): # a path that is already on disk
source = Path(source)
destdir = source
job = queue.create_download_job(source=source, destdir=destdir, start=False, variant=variant)
job.add_event_handler(complete_registration if inplace else complete_installation)
else:
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
job = queue.create_download_job(source=source, destdir=self._tmpdir.name, start=False, variant=variant)
job.add_event_handler(complete_installation)
return ModelInstallPathJob(source=source, destination=Path(destdir))
self._async_installs[source] = None
queue.start_job(job)
return job
# choose a temporary directory inside the models directory
models_dir = self._config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
if re.match(r"^[\w-]+/[\w-]+$", str(source)):
cls = ModelInstallRepoIDJob
kwargs = dict(variant=variant)
elif re.match(r"^https?://", str(source)):
cls = ModelInstallURLJob
kwargs = {}
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
self._download_queue.join()

View File

@ -20,7 +20,6 @@ from onnxruntime import (
SessionOptions,
get_available_providers,
)
from pydantic import BaseModel
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
from ..config import ( # noqa F401
BaseModelType,

View File

@ -14,6 +14,7 @@ from .base import (
read_checkpoint_meta,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
from ..config import SilenceWarnings
from .sdxl import StableDiffusionXLModel

View File

@ -37,11 +37,11 @@ class ModelProbeInfo(object):
model_type: ModelType
base_type: BaseModelType
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: ModelFormat
image_size: int
variant_type: ModelVariantType = "normal"
prediction_type: SchedulerPredictionType = "v_prediction"
upcast_attention: bool = False
image_size: int = None
class ModelProbeBase(ABC):

View File

@ -18,4 +18,4 @@ from .util import ( # noqa: F401
Chdir,
)
from .attention import auto_detect_slice_size # noqa: F401
from .logging import InvokeAILogger
from .logging import InvokeAILogger # noqa: F401