mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add download manager to invoke services
This commit is contained in:
@ -28,6 +28,7 @@ from ..services.invoker import Invoker
|
||||
from ..services.processor import DefaultInvocationProcessor
|
||||
from ..services.sqlite import SqliteItemStorage
|
||||
from ..services.model_manager_service import ModelManagerService
|
||||
from ..services.download_manager import DownloadQueueService
|
||||
from ..services.invocation_stats import InvocationStatsService
|
||||
from .events import FastAPIEventService
|
||||
|
||||
@ -129,6 +130,7 @@ class ApiDependencies:
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||
download_manager=DownloadQueueService(event_bus=events),
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
@ -9,18 +9,19 @@ from typing import Optional, List
|
||||
from .events import EventServicesBase
|
||||
from invokeai.backend.model_manager.download import DownloadQueue, DownloadJobBase, DownloadEventHandler
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL or repo_id."""
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
@ -136,7 +137,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if self._event_bus:
|
||||
event_handlers.append([self._event_bus.emit_model_download_event])
|
||||
return self._queue.create_download_job(
|
||||
source, destdir, filename, start, access_token,
|
||||
source,
|
||||
destdir,
|
||||
filename,
|
||||
start,
|
||||
access_token,
|
||||
event_handlers=event_handlers,
|
||||
)
|
||||
|
||||
|
@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||
from invokeai.app.services.download_manager import DownloadQueueServiceBase
|
||||
|
||||
|
||||
class InvocationServices:
|
||||
@ -34,6 +35,7 @@ class InvocationServices:
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
download_manager: "DownloadQueueServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
|
||||
def __init__(
|
||||
|
@ -6,7 +6,7 @@ from .base import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase
|
||||
DownloadJobBase,
|
||||
)
|
||||
|
||||
from .queue import DownloadQueue # noqa F401
|
||||
from .queue import DownloadQueue # noqa F401
|
||||
|
@ -15,12 +15,12 @@ from pydantic.networks import AnyHttpUrl
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
|
||||
IDLE = "idle" # not enqueued, will not run
|
||||
ENQUEUED = "enqueued" # enqueued but not yet active
|
||||
RUNNING = "running" # actively downloading
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
IDLE = "idle" # not enqueued, will not run
|
||||
ENQUEUED = "enqueued" # enqueued but not yet active
|
||||
RUNNING = "running" # actively downloading
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
ERROR = "error" # terminated with an error message
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
@ -46,13 +46,17 @@ class DownloadJobBase(BaseModel):
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = Field(description="Callables that will be called whenever job status changes")
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = Field(
|
||||
description="Callables that will be called whenever job status changes"
|
||||
)
|
||||
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
|
||||
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
|
||||
job_sequence: Optional[int] = Field(description="Counter that records order in which this job was dequeued (for debugging)")
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (for debugging)"
|
||||
)
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
class Config():
|
||||
class Config:
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
@ -74,14 +78,14 @@ class DownloadQueueBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create a download job.
|
||||
@ -128,7 +132,7 @@ class DownloadQueueBase(ABC):
|
||||
|
||||
Note that once a job is completed, id_to_job() may no longer
|
||||
recognize the job. Call id_to_job() before the job completes
|
||||
if you wish to keep the job object around after it has
|
||||
if you wish to keep the job object around after it has
|
||||
completed work.
|
||||
"""
|
||||
pass
|
||||
@ -186,5 +190,3 @@ class DownloadQueueBase(ABC):
|
||||
no longer recognize the job.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
@ -23,15 +23,11 @@ from .base import (
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase
|
||||
DownloadJobBase,
|
||||
)
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = DownloadJobBase(
|
||||
id=-99,
|
||||
priority=-99,
|
||||
source='dummy',
|
||||
destination='/')
|
||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||
|
||||
|
||||
class DownloadJobURL(DownloadJobBase):
|
||||
@ -45,13 +41,14 @@ class DownloadJobRepoID(DownloadJobBase):
|
||||
|
||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||
|
||||
@validator('source')
|
||||
@validator("source")
|
||||
@classmethod
|
||||
def _validate_source(cls, v: str) -> str:
|
||||
if not re.match(r'^[\w-]+/[\w-]+$', v):
|
||||
raise ValidationError(f'{v} invalid repo_id')
|
||||
if not re.match(r"^[\w-]+/[\w-]+$", v):
|
||||
raise ValidationError(f"{v} invalid repo_id")
|
||||
return v
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
"""Class for queued download of models."""
|
||||
|
||||
@ -65,11 +62,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
|
||||
def __init__(self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
@ -89,17 +87,17 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
def create_download_job(
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
self,
|
||||
source: str,
|
||||
destdir: Path,
|
||||
filename: Optional[Path] = None,
|
||||
start: bool = True,
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
"""Create a download job and return its ID."""
|
||||
if re.match(r'^[\w-]+/[\w-]+$', source):
|
||||
if re.match(r"^[\w-]+/[\w-]+$", source):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs = dict(variant=variant)
|
||||
else:
|
||||
@ -249,7 +247,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
break
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
|
||||
if isinstance(job, DownloadJobURL):
|
||||
self._download_with_resume(job)
|
||||
elif isinstance(job, DownloadJobRepoID):
|
||||
@ -322,10 +320,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _update_job_status(self,
|
||||
job: DownloadJobBase,
|
||||
new_status: Optional[DownloadJobStatus] = None
|
||||
):
|
||||
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
if new_status:
|
||||
job.status = new_status
|
||||
@ -367,7 +362,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
destdir=job.destination / subdir,
|
||||
filename=file,
|
||||
variant=variant,
|
||||
access_token=job.access_token
|
||||
access_token=job.access_token,
|
||||
)
|
||||
except Exception as excp:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@ -382,7 +377,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
def _get_download_size(self, url: AnyHttpUrl) -> int:
|
||||
resp = self._requests.get(url, stream=True)
|
||||
resp.raise_for_status()
|
||||
return int(resp.headers.get('content-length',0))
|
||||
return int(resp.headers.get("content-length", 0))
|
||||
|
||||
def _get_repo_urls(self, repo_id: str, variant: Optional[str] = None) -> List[Tuple[AnyHttpUrl, Path, Path]]:
|
||||
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
|
||||
@ -392,11 +387,14 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if "model_index.json" in paths:
|
||||
url = hf_hub_url(repo_id, filename="model_index.json")
|
||||
resp = self._requests.get(url)
|
||||
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
||||
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
||||
submodels = resp.json()
|
||||
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, "model_index.json")
|
||||
return [(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path('.'), x.name) for x in self._select_variants(paths, variant)]
|
||||
return [
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name)
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
@ -404,7 +402,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
basenames = dict()
|
||||
for p in paths:
|
||||
path = Path(p)
|
||||
if path.suffix in ['.bin', '.safetensors', '.pt']:
|
||||
if path.suffix in [".bin", ".safetensors", ".pt"]:
|
||||
parent = path.parent
|
||||
suffixes = path.suffixes
|
||||
if len(suffixes) == 2:
|
||||
@ -416,9 +414,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
basename = parent / path.stem
|
||||
|
||||
if previous := basenames.get(basename):
|
||||
if previous.suffix != '.safetensors' and suffix == '.safetensors':
|
||||
if previous.suffix != ".safetensors" and suffix == ".safetensors":
|
||||
basenames[basename] = path
|
||||
if file_variant == f'.{variant}':
|
||||
if file_variant == f".{variant}":
|
||||
basenames[basename] = path
|
||||
elif not variant and not file_variant:
|
||||
basenames[basename] = path
|
||||
|
@ -17,35 +17,28 @@ from invokeai.backend.model_manager.download import (
|
||||
)
|
||||
|
||||
|
||||
SAFETENSORS1_CONTENT = b'I am a safetensors file (1)'
|
||||
SAFETENSORS1_CONTENT = b"I am a safetensors file (1)"
|
||||
SAFETENSORS1_HEADER = {
|
||||
'Content-Length' : len(SAFETENSORS1_CONTENT),
|
||||
'Content-Disposition': 'filename="mock1.safetensors"'
|
||||
"Content-Length": len(SAFETENSORS1_CONTENT),
|
||||
"Content-Disposition": 'filename="mock1.safetensors"',
|
||||
}
|
||||
SAFETENSORS2_CONTENT = b'I am a safetensors file (2)'
|
||||
SAFETENSORS2_CONTENT = b"I am a safetensors file (2)"
|
||||
SAFETENSORS2_HEADER = {
|
||||
'Content-Length' : len(SAFETENSORS2_CONTENT),
|
||||
'Content-Disposition': 'filename="mock2.safetensors"'
|
||||
"Content-Length": len(SAFETENSORS2_CONTENT),
|
||||
"Content-Disposition": 'filename="mock2.safetensors"',
|
||||
}
|
||||
session = requests.Session()
|
||||
session.mount('http://www.civitai.com/models/12345',
|
||||
TestAdapter(SAFETENSORS1_CONTENT,
|
||||
status=200,
|
||||
headers=SAFETENSORS1_HEADER)
|
||||
)
|
||||
session.mount('http://www.civitai.com/models/9999',
|
||||
TestAdapter(SAFETENSORS2_CONTENT,
|
||||
status=200,
|
||||
headers=SAFETENSORS2_HEADER)
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/12345", TestAdapter(SAFETENSORS1_CONTENT, status=200, headers=SAFETENSORS1_HEADER)
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/9999", TestAdapter(SAFETENSORS2_CONTENT, status=200, headers=SAFETENSORS2_HEADER)
|
||||
)
|
||||
|
||||
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
session.mount('http://www.civitai.com/models/broken',
|
||||
TestAdapter(b'Not found',
|
||||
status=404)
|
||||
)
|
||||
|
||||
def test_basic_queue():
|
||||
|
||||
events = list()
|
||||
|
||||
def event_handler(job: DownloadJobBase):
|
||||
@ -56,16 +49,12 @@ def test_basic_queue():
|
||||
event_handlers=[event_handler],
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
id1 = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/12345',
|
||||
destdir=tmpdir,
|
||||
start=False
|
||||
)
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
assert type(id1) == int, "expected first job id to be numeric"
|
||||
|
||||
job = queue.id_to_job(id1)
|
||||
assert isinstance(job, DownloadJobBase), "expected job to be a DownloadJobBase"
|
||||
assert job.status == 'idle', "expected job status to be idle"
|
||||
assert job.status == "idle", "expected job status to be idle"
|
||||
assert job.status == DownloadJobStatus.IDLE
|
||||
|
||||
queue.start_job(id1)
|
||||
@ -73,7 +62,8 @@ def test_basic_queue():
|
||||
assert events[0] == DownloadJobStatus.ENQUEUED
|
||||
assert events[-1] == DownloadJobStatus.COMPLETED
|
||||
assert DownloadJobStatus.RUNNING in events
|
||||
assert Path(tmpdir, 'mock1.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
|
||||
|
||||
def test_queue_priority():
|
||||
queue = DownloadQueue(
|
||||
@ -81,17 +71,9 @@ def test_queue_priority():
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
id1 = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/12345',
|
||||
destdir=tmpdir,
|
||||
start=False
|
||||
)
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
|
||||
id2 = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/9999',
|
||||
destdir=tmpdir,
|
||||
start=False
|
||||
)
|
||||
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
queue.change_priority(id1, -10) # make id1 run first
|
||||
job1 = queue.id_to_job(id1)
|
||||
@ -103,17 +85,9 @@ def test_queue_priority():
|
||||
queue.join()
|
||||
assert job1.job_sequence < job2.job_sequence
|
||||
|
||||
id1 = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/12345',
|
||||
destdir=tmpdir,
|
||||
start=False
|
||||
)
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
|
||||
id2 = queue.create_download_job(
|
||||
source='http://www.civitai.com/models/9999',
|
||||
destdir=tmpdir,
|
||||
start=False
|
||||
)
|
||||
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
queue.change_priority(id2, -10) # make id2 run first
|
||||
job1 = queue.id_to_job(id1)
|
||||
@ -124,5 +98,5 @@ def test_queue_priority():
|
||||
queue.join()
|
||||
assert job2.job_sequence < job1.job_sequence
|
||||
|
||||
assert Path(tmpdir, 'mock1.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, 'mock2.safetensors').exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock2.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
|
Reference in New Issue
Block a user