added download manager service and began repo_id download

This commit is contained in:
Lincoln Stein
2023-09-04 18:26:28 -04:00
parent 869f310ae7
commit 8fc20925b5
4 changed files with 446 additions and 189 deletions

View File

@ -0,0 +1,178 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Model download service.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, List
from .events import EventServicesBase
from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL or repo_id."""
@abstractmethod
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
access_token: Optional[str] = None,
) -> int:
"""
Create a download job.
:param source: Source of the download - URL or repo_id
:param destdir: Directory to download into.
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]
:returns job id: The numeric ID of the DownloadJobBase object for this task.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJobBase]:
"""
List active DownloadJobBases.
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJobBase:
"""
Return the DownloadJobBase corresponding to the string ID.
:param id: ID of the DownloadJobBase.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def start_all_jobs(self):
"""Enqueue all stopped jobs."""
pass
@abstractmethod
def pause_all_jobs(self):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def start_job(self, id: int):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, id: int):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, id: int):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def change_priority(self, id: int, delta: int):
"""
Change the job's priority.
:param id: ID of the job
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
job.change_priority(-10).
"""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass
class DownloadQueueService(DownloadQueueServiceBase):
"""Multithreaded queue for downloading models via URL or repo_id."""
_event_bus: EventServicesBase
_queue: DownloadQueue
def __init__(self, event_bus: EventServicesBase, **kwargs):
"""
Initialize new DownloadQueueService object.
:param event_bus: EventServicesBase object for reporting progress.
:param **kwargs: Any of the arguments taken by invokeai.backend.model_manager.download.DownloadQueue.
e.g. `max_parallel_dl`.
"""
self._event_bus = event_bus
self._queue = DownloadQueue(event_handler=self._forward_event)
def _forward_event(self, job: DownloadJobBase):
if self._event_bus:
self._event_bus.emit_model_download_event(job)
def _wrap_handler(self, event_handler: DownloadEventHandler) -> DownloadEventHandler:
def __wrapper(job: DownloadJobBase):
self._forward_events(job)
event_handler(job)
return __wrapper
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
access_token: Optional[str] = None,
event_handler: Optional[DownloadEventHandler] = None,
) -> int:
if event_handler:
event_handler = self._wrap_handler(event_handler)
return self._queue.create_download_job(
source, destdir, filename, start, access_token,
event_handler=event_handler
)
def list_jobs(self) -> List[DownloadJobBase]:
return self._queue.list_jobs()
def id_to_job(self, id: int) -> DownloadJobBase:
return self._queue.id_to_job(id)
def start_all_jobs(self):
return self._queue.start_all_jobs()
def pause_all_jobs(self):
return self._queue.pause_all_jobs()
def cancel_all_jobs(self):
return self._queue.cancel_all_jobs()
def start_job(self, id: int):
return self._queue.start_job(id)
def pause_job(self, id: int):
return self._queue.pause_job(id)
def cancel_job(self, id: int):
return self._queue.cancel_job(id)
def change_priority(self, id: int, delta: int):
return self._queue.change_priority(id, delta)
def join(self):
return self._queue.join()

View File

@ -0,0 +1,12 @@
"""Initialization file for threaded download manager."""
from .base import ( # noqa F401
DownloadQueueBase,
DownloadJobStatus,
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase
)
from .queue import DownloadQueue # noqa F401

View File

@ -0,0 +1,177 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for a multithreaded model download queue.
"""
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Set, List, Optional, Dict, Callable
from pydantic import BaseModel, Field, validator, ValidationError
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
IDLE = "idle" # not enqueued, will not run
ENQUEUED = "enqueued" # enqueued but not yet active
RUNNING = "running" # actively downloading
PAUSED = "paused" # previously started, now paused
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
class UnknownJobIDException(Exception):
"""Raised when an invalid Job ID is requested."""
class CancelledJobException(Exception):
"""Raised when a job is cancelled."""
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@total_ordering
class DownloadJobBase(BaseModel):
"""Class to monitor and control a model download request."""
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job")
source: str = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
event_handler: Optional[DownloadEventHandler] = Field(description="Callable will be called whenever job status changes")
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
class Config():
"""Config object for this pydantic class."""
arbitrary_types_allowed = True
validate_assignment = True
def __lt__(self, other: "DownloadJobBase") -> bool:
"""
Return True if self.priority < other.priority.
:param other: The DownloadJobBase that this will be compared against.
"""
if not hasattr(other, "id"):
return NotImplemented
return self.id < other.id
class DownloadQueueBase(ABC):
"""Abstract base class for managing model downloads."""
@abstractmethod
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handler: Optional[DownloadEventHandler] = None,
) -> int:
"""
Create a download job.
:param source: Source of the download - URL or repo_id
:param destdir: Directory to download into.
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]
:param variant: Variant to download, such as "fp16" (repo_ids only).
:param event_handler: Optional callable that will be called whenever job status changes.
:returns job id: The numeric ID of the DownloadJobBase object for this task.
"""
pass
@abstractmethod
def release(self) -> int:
"""
Release resources used by queue.
If threaded downloads are
used, then this will stop the threads.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJobBase]:
"""
List active DownloadJobBases.
:returns List[DownloadJobBase]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJobBase:
"""
Return the DownloadJobBase corresponding to the string ID.
:param id: ID of the DownloadJobBase.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def start_all_jobs(self):
"""Enqueue all stopped jobs."""
pass
@abstractmethod
def pause_all_jobs(self):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def start_job(self, id: int):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, id: int):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, id: int):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def change_priority(self, id: int, delta: int):
"""
Change the job's priority.
:param id: ID of the job
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
job.change_priority(-10).
"""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass

