From cb0fdf3394058220d3d44bab1b84ccba747b176a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 4 Oct 2023 14:51:59 -0400 Subject: [PATCH] refactor model install job class hierarchy --- invokeai/app/services/events.py | 3 - .../app/services/model_manager_service.py | 5 +- .../model_manager/download/__init__.py | 2 - .../backend/model_manager/download/base.py | 29 +++---- .../backend/model_manager/download/queue.py | 81 ++++++++++++------- invokeai/backend/model_manager/install.py | 16 +--- invokeai/backend/model_manager/models/base.py | 3 +- 7 files changed, 71 insertions(+), 68 deletions(-) diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 5d65aa2d5e..b07dc79237 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -293,9 +293,6 @@ class EventServiceBase: def emit_model_event(self, job: DownloadJobBase) -> None: """Emit event when the status of a download/install job changes.""" - logger = InvokeAILogger.get_logger() - progress = 100 * (job.bytes / job.total_bytes) if job.total_bytes > 0 else 0 - logger.debug(f"Dispatch model_event for job {job.id}, status={job.status.value}, progress={progress:5.2f}%") self.dispatch( # use dispatch() directly here because we are not a session event. event_name="model_event", payload=dict(job=job) ) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 8b815350b7..5a809da866 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -27,11 +27,12 @@ from invokeai.backend.model_manager.cache import CacheStats from invokeai.backend.model_manager.download import DownloadJobBase from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger +from .config import InvokeAIAppConfig +from .events import EventServiceBase + # processor is giving circular import errors # from .processor import Invoker -from .config import InvokeAIAppConfig -from .events import EventServiceBase if TYPE_CHECKING: from ..invocations.baseinvocation import InvocationContext diff --git a/invokeai/backend/model_manager/download/__init__.py b/invokeai/backend/model_manager/download/__init__.py index ef4756be13..370f9127b4 100644 --- a/invokeai/backend/model_manager/download/__init__.py +++ b/invokeai/backend/model_manager/download/__init__.py @@ -1,8 +1,6 @@ """Initialization file for threaded download manager.""" from .base import ( # noqa F401 - HTTP_RE, - REPO_ID_RE, DownloadEventHandler, DownloadJobBase, DownloadJobStatus, diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 4465bd49d3..5240c68f6c 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import requests from pydantic import BaseModel, Field @@ -13,10 +13,6 @@ from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig -# Used to distinguish between repo_id sources and URL sources -REPO_ID_RE = r"^[\w-]+/[.\w-]+$" -HTTP_RE = r"^https?://" - class DownloadJobStatus(str, Enum): """State of a download job.""" @@ -53,16 +49,11 @@ class DownloadJobBase(BaseModel): """Class to monitor and control a model download request.""" priority: int = Field(default=10, description="Queue priority; lower values are higher priority") - id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder - source: Union[str, Path] = Field(description="URL or repo_id to download") - destination: Path = Field(description="Destination of URL on local disk") - metadata: ModelSourceMetadata = Field( - description="Model metadata (source-specific)", default_factory=ModelSourceMetadata - ) + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + source: Any = Field(description="Where to download from. Specific types specified in child classes.") + destination: Path = Field(description="Destination of downloaded model on local disk") access_token: Optional[str] = Field(description="access token needed to access this resource") status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download") - bytes: int = Field(default=0, description="Bytes downloaded so far") - total_bytes: int = Field(default=0, description="Total bytes to download") event_handlers: Optional[List[DownloadEventHandler]] = Field( description="Callables that will be called whenever job status changes", default_factory=list, @@ -70,15 +61,15 @@ class DownloadJobBase(BaseModel): job_started: Optional[float] = Field(description="Timestamp for when the download job started") job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)") job_sequence: Optional[int] = Field( - description="Counter that records order in which this job was dequeued (for debugging)" - ) - subqueue: Optional["DownloadQueueBase"] = Field( - description="a subqueue used for downloading repo_ids", default=None + description="Counter that records order in which this job was dequeued (used in unit testing)" ) preserve_partial_downloads: bool = Field( description="if true, then preserve partial downloads when cancelled or errored", default=False ) error: Optional[Exception] = Field(default=None, description="Exception that caused an error") + metadata: ModelSourceMetadata = Field( + description="Metadata describing download contents", default_factory=ModelSourceMetadata + ) def add_event_handler(self, handler: DownloadEventHandler): """Add an event handler to the end of the handlers list.""" @@ -89,6 +80,10 @@ class DownloadJobBase(BaseModel): """Clear all event handlers.""" self.event_handlers = list() + def cleanup(self, preserve_partial_downloads: bool = False): + """Possibly do some action when work is finished.""" + pass + class Config: """Config object for this pydantic class.""" diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index bd65a706de..67d972c6db 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union import requests from huggingface_hub import HfApi, hf_hub_url -from pydantic import Field, ValidationError, parse_obj_as, validator +from pydantic import Field, parse_obj_as, validator from pydantic.networks import AnyHttpUrl from requests import HTTPError @@ -21,7 +21,6 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import InvokeAILogger, Logger from ..storage import DuplicateModelException -from . import HTTP_RE, REPO_ID_RE from .base import ( DownloadEventHandler, DownloadJobBase, @@ -44,31 +43,53 @@ CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersi CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/" CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/" +# Regular expressions to describe repo_ids and http urls +HTTP_RE = r"^https?://" +REPO_ID_RE = r"^[\w-]+/[.\w-]+$" +REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^[\w-]+/[.\w-]+(?::\w+)?$" -class DownloadJobURL(DownloadJobBase): + +class DownloadJobPath(DownloadJobBase): + """Download from a local Path.""" + + source: Path = Field(description="Local filesystem Path where model can be found") + + +class DownloadJobRemoteSource(DownloadJobBase): + """A DownloadJob from a remote source that provides progress info.""" + + bytes: int = Field(default=0, description="Bytes downloaded so far") + total_bytes: int = Field(default=0, description="Total bytes to download") + access_token: Optional[str] = Field(description="access token needed to access this resource") + + +class DownloadJobURL(DownloadJobRemoteSource): """Job declaration for downloading individual URLs.""" source: AnyHttpUrl = Field(description="URL to download") -class DownloadJobRepoID(DownloadJobBase): +class DownloadJobRepoID(DownloadJobRemoteSource): """Download repo ids.""" + source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)") variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") - source: str = Field(description="URL or repo_id to download") + subqueue: Optional["DownloadQueueBase"] = Field( + description="a subqueue used for downloading the individual files in the repo_id", default=None + ) @validator("source") @classmethod - def _validate_source(cls, v: str) -> str: - if not re.match(REPO_ID_RE, v): - raise ValidationError(f"{v} invalid repo_id", cls) + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v): + raise ValueError(f"{v}: invalid repo_id format") return v - -class DownloadJobPath(DownloadJobBase): - """Handle file paths.""" - - source: Union[str, Path] = Field(description="Path to a file or directory to install") + def cleanup(self, preserve_partial_downloads: bool = False): + """Perform action when job is completed.""" + if self.subqueue: + self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads) + self.subqueue.release() class DownloadQueue(DownloadQueueBase): @@ -124,23 +145,26 @@ class DownloadQueue(DownloadQueueBase): event_handlers: Optional[List[DownloadEventHandler]] = None, ) -> DownloadJobBase: """Create a download job and return its ID.""" - kwargs = dict() + kwargs: Dict[str, Optional[str]] = dict() cls = DownloadJobBase if Path(source).exists(): cls = DownloadJobPath - elif re.match(REPO_ID_RE, str(source)): + elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)): cls = DownloadJobRepoID - kwargs = dict(variant=variant) + kwargs.update( + variant=variant, + access_token=access_token, + ) elif re.match(HTTP_RE, str(source)): cls = DownloadJobURL + kwargs.update(access_token=access_token) else: raise NotImplementedError(f"Don't know what to do with this type of source: {source}") job = cls( - source=str(source), + source=source, destination=Path(destdir) / (filename or "."), - access_token=access_token, event_handlers=(event_handlers or self._event_handlers), priority=priority, **kwargs, @@ -220,8 +244,7 @@ class DownloadQueue(DownloadQueueBase): assert isinstance(self._jobs[job.id], DownloadJobBase) job.preserve_partial_downloads = preserve_partial self._update_job_status(job, DownloadJobStatus.CANCELLED) - if job.subqueue: - job.subqueue.cancel_all_jobs(preserve_partial=preserve_partial) + job.cleanup() except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp finally: @@ -254,9 +277,7 @@ class DownloadQueue(DownloadQueueBase): self._lock.acquire() assert isinstance(self._jobs[job.id], DownloadJobBase) self._update_job_status(job, DownloadJobStatus.PAUSED) - if job.subqueue: - job.subqueue.cancel_all_jobs(preserve_partial=True) - job.subqueue.release() + job.cleanup() except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp finally: @@ -323,7 +344,6 @@ class DownloadQueue(DownloadQueueBase): done = True if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen) - # There should be a better way to dispatch on the job type if not self._quiet: self._logger.info(f"{job.source}: Downloading to {job.destination}") if isinstance(job, DownloadJobURL): @@ -340,7 +360,7 @@ class DownloadQueue(DownloadQueueBase): self._queue.task_done() - def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl: + def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl: """ Fetch metadata from certain well-known URLs. @@ -406,6 +426,7 @@ class DownloadQueue(DownloadQueueBase): """Do the actual download.""" dest = None try: + assert isinstance(job, DownloadJobRemoteSource) url = self._get_metadata_and_url(job) header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} @@ -517,10 +538,11 @@ class DownloadQueue(DownloadQueueBase): job.error = excp self._update_job_status(job, DownloadJobStatus.ERROR) - def _download_repoid(self, job: DownloadJobBase): + def _download_repoid(self, job: DownloadJobRepoID): """Download a job that holds a huggingface repoid.""" def subdownload_event(subjob: DownloadJobBase): + assert isinstance(subjob, DownloadJobRemoteSource) if subjob.status == DownloadJobStatus.RUNNING: bytes_downloaded[subjob.id] = subjob.bytes job.bytes = sum(bytes_downloaded.values()) @@ -529,8 +551,7 @@ class DownloadQueue(DownloadQueueBase): if subjob.status == DownloadJobStatus.ERROR: job.error = subjob.error - if subjob.subqueue: - subjob.subqueue.cancel_all_jobs() + job.cleanup() self._update_job_status(job, DownloadJobStatus.ERROR) return @@ -546,7 +567,7 @@ class DownloadQueue(DownloadQueueBase): quiet=True, ) try: - job = DownloadJobRepoID.parse_obj(job) + assert isinstance(job, DownloadJobRepoID) repo_id = job.source variant = job.variant if not job.metadata: @@ -668,11 +689,9 @@ class DownloadQueue(DownloadQueueBase): source = Path(job.source).resolve() destination = Path(job.destination).resolve() try: - job.total_bytes = source.stat().st_size self._update_job_status(job, DownloadJobStatus.RUNNING) if source != destination: shutil.move(source, destination) - job.bytes = destination.stat().st_size self._update_job_status(job, DownloadJobStatus.COMPLETED) except OSError as excp: job.error = excp diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 07b4aefd5e..4110f941c9 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -71,16 +71,8 @@ from .config import ( SchedulerPredictionType, SubModelType, ) -from .download import ( - HTTP_RE, - REPO_ID_RE, - DownloadEventHandler, - DownloadJobBase, - DownloadQueue, - DownloadQueueBase, - ModelSourceMetadata, -) -from .download.queue import DownloadJobPath, DownloadJobRepoID, DownloadJobURL +from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata +from .download.queue import HTTP_RE, REPO_ID_RE, DownloadJobPath, DownloadJobRepoID, DownloadJobURL from .hash import FastModelHash from .models import InvalidModelException from .probe import ModelProbe, ModelProbeInfo @@ -542,7 +534,7 @@ class ModelInstall(ModelInstallBase): return job def _complete_installation_handler(self, job: DownloadJobBase): - job = ModelInstallJob.parse_obj(job) # this upcast should succeed + assert isinstance(job, ModelInstallJob) if job.status == "completed": self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") model_id = self.install_path(job.destination, job.probe_override) @@ -568,7 +560,7 @@ class ModelInstall(ModelInstallBase): self._tmpdir = None def _complete_registration_handler(self, job: DownloadJobBase): - job = ModelInstallJob.parse_obj(job) # upcast should succeed + assert isinstance(job, ModelInstallJob) if job.status == "completed": self._logger.info(f"{job.source}: Installing in place.") model_id = self.register_path(job.destination, job.probe_override) diff --git a/invokeai/backend/model_manager/models/base.py b/invokeai/backend/model_manager/models/base.py index e9c2e54fae..63da894830 100644 --- a/invokeai/backend/model_manager/models/base.py +++ b/invokeai/backend/model_manager/models/base.py @@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod from contextlib import suppress from enum import Enum from pathlib import Path +from types import ModuleType from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union import numpy as np @@ -78,7 +79,7 @@ class ModelBase(metaclass=ABCMeta): self.base_model = base_model self.model_type = model_type - def _hf_definition_to_type(self, subtypes: List[str]) -> Type: + def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]: if len(subtypes) < 2: raise Exception("Invalid subfolder definition!") if all(t is None for t in subtypes):