mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor create_download_job; override probe info in install call
This commit is contained in:
@ -46,12 +46,9 @@ class DownloadJobBase(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
@ -114,7 +111,7 @@ class DownloadQueueBase(ABC):
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job.
|
||||
Create and submit a download job.
|
||||
|
||||
:param source: Source of the download - URL, repo_id or Path
|
||||
:param destdir: Directory to download into.
|
||||
@ -123,7 +120,22 @@ class DownloadQueueBase(ABC):
|
||||
:param start: Immediately start job [True]
|
||||
:param variant: Variant to download, such as "fp16" (repo_ids only).
|
||||
:param event_handlers: Optional callables that will be called whenever job status changes.
|
||||
:returns job id: The numeric ID of the DownloadJobBase object for this task.
|
||||
:returns the job: job.id will be a non-negative value after execution
|
||||
"""
|
||||
pass
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
"""
|
||||
Submit a download job.
|
||||
|
||||
:param job: A DownloadJobBase
|
||||
:param start: Immediately start job [True]
|
||||
|
||||
After execution, `job.id` will be set to a non-negative value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -78,7 +78,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_queue: PriorityQueue
|
||||
_lock: threading.Lock
|
||||
_logger: InvokeAILogger
|
||||
_event_handlers: Optional[List[DownloadEventHandler]]
|
||||
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
||||
_next_job_id: int = 0
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
@ -90,7 +90,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
):
|
||||
@ -139,19 +139,30 @@ class DownloadQueue(DownloadQueueBase):
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
||||
job = cls(
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handlers=(event_handlers or self._event_handlers),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self.submit_download_job(job, start)
|
||||
|
||||
def submit_download_job(
|
||||
self,
|
||||
job: DownloadJobBase,
|
||||
start: bool = True,
|
||||
):
|
||||
"""Submit a job."""
|
||||
# add the queue's handlers
|
||||
for handler in self._event_handlers:
|
||||
job.add_event_handler(handler)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
id = self._next_job_id
|
||||
self._jobs[id] = cls(
|
||||
id=id,
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handlers=(event_handlers or self._event_handlers),
|
||||
**kwargs,
|
||||
)
|
||||
job.id = self._next_job_id
|
||||
self._jobs[job.id] = job
|
||||
self._next_job_id += 1
|
||||
job = self._jobs[id]
|
||||
finally:
|
||||
self._lock.release()
|
||||
if start:
|
||||
|
@ -20,7 +20,7 @@ Typical usage:
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_path('/path/to/model')
|
||||
|
||||
1 # download some remote models and install them in the background
|
||||
# download some remote models and install them in the background
|
||||
installer.install('stabilityai/stable-diffusion-2-1')
|
||||
installer.install('https://civitai.com/api/download/models/154208')
|
||||
installer.install('runwayml/stable-diffusion-v1-5')
|
||||
@ -48,6 +48,7 @@ The following exceptions may be raised:
|
||||
DuplicateModelException
|
||||
UnknownModelTypeException
|
||||
"""
|
||||
import re
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
@ -60,6 +61,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .search import ModelSearch
|
||||
from .storage import ModelConfigStore, DuplicateModelException, get_config_store
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
|
||||
from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath
|
||||
from .hash import FastModelHash
|
||||
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
|
||||
from .config import (
|
||||
@ -71,6 +73,30 @@ from .config import (
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallJob(DownloadJobBase):
|
||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||
|
||||
model_key: Optional[str] = Field(
|
||||
description="After model installation, this field will hold its primary key", default=None
|
||||
)
|
||||
probe_info: Optional[ModelProbeInfo] = Field(
|
||||
description="If provided, information here will be used instead of probing the model.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class ModelInstallURLJob(DownloadJobURL, ModelInstallJob):
|
||||
"""Job for installing URLs."""
|
||||
|
||||
|
||||
class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob):
|
||||
"""Job for installing repo ids."""
|
||||
|
||||
|
||||
class ModelInstallPathJob(DownloadJobPath, ModelInstallJob):
|
||||
"""Job for installing local paths."""
|
||||
|
||||
|
||||
class ModelInstallBase(ABC):
|
||||
"""Abstract base class for InvokeAI model installation"""
|
||||
|
||||
@ -103,17 +129,18 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str]) -> str:
|
||||
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:returns id: The string ID of the registered model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str]) -> str:
|
||||
def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
@ -121,13 +148,18 @@ class ModelInstallBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install(
|
||||
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Download and install the indicated model.
|
||||
@ -146,6 +178,7 @@ class ModelInstallBase(ABC):
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:returns DownloadQueueBase object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
@ -283,9 +316,9 @@ class ModelInstall(ModelInstallBase):
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
def register_path(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
|
||||
return self._register(model_path, info)
|
||||
|
||||
def _register(self, model_path: Path, info: ModelProbeInfo) -> str:
|
||||
@ -317,9 +350,13 @@ class ModelInstall(ModelInstallBase):
|
||||
self._store.add_model(key, registration_data)
|
||||
return key
|
||||
|
||||
def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
def install_path(
|
||||
self,
|
||||
model_path: Union[Path, str],
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
info: ModelProbeInfo = info or ModelProbe.probe(model_path)
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -343,66 +380,94 @@ class ModelInstall(ModelInstallBase):
|
||||
self.unregister(key)
|
||||
|
||||
def install(
|
||||
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
info: Optional[ModelProbeInfo] = None,
|
||||
inplace: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._config.models_path
|
||||
queue = self._download_queue
|
||||
|
||||
def complete_installation(job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
job = self._make_download_job(source, variant, access_token)
|
||||
handler = (
|
||||
self._complete_registration_handler
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
job.add_event_handler(handler)
|
||||
job.probe_info = info
|
||||
|
||||
def complete_registration(job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
self._async_installs[source] = None
|
||||
queue.submit_download_job(job, True)
|
||||
return job
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_info)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = self._download_queue.list_jobs()
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_info)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
job.model_key = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def _make_download_job(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> DownloadJobBase:
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
source = Path(source)
|
||||
destdir = source
|
||||
job = queue.create_download_job(source=source, destdir=destdir, start=False, variant=variant)
|
||||
job.add_event_handler(complete_registration if inplace else complete_installation)
|
||||
else:
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
job = queue.create_download_job(source=source, destdir=self._tmpdir.name, start=False, variant=variant)
|
||||
job.add_event_handler(complete_installation)
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir))
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.start_job(job)
|
||||
return job
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
if re.match(r"^[\w-]+/[\w-]+$", str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
kwargs = dict(variant=variant)
|
||||
elif re.match(r"^https?://", str(source)):
|
||||
cls = ModelInstallURLJob
|
||||
kwargs = {}
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs)
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
|
||||
self._download_queue.join()
|
||||
|
@ -20,7 +20,6 @@ from onnxruntime import (
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
from ..config import ( # noqa F401
|
||||
BaseModelType,
|
||||
|
@ -14,6 +14,7 @@ from .base import (
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
from ..config import SilenceWarnings
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
|
@ -37,11 +37,11 @@ class ModelProbeInfo(object):
|
||||
|
||||
model_type: ModelType
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: ModelFormat
|
||||
image_size: int
|
||||
variant_type: ModelVariantType = "normal"
|
||||
prediction_type: SchedulerPredictionType = "v_prediction"
|
||||
upcast_attention: bool = False
|
||||
image_size: int = None
|
||||
|
||||
|
||||
class ModelProbeBase(ABC):
|
||||
|
@ -18,4 +18,4 @@ from .util import ( # noqa: F401
|
||||
Chdir,
|
||||
)
|
||||
from .attention import auto_detect_slice_size # noqa: F401
|
||||
from .logging import InvokeAILogger
|
||||
from .logging import InvokeAILogger # noqa: F401
|
||||
|
Reference in New Issue
Block a user