mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make it possible to pause/resume repo_id downloads
This commit is contained in:
@ -5,7 +5,8 @@ Model download service.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Union
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from .events import EventServiceBase
|
||||
from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler
|
||||
|
||||
@ -16,7 +17,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
@ -26,7 +27,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL or repo_id
|
||||
:param source: Source of the download - URL, repo_id or local Path
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
@ -126,13 +127,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
event_handlers = event_handlers or []
|
||||
if self._event_bus:
|
||||
event_handlers.append([self._event_bus.emit_model_download_event])
|
||||
@ -145,32 +146,32 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
def list_jobs(self) -> List[DownloadJobBase]: # noqa D102
|
||||
return self._queue.list_jobs()
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
def id_to_job(self, id: int) -> DownloadJobBase: # noqa D102
|
||||
return self._queue.id_to_job(id)
|
||||
|
||||
def start_all_jobs(self):
|
||||
def start_all_jobs(self): # noqa D102
|
||||
return self._queue.start_all_jobs()
|
||||
|
||||
def pause_all_jobs(self):
|
||||
def pause_all_jobs(self): # noqa D102
|
||||
return self._queue.pause_all_jobs()
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
def cancel_all_jobs(self): # noqa D102
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
def start_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.start_job(id)
|
||||
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
def pause_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.pause_job(id)
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
def cancel_job(self, job: DownloadJobBase): # noqa D102
|
||||
return self._queue.cancel_job(id)
|
||||
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
def change_priority(self, job: DownloadJobBase, delta: int): # noqa D102
|
||||
return self._queue.change_priority(id, delta)
|
||||
|
||||
def join(self):
|
||||
def join(self): # noqa D102
|
||||
return self._queue.join()
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Callable
|
||||
from typing import List, Optional, Callable, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
@ -63,6 +63,12 @@ class DownloadJobBase(BaseModel):
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (for debugging)"
|
||||
)
|
||||
subqueue: Optional["DownloadQueueBase"] = Field(
|
||||
description="a subqueue used for downloading repo_ids", default=None
|
||||
)
|
||||
preserve_partial_downloads: bool = Field(
|
||||
description="if true, then preserve partial downloads when cancelled or errored", default=False
|
||||
)
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
def add_event_handler(self, handler: DownloadEventHandler):
|
||||
@ -96,7 +102,7 @@ class DownloadQueueBase(ABC):
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
@ -107,7 +113,7 @@ class DownloadQueueBase(ABC):
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
:param source: Source of the download - URL or repo_id
|
||||
:param source: Source of the download - URL, repo_id or Path
|
||||
:param destdir: Directory to download into.
|
||||
:param filename: Optional name of file, if not provided
|
||||
will use the content-disposition field to assign the name.
|
||||
@ -165,8 +171,12 @@ class DownloadQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active and enquedjobs."""
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel all active and enquedjobs.
|
||||
|
||||
:param preserve_partial: Keep partially downloaded files [False].
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -180,8 +190,12 @@ class DownloadQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the job, clearing partial downloads and putting it into CANCELLED state.
|
||||
|
||||
:param preserve_partial: Keep partial downloads [False]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -12,7 +12,7 @@ import traceback
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from requests import HTTPError
|
||||
from typing import Dict, Optional, Set, List, Tuple
|
||||
from typing import Dict, Optional, Set, List, Tuple, Union
|
||||
|
||||
from pydantic import Field, validator, ValidationError
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@ -32,6 +32,9 @@ from .base import (
|
||||
)
|
||||
from ..storage import DuplicateModelException
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||
|
||||
@ -61,6 +64,12 @@ class DownloadJobRepoID(DownloadJobBase):
|
||||
return v
|
||||
|
||||
|
||||
class DownloadJobPath(DownloadJobBase):
|
||||
"""Handle file paths."""
|
||||
|
||||
source: Path = Field(description="Path to a file or directory to install")
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
@ -74,6 +83,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
|
||||
# for debugging
|
||||
_gets: int = 0
|
||||
_dones: int = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
@ -99,9 +112,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
# debugging - get rid of this
|
||||
self._gets = 0
|
||||
self._dones = 0
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
@ -110,12 +127,18 @@ class DownloadQueue(DownloadQueueBase):
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
if re.match(r"^[\w-]+/[\w-]+$", source):
|
||||
kwargs = dict()
|
||||
|
||||
if Path(source).exists():
|
||||
cls = DownloadJobPath
|
||||
elif re.match(r"^[\w-]+/[\w-]+$", str(source)):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs = dict(variant=variant)
|
||||
else:
|
||||
elif re.match(r"^https?://", str(source)):
|
||||
cls = DownloadJobURL
|
||||
kwargs = dict()
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
id = self._next_job_id
|
||||
@ -160,7 +183,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
@ -170,7 +193,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
if job.subqueue:
|
||||
job.subqueue.cancel_all_jobs(preserve_partial=preserve_partial)
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
@ -196,13 +222,16 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""
|
||||
Pause (dequeue) the indicated job.
|
||||
|
||||
In theory the job can be restarted and the download will pick up
|
||||
The job can be restarted with start_job() and the download will pick up
|
||||
from where it left off.
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
if job.subqueue:
|
||||
job.subqueue.cancel_all_jobs(preserve_partial=True)
|
||||
job.subqueue.release()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
@ -213,7 +242,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for job in self._jobs.values():
|
||||
if job.status in [DownloadJobStatus.IDLE or DownloadJobStatus.PAUSED]:
|
||||
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
||||
self.start_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -222,19 +251,19 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Pause all running jobs."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id, job in self._jobs.items():
|
||||
if job.stats == DownloadJobStatus.RUNNING:
|
||||
self.pause_job(id)
|
||||
for job in self._jobs.values():
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self.pause_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""Cancel all running jobs."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id, job in self._jobs.items():
|
||||
for job in self._jobs.values():
|
||||
if not self._in_terminal_state(job):
|
||||
self.cancel_job(id)
|
||||
self.cancel_job(job, preserve_partial)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@ -254,10 +283,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def _download_next_item(self):
|
||||
"""Worker thread gets next job on priority queue."""
|
||||
while True:
|
||||
done = False
|
||||
while not done:
|
||||
job = self._queue.get()
|
||||
self._gets += 1
|
||||
|
||||
try:
|
||||
try: # this is for debugging priority
|
||||
self._lock.acquire()
|
||||
job.job_sequence = self._sequence
|
||||
self._sequence += 1
|
||||
@ -265,13 +296,16 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._lock.release()
|
||||
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
break
|
||||
done = True
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
# There should be a better way to dispatch on the job type
|
||||
if isinstance(job, DownloadJobURL):
|
||||
self._download_with_resume(job)
|
||||
elif isinstance(job, DownloadJobRepoID):
|
||||
self._download_repoid(job)
|
||||
elif isinstance(job, DownloadJobPath):
|
||||
self._download_path(job)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}")
|
||||
|
||||
@ -280,6 +314,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if self._in_terminal_state(job):
|
||||
del self._jobs[job.id]
|
||||
|
||||
self._dones += 1
|
||||
self._queue.task_done()
|
||||
|
||||
def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]:
|
||||
@ -375,7 +411,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming")
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(f"status code {resp.status_code}: {resp.reason}")
|
||||
raise HTTPError(resp.reason)
|
||||
else:
|
||||
self._logger.info(f"{job.source}: Downloading {job.destination}")
|
||||
|
||||
@ -384,7 +420,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
with open(dest, open_mode) as file:
|
||||
for data in resp.iter_content(chunk_size=100000):
|
||||
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
||||
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||
return
|
||||
job.bytes += file.write(data)
|
||||
@ -393,6 +429,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job)
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except DuplicateModelException as excp:
|
||||
self._logger.error(f"A model with the same hash as {dest} is already installed.")
|
||||
job.error = excp
|
||||
@ -426,6 +464,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
for handler in job.event_handlers:
|
||||
try:
|
||||
handler(job)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except Exception as excp:
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
@ -442,7 +482,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if subjob.status == DownloadJobStatus.ERROR:
|
||||
job.error = subjob.error
|
||||
subqueue.cancel_all_jobs()
|
||||
subjob.subqueue.cancel_all_jobs()
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
return
|
||||
|
||||
@ -452,7 +492,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
subqueue = self.__class__(
|
||||
job.subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
)
|
||||
@ -460,28 +500,32 @@ class DownloadQueue(DownloadQueueBase):
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
urls_to_download, metadata = self._get_repo_info(repo_id, variant)
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
if job.destination.stem != Path(repo_id).stem:
|
||||
job.destination = job.destination / Path(repo_id).stem
|
||||
job.metadata = metadata
|
||||
bytes_downloaded = dict()
|
||||
job.total_bytes = 0
|
||||
|
||||
for url, subdir, file, size in urls_to_download:
|
||||
job.total_bytes += size
|
||||
subqueue.create_download_job(
|
||||
job.subqueue.create_download_job(
|
||||
source=url,
|
||||
destdir=job.destination / subdir,
|
||||
filename=file,
|
||||
variant=variant,
|
||||
access_token=job.access_token,
|
||||
)
|
||||
except KeyboardInterrupt as excp:
|
||||
raise excp
|
||||
except Exception as excp:
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
self._logger.error(job.error)
|
||||
finally:
|
||||
subqueue.join()
|
||||
if not job.status == DownloadJobStatus.ERROR:
|
||||
job.subqueue.join()
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
subqueue.release() # get rid of the subqueue
|
||||
job.subqueue.release() # get rid of the subqueue
|
||||
|
||||
def _get_repo_info(
|
||||
self,
|
||||
@ -543,7 +587,22 @@ class DownloadQueue(DownloadQueueBase):
|
||||
result.add(v)
|
||||
return result
|
||||
|
||||
def _download_path(self, job: DownloadJobBase):
|
||||
"""Call when the source is a Path or pathlike object."""
|
||||
source = Path(job.source).resolve()
|
||||
destination = Path(job.destination).resolve()
|
||||
job.metadata = ModelSourceMetadata()
|
||||
try:
|
||||
if source != destination:
|
||||
shutil.move(source, destination)
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except OSError as excp:
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
||||
if job.preserve_partial_downloads:
|
||||
return
|
||||
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
|
||||
dest = Path(job.destination)
|
||||
if dest.is_file():
|
||||
|
@ -15,17 +15,18 @@ Typical usage:
|
||||
installer = ModelInstall(store=store, config=config, download=download)
|
||||
|
||||
# register config, don't move path
|
||||
id: str = installer.register_model('/path/to/model')
|
||||
id: str = installer.register_path('/path/to/model')
|
||||
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_model('/path/to/model')
|
||||
id: str = installer.install_path('/path/to/model')
|
||||
|
||||
# download some remote models and install them in the background
|
||||
installer.download('stabilityai/stable-diffusion-2-1')
|
||||
installer.download('https://civitai.com/api/download/models/154208')
|
||||
installer.download('runwayml/stable-diffusion-v1-5')
|
||||
installer.install('stabilityai/stable-diffusion-2-1')
|
||||
installer.install('https://civitai.com/api/download/models/154208')
|
||||
installer.install('runwayml/stable-diffusion-v1-5')
|
||||
installer.install('/home/user/models/stable-diffusion-v1-5', inplace=True)
|
||||
|
||||
installed_ids = installer.wait_for_downloads()
|
||||
installed_ids = installer.wait_for_installs()
|
||||
id1 = installed_ids['stabilityai/stable-diffusion-2-1']
|
||||
id2 = installed_ids['https://civitai.com/api/download/models/154208']
|
||||
|
||||
@ -94,8 +95,14 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def register(self, model_path: Union[Path, str]) -> str:
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the download queue used by the installer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_path(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Probe and register the model at model_path.
|
||||
|
||||
@ -105,7 +112,7 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install(self, model_path: Union[Path, str]) -> str:
|
||||
def install_path(self, model_path: Union[Path, str]) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
@ -118,9 +125,11 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase:
|
||||
def install(
|
||||
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Download and install the model located at remote site.
|
||||
Download and install the indicated model.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
@ -128,18 +137,25 @@ class ModelInstallBase(ABC):
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_downloads()`
|
||||
method to wait for all downloads to complete.
|
||||
the download and install process. Call our `wait_for_installs()`
|
||||
method to wait for all downloads and installations to complete.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:returns queue: DownloadQueueBase object.
|
||||
:param inplace: If True, local paths will not be moved into
|
||||
the models directory, but registered in place (the default).
|
||||
:param variant: For HuggingFace models, this optional parameter
|
||||
specifies which variant to download (e.g. 'fp16')
|
||||
:returns DownloadQueueBase object.
|
||||
|
||||
The `inplace` flag does not affect the behavior of downloaded
|
||||
models, which are always moved into the `models` directory.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_downloads(self) -> Dict[str, str]:
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
"""
|
||||
Wait for all pending downloads to complete.
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
@ -147,7 +163,7 @@ class ModelInstallBase(ABC):
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
URL or repo_id to the ID of the installed model.
|
||||
path, URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -259,7 +275,12 @@ class ModelInstall(ModelInstallBase):
|
||||
self._async_installs = dict()
|
||||
self._tmpdir = None
|
||||
|
||||
def register(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
@property
|
||||
def queue(self) -> DownloadQueueBase:
|
||||
"""Return the queue."""
|
||||
return self._download_queue
|
||||
|
||||
def register_path(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
return self._register(model_path, info)
|
||||
@ -293,7 +314,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self._store.add_model(id, registration_data)
|
||||
return id
|
||||
|
||||
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
@ -318,20 +339,21 @@ class ModelInstall(ModelInstallBase):
|
||||
rmtree(model.path)
|
||||
self.unregister(id)
|
||||
|
||||
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: # noqa D102
|
||||
def install(
|
||||
self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._config.models_path
|
||||
queue = self._download_queue
|
||||
self._async_installs[source] = None
|
||||
|
||||
def complete_installation(job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install(job.destination)
|
||||
model_id = self.install_path(job.destination)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Downloaded model {info.name}"
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
@ -339,22 +361,46 @@ class ModelInstall(ModelInstallBase):
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Download finished with error: {job.error}")
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Download cancelled at caller's request.")
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
jobs = queue.list_jobs()
|
||||
if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir.cleanup()
|
||||
self._tmpdir = None
|
||||
|
||||
# note - this is probably not going to work. The tmpdir
|
||||
# will be deleted before the job actually runs.
|
||||
# Better to do the cleanup in the callback
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
job = queue.create_download_job(source=source, destdir=self._tmpdir.name)
|
||||
job.add_event_handler(complete_installation)
|
||||
def complete_registration(job: DownloadJobBase):
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
info.description = f"Imported model {info.name}"
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
elif job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Model installation error: {job.error}")
|
||||
elif job.status == "cancelled":
|
||||
self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.")
|
||||
|
||||
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
|
||||
# In the event that we are being asked to install a path that is already on disk,
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
source = Path(source)
|
||||
destdir = source
|
||||
job = queue.create_download_job(source=source, destdir=destdir, start=False, variant=variant)
|
||||
job.add_event_handler(complete_registration if inplace else complete_installation)
|
||||
else:
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
job = queue.create_download_job(source=source, destdir=self._tmpdir.name, start=False, variant=variant)
|
||||
job.add_event_handler(complete_installation)
|
||||
|
||||
self._async_installs[source] = None
|
||||
queue.start_job(job)
|
||||
return job
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
|
@ -13,9 +13,17 @@ from invokeai.backend.model_manager.download import (
|
||||
DownloadJobBase,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
import invokeai.backend.model_manager.download.queue as download_queue
|
||||
|
||||
# Allow for at least one chunk to be fetched during the pause/unpause test.
|
||||
# Otherwise pause test doesn't work because whole file contents are read
|
||||
# before pause is received.
|
||||
download_queue.DOWNLOAD_CHUNK_SIZE = 16500
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
# Disable some tests that require the internet.
|
||||
INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200
|
||||
|
||||
########################################################################################
|
||||
@ -264,7 +272,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
||||
|
||||
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
Reference in New Issue
Block a user