InvokeAI/invokeai/app/services/download/download_base.py
Lincoln Stein 4536e4a8b6
Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)
* add basic functionality for model metadata fetching from hf and civitai

* add storage

* start unit tests

* add unit tests and documentation

* add missing dependency for pytests

* remove redundant fetch; add modified/published dates; updated docs

* add code to select diffusers files based on the variant type

* implement Civitai installs

* make huggingface parallel downloading work

* add unit tests for model installation manager

- Fixed race condition on selection of download destination path
- Add fixtures common to several model_manager_2 unit tests
- Added dummy model files for testing diffusers and safetensors downloading/probing
- Refactored code for selecting proper variant from list of huggingface repo files
- Regrouped ordering of methods in model_install_default.py

* improve Civitai model downloading

- Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects
  to the HTML of an authorization page -- arrgh)
- Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc

* add routes for retrieving metadata and tags

* code tidying and documentation

* fix ruff errors

* add file needed to maintain test root diretory in repo for unit tests

* fix self->cls in classmethod

* add pydantic plugin for mypy

* use TestSession instead of requests.Session to prevent any internet activity

improve logging

fix error message formatting

fix logging again

fix forward vs reverse slash issue in Windows install tests

* Several fixes of problems detected during PR review:

- Implement cancel_model_install_job and get_model_install_job routes
  to allow for better control of model download and install.
- Fix thread deadlock that occurred after cancelling an install.
- Remove unneeded pytest_plugins section from tests/conftest.py
- Remove unused _in_terminal_state() from model_install_default.
- Remove outdated documentation from several spots.
- Add workaround for Civitai API results which don't return correct
  URL for the default model.

* fix docs and tests to match get_job_by_source() rather than get_job()

* Update invokeai/backend/model_manager/metadata/fetch/huggingface.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Call CivitaiMetadata.model_validate_json() directly

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Second round of revisions suggested by @ryanjdick:

- Fix type mismatch in `list_all_metadata()` route.
- Do not have a default value for the model install job id
- Remove static class variable declarations from non Pydantic classes
- Change `id` field to `model_id` for the sqlite3 `model_tags` table.
- Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables.
- Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness.

* Code cleanup suggested in PR review:

- Narrowed the declaration of the `parts` attribute of the download progress event
- Removed auto-conversion of str to Url in Url-containing sources
- Fixed handling of `InvalidModelConfigException`
- Made unknown sources raise `NotImplementedError` rather than `Exception`
- Improved status reporting on cached HuggingFace access tokens

* Multiple fixes:

- `job.total_size` returns a valid size for locally installed models
- new route `list_models` returns a paged summary of model, name,
  description, tags and other essential info
- fix a few type errors

* consolidated all invokeai root pytest fixtures into a single location

* Update invokeai/backend/model_manager/metadata/metadata_store.py

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

* Small tweaks in response to review comments:

- Remove flake8 configuration from pyproject.toml
- Use `id` rather than `modelId` for huggingface `ModelInfo` object
- Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object
- Add `sha256` field to file metadata downloaded from huggingface
- Add `Invoker` argument to the model installer `start()` and `stop()` routines
  (but made it optional in order to facilitate use of the service outside the API)
- Removed redundant `PRAGMA foreign_keys` from metadata store initialization code.

* Additional tweaks and minor bug fixes

- Fix calculation of aggregate diffusers model size to only count the
  size of files, not files + directories (which gives different unit test
  results on different filesystems).
- Refactor _get_metadata() and _get_download_urls() to have distinct code paths
  for Civitai, HuggingFace and URL sources.
- Forward the `inplace` flag from the source to the job and added unit test for this.
- Attach cached model metadata to the job rather than to the model install service.

* fix unit test that was breaking on windows due to CR/LF changing size of test json files

* fix ruff formatting

* a few last minor fixes before merging:

- Turn job `error` and `error_type` into properties derived from the exception.
- Add TODO comment about the reason for handling temporary directory destruction
  manually rather than using tempfile.tmpdir().