View File

@ -1,241 +1,113 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for a multithreaded model download queue.
"""
# Copyright (c) 2023, Lincoln D. Stein
"""Implementation of multithreaded download queue for invokeai."""
import os
import re
import os
import requests
import threading
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from queue import PriorityQueue
from pathlib import Path
from threading import Thread
from typing import Set, List, Optional, Dict, Callable
from typing import Dict, Optional, Set, List
from pydantic import BaseModel, Field, validator, PrivateAttr
from pydantic import Field, validator, ValidationError
from pydantic.networks import AnyHttpUrl
from queue import PriorityQueue
from invokeai.backend.util.logging import InvokeAILogger
from .base import (
DownloadQueueBase,
DownloadJobStatus,
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase
)
# marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(
id=-99,
priority=-99,
source='dummy',
destination='/')
class EventServicesBase: # forward declaration
pass
class DownloadJobURL(DownloadJobBase):
"""Job declaration for downloading individual URLs."""
class DownloadJobStatus(str, Enum):
"""State of a download job."""
IDLE = "idle" # not enqueued, will not run
ENQUEUED = "enqueued" # enqueued but not yet active
RUNNING = "running" # actively downloading
PAUSED = "paused" # previously started, now paused
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
source: AnyHttpUrl = Field(description="URL to download")
class UnknownJobIDException(Exception):
"""Raised when an invalid Job ID is requested."""
class DownloadJobRepoID(DownloadJobBase):
"""Download repo ids."""
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
class CancelledJobException(Exception):
"""Raised when a job is cancelled."""
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job")
url: AnyHttpUrl = Field(description="URL to download")
destination: Path = Field(description="Destination of URL on local disk")
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
event_handler: Optional[DownloadEventHandler] = Field(description="Callable will be called whenever job status changes")
error: Exception = Field(default=None, description="Exception that caused an error")
class Config():
"""Config object for this pydantic class."""
arbitrary_types_allowed = True
validate_assignment = True
# @validator('destination')
# def path_doesnt_exist(cls, v):
# """Don't allow a destination to clobber an existing file."""
# if v.exists():
# raise ValueError(f"{v} already exists")
# return v
def __lt__(self, other: "DownloadJob") -> bool:
"""
Return True if self.priority < other.priority.
:param other: The DownloadJob that this will be compared against.
"""
if not hasattr(other, "id"):
return NotImplemented
return self.id < other.id
class DownloadQueueBase(ABC):
"""Abstract base class for managing model downloads."""
@abstractmethod
def create_download_job(
self,
url: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
access_token: Optional[str] = None,
event_handler: Optional[DownloadEventHandler] = None,
) -> int:
"""
Create a download job.
:param url: URL to download.
:param destdir: Directory to download into.
:param filename: Optional name of file, if not provided
will use the content-disposition field to assign the name.
:param start: Immediately start job [True]
:param event_handler: Optional callable that will be called whenever job status changes.
:returns job id: The numeric ID of the DownloadJob object for this task.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
List active DownloadJobs.
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJob:
"""
Return the DownloadJob corresponding to the string ID.
:param id: ID of the DownloadJob.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def start_all_jobs(self):
"""Enqueue all stopped jobs."""
pass
@abstractmethod
def pause_all_jobs(self):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def start_job(self, id: int):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, id: int):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, id: int):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def change_priority(self, id: int, delta: int):
"""
Change the job's priority.
:param id: ID of the job
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
job.change_priority(-10).
"""
pass
@abstractmethod
def join(self):
"""Wait until all jobs are off the queue."""
pass
@validator('source')
@classmethod
def _validate_source(cls, v: str) -> str:
if not re.match(r'^\w+/\w+$', v):
raise ValidationError(f'{v} invalid repo_id')
return v
class DownloadQueue(DownloadQueueBase):
"""Class for queued download of models."""
_jobs: Dict[int, DownloadJob]
_worker_pool: Set[Thread]
_jobs: Dict[int, DownloadJobBase]
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue
_lock: threading.Lock
_logger: InvokeAILogger
_event_bus: Optional[EventServicesBase]
_event_handler: Optional[DownloadEventHandler]
_next_job_id: int = 0
def __init__(self,
max_parallel_dl: int = 5,
events: Optional["EventServicesBase"] = None,
event_handler: Optional[DownloadEventHandler] = None,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param events: Optional EventServices bus for reporting events.
:param event_handler: Optional callable that will be called each time a job status changes.
"""
print('IN __INIT__')
self._jobs = dict()
self._next_job_id = 0
self._queue = PriorityQueue()
self._worker_pool = set()
self._lock = threading.RLock()
self._logger = InvokeAILogger.getLogger()
self._event_bus = events
self._event_handler = event_handler
self._start_workers(max_parallel_dl)
def create_download_job(
self,
url: str,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handler: Optional[DownloadEventHandler] = None,
) -> int:
if re.match(r'^\w+/\w+$', source):
cls = DownloadJobRepoID
kwargs = dict(variant=variant)
else:
cls = DownloadJobURL
kwargs = dict()
try:
self._lock.acquire()
id = self._next_job_id
self._jobs[id] = DownloadJob(
self._jobs[id] = cls(
id=id,
url=url,
source=source,
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handler=(event_handler or self._event_handler),
**kwargs,
)
self._next_job_id += 1
job = self._jobs[id]
@ -245,13 +117,18 @@ class DownloadQueue(DownloadQueueBase):
self.start_job(id)
return job.id
def release(self):
"""Signal our threads to exit when queue done."""
for thread in self._worker_pool:
if thread.is_alive():
self._queue.put(STOP_JOB)
def join(self):
self._queue.join()
def list_jobs(self) -> List[DownloadJob]:
def list_jobs(self) -> List[DownloadJobBase]:
return self._jobs.values()
def change_priority(self, id: int, delta: int):
try:
self._lock.acquire()
@ -262,7 +139,7 @@ class DownloadQueue(DownloadQueueBase):
finally:
self._lock.release()
def cancel_job(self, job: DownloadJob):
def cancel_job(self, job: DownloadJobBase):
try:
self._lock.acquire()
job.status = DownloadJobStatus.ERROR
@ -272,7 +149,7 @@ class DownloadQueue(DownloadQueueBase):
finally:
self._lock.release()
def id_to_job(self, id: int) -> DownloadJob:
def id_to_job(self, id: int) -> DownloadJobBase:
try:
return self._jobs[id]
except KeyError as excp:
@ -322,7 +199,7 @@ class DownloadQueue(DownloadQueueBase):
def _start_workers(self, max_workers: int):
for i in range(0, max_workers):
worker = Thread(target=self._download_next_item, daemon=True)
worker = threading.Thread(target=self._download_next_item, daemon=True)
worker.start()
self._worker_pool.add(worker)
@ -330,17 +207,25 @@ class DownloadQueue(DownloadQueueBase):
"""Worker thread gets next job on priority queue."""
while True:
job = self._queue.get()
if job == STOP_JOB: # marker that queue is done
print(f'DEBUG: thread {threading.current_thread().native_id} exiting')
break
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
self._download_with_resume(job)
if isinstance(job, DownloadJobURL):
self._download_with_resume(job)
elif isinstance(job, DownloadJobRepoID):
raise self._download_repoid(job)
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}")
self._queue.task_done()
def _download_with_resume(self, job: DownloadJob):
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
exist_size = 0
resp = requests.get(job.url, header, stream=True)
resp = requests.get(job.source, header, stream=True)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
@ -348,8 +233,9 @@ class DownloadQueue(DownloadQueueBase):
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except AttributeError:
file_name = os.path.basename(job.url)
dest = job.destination / file_name
file_name = os.path.basename(job.source)
job.destination = job.destination / file_name
dest = job.destination
else:
dest = job.destination
dest.parent.mkdir(parents=True, exist_ok=True)
@ -358,7 +244,7 @@ class DownloadQueue(DownloadQueueBase):
job.bytes = dest.stat().st_size
header["Range"] = f"bytes={job.bytes}-"
open_mode = "ab"
resp = requests.get(job.url, headers=header, stream=True) # new request with range
resp = requests.get(job.source, headers=header, stream=True) # new request with range
if exist_size > content_length:
self._logger.warning("corrupt existing file found. re-downloading")
@ -398,14 +284,18 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.ERROR)
def _update_job_status(self,
job: DownloadJob,
job: DownloadJobBase,
new_status: Optional[DownloadJobStatus] = None
):
"""Optionally change the job status and send an event indicating a change of state."""
if new_status:
job.status = new_status
if bus := self._event_bus:
bus.emit_model_download_event(job)
self._logger.debug(f"Status update for download job {job.id}: {job}")
if job.event_handler:
job.event_handler(job)
def _download_repoid(self, job: DownloadJobBase):
"""Download a job that holds a huggingface repoid."""
repo_id = job.source
variant = job.variant