mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor model install job class hierarchy
This commit is contained in:
@ -293,9 +293,6 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_event(self, job: DownloadJobBase) -> None:
|
||||
"""Emit event when the status of a download/install job changes."""
|
||||
logger = InvokeAILogger.get_logger()
|
||||
progress = 100 * (job.bytes / job.total_bytes) if job.total_bytes > 0 else 0
|
||||
logger.debug(f"Dispatch model_event for job {job.id}, status={job.status.value}, progress={progress:5.2f}%")
|
||||
self.dispatch( # use dispatch() directly here because we are not a session event.
|
||||
event_name="model_event", payload=dict(job=job)
|
||||
)
|
||||
|
@ -27,11 +27,12 @@ from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
# processor is giving circular import errors
|
||||
# from .processor import Invoker
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
@ -1,8 +1,6 @@
|
||||
"""Initialization file for threaded download manager."""
|
||||
|
||||
from .base import ( # noqa F401
|
||||
HTTP_RE,
|
||||
REPO_ID_RE,
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadJobStatus,
|
||||
|
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
@ -13,10 +13,6 @@ from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# Used to distinguish between repo_id sources and URL sources
|
||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
||||
HTTP_RE = r"^https?://"
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
@ -53,16 +49,11 @@ class DownloadJobBase(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
|
||||
source: Union[str, Path] = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
metadata: ModelSourceMetadata = Field(
|
||||
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
|
||||
)
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
source: Any = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
destination: Path = Field(description="Destination of downloaded model on local disk")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = Field(
|
||||
description="Callables that will be called whenever job status changes",
|
||||
default_factory=list,
|
||||
@ -70,15 +61,15 @@ class DownloadJobBase(BaseModel):
|
||||
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
|
||||
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (for debugging)"
|
||||
)
|
||||
subqueue: Optional["DownloadQueueBase"] = Field(
|
||||
description="a subqueue used for downloading repo_ids", default=None
|
||||
description="Counter that records order in which this job was dequeued (used in unit testing)"
|
||||
)
|
||||
preserve_partial_downloads: bool = Field(
|
||||
description="if true, then preserve partial downloads when cancelled or errored", default=False
|
||||
)
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
metadata: ModelSourceMetadata = Field(
|
||||
description="Metadata describing download contents", default_factory=ModelSourceMetadata
|
||||
)
|
||||
|
||||
def add_event_handler(self, handler: DownloadEventHandler):
|
||||
"""Add an event handler to the end of the handlers list."""
|
||||
@ -89,6 +80,10 @@ class DownloadJobBase(BaseModel):
|
||||
"""Clear all event handlers."""
|
||||
self.event_handlers = list()
|
||||
|
||||
def cleanup(self, preserve_partial_downloads: bool = False):
|
||||
"""Possibly do some action when work is finished."""
|
||||
pass
|
||||
|
||||
class Config:
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
|
@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from pydantic import Field, ValidationError, parse_obj_as, validator
|
||||
from pydantic import Field, parse_obj_as, validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
|
||||
@ -21,7 +21,6 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
from ..storage import DuplicateModelException
|
||||
from . import HTTP_RE, REPO_ID_RE
|
||||
from .base import (
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
@ -44,31 +43,53 @@ CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersi
|
||||
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
|
||||
# Regular expressions to describe repo_ids and http urls
|
||||
HTTP_RE = r"^https?://"
|
||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^[\w-]+/[.\w-]+(?::\w+)?$"
|
||||
|
||||
class DownloadJobURL(DownloadJobBase):
|
||||
|
||||
class DownloadJobPath(DownloadJobBase):
|
||||
"""Download from a local Path."""
|
||||
|
||||
source: Path = Field(description="Local filesystem Path where model can be found")
|
||||
|
||||
|
||||
class DownloadJobRemoteSource(DownloadJobBase):
|
||||
"""A DownloadJob from a remote source that provides progress info."""
|
||||
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
|
||||
|
||||
class DownloadJobURL(DownloadJobRemoteSource):
|
||||
"""Job declaration for downloading individual URLs."""
|
||||
|
||||
source: AnyHttpUrl = Field(description="URL to download")
|
||||
|
||||
|
||||
class DownloadJobRepoID(DownloadJobBase):
|
||||
class DownloadJobRepoID(DownloadJobRemoteSource):
|
||||
"""Download repo ids."""
|
||||
|
||||
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
subqueue: Optional["DownloadQueueBase"] = Field(
|
||||
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
||||
)
|
||||
|
||||
@validator("source")
|
||||
@classmethod
|
||||
def _validate_source(cls, v: str) -> str:
|
||||
if not re.match(REPO_ID_RE, v):
|
||||
raise ValidationError(f"{v} invalid repo_id", cls)
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
|
||||
class DownloadJobPath(DownloadJobBase):
|
||||
"""Handle file paths."""
|
||||
|
||||
source: Union[str, Path] = Field(description="Path to a file or directory to install")
|
||||
def cleanup(self, preserve_partial_downloads: bool = False):
|
||||
"""Perform action when job is completed."""
|
||||
if self.subqueue:
|
||||
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
|
||||
self.subqueue.release()
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
@ -124,23 +145,26 @@ class DownloadQueue(DownloadQueueBase):
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
kwargs = dict()
|
||||
kwargs: Dict[str, Optional[str]] = dict()
|
||||
|
||||
cls = DownloadJobBase
|
||||
if Path(source).exists():
|
||||
cls = DownloadJobPath
|
||||
elif re.match(REPO_ID_RE, str(source)):
|
||||
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||
cls = DownloadJobRepoID
|
||||
kwargs = dict(variant=variant)
|
||||
kwargs.update(
|
||||
variant=variant,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(HTTP_RE, str(source)):
|
||||
cls = DownloadJobURL
|
||||
kwargs.update(access_token=access_token)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
||||
job = cls(
|
||||
source=str(source),
|
||||
source=source,
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handlers=(event_handlers or self._event_handlers),
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
@ -220,8 +244,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
if job.subqueue:
|
||||
job.subqueue.cancel_all_jobs(preserve_partial=preserve_partial)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
@ -254,9 +277,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
if job.subqueue:
|
||||
job.subqueue.cancel_all_jobs(preserve_partial=True)
|
||||
job.subqueue.release()
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
@ -323,7 +344,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
done = True
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
# There should be a better way to dispatch on the job type
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
if isinstance(job, DownloadJobURL):
|
||||
@ -340,7 +360,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl:
|
||||
"""
|
||||
Fetch metadata from certain well-known URLs.
|
||||
|
||||
@ -406,6 +426,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Do the actual download."""
|
||||
dest = None
|
||||
try:
|
||||
assert isinstance(job, DownloadJobRemoteSource)
|
||||
url = self._get_metadata_and_url(job)
|
||||
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
@ -517,10 +538,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _download_repoid(self, job: DownloadJobBase):
|
||||
def _download_repoid(self, job: DownloadJobRepoID):
|
||||
"""Download a job that holds a huggingface repoid."""
|
||||
|
||||
def subdownload_event(subjob: DownloadJobBase):
|
||||
assert isinstance(subjob, DownloadJobRemoteSource)
|
||||
if subjob.status == DownloadJobStatus.RUNNING:
|
||||
bytes_downloaded[subjob.id] = subjob.bytes
|
||||
job.bytes = sum(bytes_downloaded.values())
|
||||
@ -529,8 +551,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if subjob.status == DownloadJobStatus.ERROR:
|
||||
job.error = subjob.error
|
||||
if subjob.subqueue:
|
||||
subjob.subqueue.cancel_all_jobs()
|
||||
job.cleanup()
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
return
|
||||
|
||||
@ -546,7 +567,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
quiet=True,
|
||||
)
|
||||
try:
|
||||
job = DownloadJobRepoID.parse_obj(job)
|
||||
assert isinstance(job, DownloadJobRepoID)
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
if not job.metadata:
|
||||
@ -668,11 +689,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
source = Path(job.source).resolve()
|
||||
destination = Path(job.destination).resolve()
|
||||
try:
|
||||
job.total_bytes = source.stat().st_size
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
if source != destination:
|
||||
shutil.move(source, destination)
|
||||
job.bytes = destination.stat().st_size
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except OSError as excp:
|
||||
job.error = excp
|
||||
|
@ -71,16 +71,8 @@ from .config import (
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
from .download import (
|
||||
HTTP_RE,
|
||||
REPO_ID_RE,
|
||||
DownloadEventHandler,
|
||||
DownloadJobBase,
|
||||
DownloadQueue,
|
||||
DownloadQueueBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
from .download.queue import DownloadJobPath, DownloadJobRepoID, DownloadJobURL
|
||||
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
|
||||
from .download.queue import HTTP_RE, REPO_ID_RE, DownloadJobPath, DownloadJobRepoID, DownloadJobURL
|
||||
from .hash import FastModelHash
|
||||
from .models import InvalidModelException
|
||||
from .probe import ModelProbe, ModelProbeInfo
|
||||
@ -542,7 +534,7 @@ class ModelInstall(ModelInstallBase):
|
||||
return job
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
job = ModelInstallJob.parse_obj(job) # this upcast should succeed
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
@ -568,7 +560,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
job = ModelInstallJob.parse_obj(job) # upcast should succeed
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
|
@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, Generic, List, Literal, Optional, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
@ -78,7 +79,7 @@ class ModelBase(metaclass=ABCMeta):
|
||||
self.base_model = base_model
|
||||
self.model_type = model_type
|
||||
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Type:
|
||||
def _hf_definition_to_type(self, subtypes: List[str]) -> Optional[ModuleType]:
|
||||
if len(subtypes) < 2:
|
||||
raise Exception("Invalid subfolder definition!")
|
||||
if all(t is None for t in subtypes):
|
||||
|
Reference in New Issue
Block a user