refactor model install job class hierarchy

This commit is contained in:
Lincoln Stein
2023-10-04 14:51:59 -04:00
parent a180c0f241
commit cb0fdf3394
7 changed files with 71 additions and 68 deletions

View File

@ -293,9 +293,6 @@ class EventServiceBase:
def emit_model_event(self, job: DownloadJobBase) -> None:
"""Emit event when the status of a download/install job changes."""
logger = InvokeAILogger.get_logger()
progress = 100 * (job.bytes / job.total_bytes) if job.total_bytes > 0 else 0
logger.debug(f"Dispatch model_event for job {job.id}, status={job.status.value}, progress={progress:5.2f}%")
self.dispatch( # use dispatch() directly here because we are not a session event.
event_name="model_event", payload=dict(job=job)
)

View File

@ -27,11 +27,12 @@ from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
from .events import EventServiceBase
# processor is giving circular import errors
# from .processor import Invoker
from .config import InvokeAIAppConfig
from .events import EventServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext

View File

@ -1,8 +1,6 @@
"""Initialization file for threaded download manager."""
from .base import ( # noqa F401
HTTP_RE,
REPO_ID_RE,
DownloadEventHandler,
DownloadJobBase,
DownloadJobStatus,

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union
import requests
from pydantic import BaseModel, Field
@ -13,10 +13,6 @@ from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
# Used to distinguish between repo_id sources and URL sources
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
HTTP_RE = r"^https?://"
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@ -53,16 +49,11 @@ class DownloadJobBase(BaseModel):
"""Class to monitor and control a model download request."""
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
source: Union[str, Path] = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
metadata: ModelSourceMetadata = Field(
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
)
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
source: Any = Field(description="Where to download from. Specific types specified in child classes.")
destination: Path = Field(description="Destination of downloaded model on local disk")
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
event_handlers: Optional[List[DownloadEventHandler]] = Field(
description="Callables that will be called whenever job status changes",
default_factory=list,
@ -70,15 +61,15 @@ class DownloadJobBase(BaseModel):
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
job_sequence: Optional[int] = Field(
description="Counter that records order in which this job was dequeued (for debugging)"
)
subqueue: Optional["DownloadQueueBase"] = Field(
description="a subqueue used for downloading repo_ids", default=None
description="Counter that records order in which this job was dequeued (used in unit testing)"
)
preserve_partial_downloads: bool = Field(
description="if true, then preserve partial downloads when cancelled or errored", default=False
)
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
metadata: ModelSourceMetadata = Field(
description="Metadata describing download contents", default_factory=ModelSourceMetadata
)
def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list."""
@ -89,6 +80,10 @@ class DownloadJobBase(BaseModel):
"""Clear all event handlers."""
self.event_handlers = list()
def cleanup(self, preserve_partial_downloads: bool = False):
"""Possibly do some action when work is finished."""
pass
class Config:
"""Config object for this pydantic class."""

View File

@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
import requests
from huggingface_hub import HfApi, hf_hub_url
from pydantic import Field, ValidationError, parse_obj_as, validator
from pydantic import Field, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
@ -21,7 +21,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import InvokeAILogger, Logger
from ..storage import DuplicateModelException
from . import HTTP_RE, REPO_ID_RE
from .base import (
DownloadEventHandler,
DownloadJobBase,
@ -44,31 +43,53 @@ CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersi
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
HTTP_RE = r"^https?://"
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^[\w-]+/[.\w-]+(?::\w+)?$"
class DownloadJobURL(DownloadJobBase):
class DownloadJobPath(DownloadJobBase):
"""Download from a local Path."""
source: Path = Field(description="Local filesystem Path where model can be found")
class DownloadJobRemoteSource(DownloadJobBase):
"""A DownloadJob from a remote source that provides progress info."""
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
access_token: Optional[str] = Field(description="access token needed to access this resource")
class DownloadJobURL(DownloadJobRemoteSource):
"""Job declaration for downloading individual URLs."""
source: AnyHttpUrl = Field(description="URL to download")
class DownloadJobRepoID(DownloadJobBase):
class DownloadJobRepoID(DownloadJobRemoteSource):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
source: str = Field(description="URL or repo_id to download")
subqueue: Optional["DownloadQueueBase"] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def _validate_source(cls, v: str) -> str:
if not re.match(REPO_ID_RE, v):
raise ValidationError(f"{v} invalid repo_id", cls)
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
class DownloadJobPath(DownloadJobBase):
"""Handle file paths."""
source: Union[str, Path] = Field(description="Path to a file or directory to install")
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class DownloadQueue(DownloadQueueBase):
@ -124,23 +145,26 @@ class DownloadQueue(DownloadQueueBase):
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> DownloadJobBase:
"""Create a download job and return its ID."""
kwargs = dict()
kwargs: Dict[str, Optional[str]] = dict()
cls = DownloadJobBase
if Path(source).exists():
cls = DownloadJobPath
elif re.match(REPO_ID_RE, str(source)):
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs = dict(variant=variant)
kwargs.update(
variant=variant,
access_token=access_token,
)
elif re.match(HTTP_RE, str(source)):
cls = DownloadJobURL
kwargs.update(access_token=access_token)
else:
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
job = cls(
source=str(source),
source=source,
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handlers=(event_handlers or self._event_handlers),
priority=priority,
**kwargs,
@ -220,8 +244,7 @@ class DownloadQueue(DownloadQueueBase):
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED)
if job.subqueue:
job.subqueue.cancel_all_jobs(preserve_partial=preserve_partial)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
@ -254,9 +277,7 @@ class DownloadQueue(DownloadQueueBase):
self._lock.acquire()
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.PAUSED)
if job.subqueue:
job.subqueue.cancel_all_jobs(preserve_partial=True)
job.subqueue.release()
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
@ -323,7 +344,6 @@ class DownloadQueue(DownloadQueueBase):
done = True
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
# There should be a better way to dispatch on the job type
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
if isinstance(job, DownloadJobURL):
@ -340,7 +360,7 @@ class DownloadQueue(DownloadQueueBase):
self._queue.task_done()
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
@ -406,6 +426,7 @@ class DownloadQueue(DownloadQueueBase):
"""Do the actual download."""
dest = None
try:
assert isinstance(job, DownloadJobRemoteSource)
url = self._get_metadata_and_url(job)
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
@ -517,10 +538,11 @@ class DownloadQueue(DownloadQueueBase):
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
def _download_repoid(self, job: DownloadJobBase):
def _download_repoid(self, job: DownloadJobRepoID):
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
@ -529,8 +551,7 @@ class DownloadQueue(DownloadQueueBase):
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
if subjob.subqueue:
subjob.subqueue.cancel_all_jobs()
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
@ -546,7 +567,7 @@ class DownloadQueue(DownloadQueueBase):
quiet=True,
)
try:
job = DownloadJobRepoID.parse_obj(job)
assert isinstance(job, DownloadJobRepoID)
repo_id = job.source
variant = job.variant
if not job.metadata:
@ -668,11 +689,9 @@ class DownloadQueue(DownloadQueueBase):
source = Path(job.source).resolve()
destination = Path(job.destination).resolve()
try:
job.total_bytes = source.stat().st_size
self._update_job_status(job, DownloadJobStatus.RUNNING)
if source != destination:
shutil.move(source, destination)
job.bytes = destination.stat().st_size
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except OSError as excp:
job.error = excp

View File

@ -71,16 +71,8 @@ from .config import (
SchedulerPredictionType,
SubModelType,
)
from .download import (
HTTP_RE,
REPO_ID_RE,
DownloadEventHandler,
DownloadJobBase,
DownloadQueue,
DownloadQueueBase,
ModelSourceMetadata,
)
from .download.queue import DownloadJobPath, DownloadJobRepoID, DownloadJobURL
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
from .download.queue import HTTP_RE, REPO_ID_RE, DownloadJobPath, DownloadJobRepoID, DownloadJobURL
from .hash import FastModelHash
from .models import InvalidModelException
from .probe import ModelProbe, ModelProbeInfo
@ -542,7 +534,7 @@ class ModelInstall(ModelInstallBase):
return job
def _complete_installation_handler(self, job: DownloadJobBase):
job = ModelInstallJob.parse_obj(job) # this upcast should succeed
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_override)
@ -568,7 +560,7 @@ class ModelInstall(ModelInstallBase):
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
job = ModelInstallJob.parse_obj(job) # upcast should succeed
assert isinstance(job, ModelInstallJob)
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_override)

View File

@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod
from contextlib import suppress
from enum import Enum
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
import numpy as np
@ -78,7 +79,7 @@ class ModelBase(metaclass=ABCMeta):
self.base_model = base_model
self.model_type = model_type
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]:
if len(subtypes) < 2:
raise Exception("Invalid subfolder definition!")
if all(t is None for t in subtypes):