add download manager to invoke services

This commit is contained in:
Lincoln Stein
2023-09-06 18:47:30 -04:00
parent e9074176bd
commit 404cfe0eb9
7 changed files with 97 additions and 114 deletions

View File

@ -28,6 +28,7 @@ from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.model_manager_service import ModelManagerService
from ..services.download_manager import DownloadQueueService
from ..services.invocation_stats import InvocationStatsService
from .events import FastAPIEventService
@ -129,6 +130,7 @@ class ApiDependencies:
processor=DefaultInvocationProcessor(),
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
download_manager=DownloadQueueService(event_bus=events),
logger=logger,
)

View File

@ -9,18 +9,19 @@ from typing import Optional, List
from .events import EventServicesBase
from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL or repo_id."""
@abstractmethod
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
"""
Create a download job.
@ -136,7 +137,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
if self._event_bus:
event_handlers.append([self._event_bus.emit_model_download_event])
return self._queue.create_download_job(
source, destdir, filename, start, access_token,
source,
destdir,
filename,
start,
access_token,
event_handlers=event_handlers,
)

View File

@ -16,6 +16,7 @@ if TYPE_CHECKING:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
from invokeai.app.services.download_manager import DownloadQueueServiceBase
class InvocationServices:
@ -34,6 +35,7 @@ class InvocationServices:
model_manager: "ModelManagerServiceBase"
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
download_manager: "DownloadQueueServiceBase"
queue: "InvocationQueueABC"
def __init__(

View File

@ -6,7 +6,7 @@ from .base import ( # noqa F401
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase
DownloadJobBase,
)
from .queue import DownloadQueue # noqa F401
from .queue import DownloadQueue # noqa F401

View File

@ -15,12 +15,12 @@ from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
IDLE = "idle" # not enqueued, will not run
ENQUEUED = "enqueued" # enqueued but not yet active
RUNNING = "running" # actively downloading
PAUSED = "paused" # previously started, now paused
IDLE = "idle" # not enqueued, will not run
ENQUEUED = "enqueued" # enqueued but not yet active
RUNNING = "running" # actively downloading
PAUSED = "paused" # previously started, now paused
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
ERROR = "error" # terminated with an error message
class UnknownJobIDException(Exception):
@ -46,13 +46,17 @@ class DownloadJobBase(BaseModel):
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
event_handlers: Optional[List[DownloadEventHandler]] = Field(description="Callables that will be called whenever job status changes")
event_handlers: Optional[List[DownloadEventHandler]] = Field(
description="Callables that will be called whenever job status changes"
)
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
job_sequence: Optional[int] = Field(description="Counter that records order in which this job was dequeued (for debugging)")
job_sequence: Optional[int] = Field(
description="Counter that records order in which this job was dequeued (for debugging)"
)
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
class Config():
class Config:
"""Config object for this pydantic class."""
arbitrary_types_allowed = True
@ -74,14 +78,14 @@ class DownloadQueueBase(ABC):
@abstractmethod
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
"""
Create a download job.
@ -128,7 +132,7 @@ class DownloadQueueBase(ABC):
Note that once a job is completed, id_to_job() may no longer
recognize the job. Call id_to_job() before the job completes
if you wish to keep the job object around after it has
if you wish to keep the job object around after it has
completed work.
"""
pass
@ -186,5 +190,3 @@ class DownloadQueueBase(ABC):
no longer recognize the job.
"""
pass

View File

@ -23,15 +23,11 @@ from .base import (
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase
DownloadJobBase,
)
# marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(
id=-99,
priority=-99,
source='dummy',
destination='/')
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
class DownloadJobURL(DownloadJobBase):
@ -45,13 +41,14 @@ class DownloadJobRepoID(DownloadJobBase):
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
@validator('source')
@validator("source")
@classmethod
def _validate_source(cls, v: str) -> str:
if not re.match(r'^[\w-]+/[\w-]+$', v):
raise ValidationError(f'{v} invalid repo_id')
if not re.match(r"^[\w-]+/[\w-]+$", v):
raise ValidationError(f"{v} invalid repo_id")
return v
class DownloadQueue(DownloadQueueBase):
"""Class for queued download of models."""
@ -65,11 +62,12 @@ class DownloadQueue(DownloadQueueBase):
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
_requests: requests.sessions.Session
def __init__(self,
max_parallel_dl: int = 5,
event_handlers: Optional[List[DownloadEventHandler]] = None,
requests_session: Optional[requests.sessions.Session] = None
):
def __init__(
self,
max_parallel_dl: int = 5,
event_handlers: Optional[List[DownloadEventHandler]] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
@ -89,17 +87,17 @@ class DownloadQueue(DownloadQueueBase):
self._start_workers(max_parallel_dl)
def create_download_job(
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
self,
source: str,
destdir: Path,
filename: Optional[Path] = None,
start: bool = True,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
"""Create a download job and return its ID."""
if re.match(r'^[\w-]+/[\w-]+$', source):
if re.match(r"^[\w-]+/[\w-]+$", source):
cls = DownloadJobRepoID
kwargs = dict(variant=variant)
else:
@ -249,7 +247,7 @@ class DownloadQueue(DownloadQueueBase):
if job == STOP_JOB: # marker that queue is done
break
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
if isinstance(job, DownloadJobURL):
self._download_with_resume(job)
elif isinstance(job, DownloadJobRepoID):
@ -322,10 +320,7 @@ class DownloadQueue(DownloadQueueBase):
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
def _update_job_status(self,
job: DownloadJobBase,
new_status: Optional[DownloadJobStatus] = None
):
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
"""Optionally change the job status and send an event indicating a change of state."""
if new_status:
job.status = new_status
@ -367,7 +362,7 @@ class DownloadQueue(DownloadQueueBase):
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token
access_token=job.access_token,
)
except Exception as excp:
job.status = DownloadJobStatus.ERROR
@ -382,7 +377,7 @@ class DownloadQueue(DownloadQueueBase):
def _get_download_size(self, url: AnyHttpUrl) -> int:
resp = self._requests.get(url, stream=True)
resp.raise_for_status()
return int(resp.headers.get('content-length',0))
return int(resp.headers.get("content-length", 0))
def _get_repo_urls(self, repo_id: str, variant: Optional[str] = None) -> List[Tuple[AnyHttpUrl, Path, Path]]:
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
@ -392,11 +387,14 @@ class DownloadQueue(DownloadQueueBase):
if "model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json")
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, "model_index.json")
return [(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path('.'), x.name) for x in self._select_variants(paths, variant)]
return [
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name)
for x in self._select_variants(paths, variant)
]
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
@ -404,7 +402,7 @@ class DownloadQueue(DownloadQueueBase):
basenames = dict()
for p in paths:
path = Path(p)
if path.suffix in ['.bin', '.safetensors', '.pt']:
if path.suffix in [".bin", ".safetensors", ".pt"]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
@ -416,9 +414,9 @@ class DownloadQueue(DownloadQueueBase):
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != '.safetensors' and suffix == '.safetensors':
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f'.{variant}':
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path

View File

@ -17,35 +17,28 @@ from invokeai.backend.model_manager.download import (
)
SAFETENSORS1_CONTENT = b'I am a safetensors file (1)'
SAFETENSORS1_CONTENT = b"I am a safetensors file (1)"
SAFETENSORS1_HEADER = {
'Content-Length' : len(SAFETENSORS1_CONTENT),
'Content-Disposition': 'filename="mock1.safetensors"'
"Content-Length": len(SAFETENSORS1_CONTENT),
"Content-Disposition": 'filename="mock1.safetensors"',
}
SAFETENSORS2_CONTENT = b'I am a safetensors file (2)'
SAFETENSORS2_CONTENT = b"I am a safetensors file (2)"
SAFETENSORS2_HEADER = {
'Content-Length' : len(SAFETENSORS2_CONTENT),
'Content-Disposition': 'filename="mock2.safetensors"'
"Content-Length": len(SAFETENSORS2_CONTENT),
"Content-Disposition": 'filename="mock2.safetensors"',
}
session = requests.Session()
session.mount('http://www.civitai.com/models/12345',
TestAdapter(SAFETENSORS1_CONTENT,
status=200,
headers=SAFETENSORS1_HEADER)
)
session.mount('http://www.civitai.com/models/9999',
TestAdapter(SAFETENSORS2_CONTENT,
status=200,
headers=SAFETENSORS2_HEADER)
)
session.mount(
"http://www.civitai.com/models/12345", TestAdapter(SAFETENSORS1_CONTENT, status=200, headers=SAFETENSORS1_HEADER)
)
session.mount(
"http://www.civitai.com/models/9999", TestAdapter(SAFETENSORS2_CONTENT, status=200, headers=SAFETENSORS2_HEADER)
)
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
session.mount('http://www.civitai.com/models/broken',
TestAdapter(b'Not found',
status=404)
)
def test_basic_queue():
events = list()
def event_handler(job: DownloadJobBase):
@ -56,16 +49,12 @@ def test_basic_queue():
event_handlers=[event_handler],
)
with tempfile.TemporaryDirectory() as tmpdir:
id1 = queue.create_download_job(
source='http://www.civitai.com/models/12345',
destdir=tmpdir,
start=False
)
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
assert type(id1) == int, "expected first job id to be numeric"
job = queue.id_to_job(id1)
assert isinstance(job, DownloadJobBase), "expected job to be a DownloadJobBase"
assert job.status == 'idle', "expected job status to be idle"
assert job.status == "idle", "expected job status to be idle"
assert job.status == DownloadJobStatus.IDLE
queue.start_job(id1)
@ -73,7 +62,8 @@ def test_basic_queue():
assert events[0] == DownloadJobStatus.ENQUEUED
assert events[-1] == DownloadJobStatus.COMPLETED
assert DownloadJobStatus.RUNNING in events
assert Path(tmpdir, 'mock1.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
def test_queue_priority():
queue = DownloadQueue(
@ -81,17 +71,9 @@ def test_queue_priority():
)
with tempfile.TemporaryDirectory() as tmpdir:
id1 = queue.create_download_job(
source='http://www.civitai.com/models/12345',
destdir=tmpdir,
start=False
)
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
id2 = queue.create_download_job(
source='http://www.civitai.com/models/9999',
destdir=tmpdir,
start=False
)
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(id1, -10) # make id1 run first
job1 = queue.id_to_job(id1)
@ -103,17 +85,9 @@ def test_queue_priority():
queue.join()
assert job1.job_sequence < job2.job_sequence
id1 = queue.create_download_job(
source='http://www.civitai.com/models/12345',
destdir=tmpdir,
start=False
)
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
id2 = queue.create_download_job(
source='http://www.civitai.com/models/9999',
destdir=tmpdir,
start=False
)
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(id2, -10) # make id2 run first
job1 = queue.id_to_job(id1)
@ -124,5 +98,5 @@ def test_queue_priority():
queue.join()
assert job2.job_sequence < job1.job_sequence
assert Path(tmpdir, 'mock1.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, 'mock2.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock2.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"