* add unit tests for reporting HTTP download errors

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2024-01-14 19:54:53 +00:00

263 lines
9.6 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Model download service."""
from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
"""State of a download job."""
WAITING = "waiting" # not enqueued, will not run
RUNNING = "running" # actively downloading
COMPLETED = "completed" # finished running
CANCELLED = "cancelled" # user cancelled
ERROR = "error" # terminated with an error message
class DownloadJobCancelledException(Exception):
"""This exception is raised when a download job is cancelled."""
class UnknownJobIDException(Exception):
"""This exception is raised when an invalid job id is referened."""
class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
# set when an error occurs
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
# internal flag
_cancelled: bool = PrivateAttr(default=False)
# optional event handlers passed in on creation
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
# cancelled and the callbacks are private attributes in order to prevent
# them from being serialized and/or used in the Json Schema
@property
def cancelled(self) -> bool:
"""Call to cancel the job."""
return self._cancelled
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def running(self) -> bool:
"""Return true if the job is running."""
return self.status == DownloadJobStatus.RUNNING
@property
def errored(self) -> bool:
"""Return true if the job is errored."""
return self.status == DownloadJobStatus.ERROR
@property
def in_terminal_state(self) -> bool:
"""Return true if job has finished, one way or another."""
return self.status not in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]
@property
def on_start(self) -> Optional[DownloadEventHandler]:
"""Return the on_start event handler."""
return self._on_start
@property
def on_progress(self) -> Optional[DownloadEventHandler]:
"""Return the on_progress event handler."""
return self._on_progress
@property
def on_complete(self) -> Optional[DownloadEventHandler]:
"""Return the on_complete event handler."""
return self._on_complete
@property
def on_error(self) -> Optional[DownloadExceptionHandler]:
"""Return the on_error event handler."""
return self._on_error
@property
def on_cancelled(self) -> Optional[DownloadEventHandler]:
"""Return the on_cancelled event handler."""
return self._on_cancelled
def set_callbacks(
self,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""Set the callbacks for download events."""
self._on_start = on_start
self._on_progress = on_progress
self._on_complete = on_complete
self._on_error = on_error
self._on_cancelled = on_cancelled
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@abstractmethod
def start(self, *args: Any, **kwargs: Any) -> None:
"""Start the download worker threads."""
@abstractmethod
def stop(self, *args: Any, **kwargs: Any) -> None:
"""Stop the download worker threads."""
@abstractmethod
def download(
self,
source: AnyHttpUrl,
dest: Path,
priority: int = 10,
access_token: Optional[str] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> DownloadJob:
"""
Create and enqueue download job.
:param source: Source of the download as a URL.
:param dest: Path to download to. See below.
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A DownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object. Its behavior is:
1. If the path exists and is a directory, then the URL contents will be downloaded
into that directory using the filename indicated in the response's `Content-Disposition` field.
If no content-disposition is present, then the last component of the URL will be used (similar to
wget's behavior).
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
content.
3. If the path exists and is an existing file, then the downloader will try to resume the download from
the end of the existing file.
"""
pass
@abstractmethod
def submit_download_job(
self,
job: DownloadJob,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> None:
"""
Enqueue a download job.
:param job: The DownloadJob
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
"""
pass
@abstractmethod
def list_jobs(self) -> List[DownloadJob]:
"""
List active download jobs.
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
"""
pass
@abstractmethod
def id_to_job(self, id: int) -> DownloadJob:
"""
Return the DownloadJob corresponding to the integer ID.
:param id: ID of the DownloadJob.
Exceptions:
* UnknownJobIDException
"""
pass
@abstractmethod
def cancel_all_jobs(self) -> None:
"""Cancel all active and enquedjobs."""
pass
@abstractmethod
def prune_jobs(self) -> None:
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def join(self) -> None:
"""Wait until all jobs are off the queue."""
pass