mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add simplified model manager install API to InvocationContext (#6132)
## Summary This three two model manager-related methods to the InvocationContext uniform API. They are accessible via `context.models.*`: 1. **`load_local_model(model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated path.* This will load a local model (.safetensors, .ckpt or diffusers directory) into the model manager RAM cache and return its `LoadedModelWithoutConfig`. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call `safetensors.torch.load_file()` `torch.load()` (with a pickle scan), or `from_pretrained()` as appropriate to the path type. Be aware that the `LoadedModelWithoutConfig` object differs from `LoadedModel` by having no `config` attribute. Here is an example of usage: ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_path = Path('/opt/models/RealESRGAN_x4plus.pth') loadnet = context.models.load_local_model(model_path) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 2. **`load_remote_model(source: str | AnyHttpUrl, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated URL or repo_id.* This is similar to `load_local_model()` but it accepts either a HugginFace repo_id (as a string), or a URL. The model's file(s) will be downloaded to `models/.download_cache` and then loaded, returning a ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' loadnet = context.models.load_remote_model(model_url) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 3. **`download_and_cache_model( source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0) -> Path`** Download the model file located at source to the models cache and return its Path. This will check `models/.download_cache` for the desired model file and download it from the indicated source if not already present. The local Path to the downloaded file is then returned. --- ## Other Changes This PR performs a migration, in which it renames `models/.cache` to `models/.convert_cache`, and migrates previously-downloaded ESRGAN, openpose, DepthAnything and Lama inpaint models from the `models/core` directory into `models/.download_cache`. There are a number of legacy model files in `models/core`, such as GFPGAN, which are no longer used. This PR deletes them and tidies up the `models/core` directory. ## Related Issues / Discussions I have systematically replaced all the calls to `download_with_progress_bar()`. This function is no longer used elsewhere and has been removed. <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions I have added unit tests for the three new calls. You may test that the `load_and_cache_model()` call is working by running the upscaler within the web app. On first try, you will see the model file being downloaded into the models `.cache` directory. On subsequent tries, the model will either load from RAM (if it hasn't been displaced) or will be loaded from the filesystem. <!--WHEN APPLICABLE: Describe how we can test the changes in this PR.--> ## Merge Plan Squash merge when approved. <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [X] _The PR has a short but descriptive title, suitable for a changelog_ - [X] _Tests added / updated (if applicable)_ - [X] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
9432336e2b
@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
|||||||
specify the source and destination of the download, and keep track of
|
specify the source and destination of the download, and keep track of
|
||||||
the progress of the download.
|
the progress of the download.
|
||||||
|
|
||||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
Two job types are defined. `DownloadJob` and
|
||||||
|
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||||
following fields:
|
following fields:
|
||||||
|
|
||||||
| **Field** | **Type** | **Default** | **Description** |
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
@ -138,7 +139,7 @@ following fields:
|
|||||||
| `dest` | Path | | Where to download to |
|
| `dest` | Path | | Where to download to |
|
||||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
|||||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||||
the job's `cancelled` property will be set to True.
|
the job's `cancelled` property will be set to True.
|
||||||
|
|
||||||
|
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||||
|
which contain multiple files and directories under a common root:
|
||||||
|
|
||||||
|
| **Field** | **Type** | **Default** | **Description** |
|
||||||
|
|----------------|-----------------|---------------|-----------------|
|
||||||
|
| _Fields passed in at job creation time_ |
|
||||||
|
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||||
|
| `dest` | Path | | Where to download to |
|
||||||
|
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||||
|
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||||
|
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||||
|
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||||
|
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||||
|
| _Fields updated over the course of the download task_
|
||||||
|
| `status` | DownloadJobStatus| | Status code |
|
||||||
|
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||||
|
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||||
|
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||||
|
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||||
|
| `error` | str | | String version of the traceback associated with an error |
|
||||||
|
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||||
|
|
||||||
|
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||||
|
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||||
|
these from the individual download jobs in `download_parts`.
|
||||||
|
|
||||||
|
|
||||||
### Callbacks
|
### Callbacks
|
||||||
|
|
||||||
Download jobs can be associated with a series of callbacks, each with
|
Download jobs can be associated with a series of callbacks, each with
|
||||||
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
|||||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||||
with `join()`.
|
with `join()`.
|
||||||
|
|
||||||
#### job = queue.download(source, dest, priority, access_token)
|
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
Create a new download job and put it on the queue, returning the
|
Create a new download job and put it on the queue, returning the
|
||||||
DownloadJob object.
|
DownloadJob object.
|
||||||
|
|
||||||
|
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||||
|
|
||||||
|
This is similar to download(), but instead of taking a single source,
|
||||||
|
it accepts a `parts` argument consisting of a list of
|
||||||
|
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||||
|
where the URL is the location of the remote file, and the Path is the
|
||||||
|
destination.
|
||||||
|
|
||||||
|
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||||
|
consists of a url/path pair. Note that the path *must* be relative.
|
||||||
|
|
||||||
|
The method returns a `MultiFileDownloadJob`.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||||
|
path='my_model/textencoder/pytorch_model.safetensors'
|
||||||
|
)
|
||||||
|
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||||
|
path='my_model/vae/diffusers_model.safetensors'
|
||||||
|
)
|
||||||
|
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||||
|
dest='/tmp/downloads',
|
||||||
|
on_progress=TqdmProgress().update)
|
||||||
|
queue.wait_for_job(job)
|
||||||
|
print(f"The files were downloaded to {job.download_path}")
|
||||||
|
```
|
||||||
|
|
||||||
#### jobs = queue.list_jobs()
|
#### jobs = queue.list_jobs()
|
||||||
|
|
||||||
Return a list of all active and inactive `DownloadJob`s.
|
Return a list of all active and inactive `DownloadJob`s.
|
||||||
|
@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
|||||||
following initialization pattern:
|
following initialization pattern:
|
||||||
|
|
||||||
```
|
```
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||||
from invokeai.app.services.model_install import ModelInstallService
|
from invokeai.app.services.model_install import ModelInstallService
|
||||||
from invokeai.app.services.download import DownloadQueueService
|
from invokeai.app.services.download import DownloadQueueService
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = get_config()
|
||||||
config.parse_args()
|
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = SqliteDatabase(config, logger)
|
db = SqliteDatabase(config.db_path, logger)
|
||||||
record_store = ModelRecordServiceSQL(db)
|
record_store = ModelRecordServiceSQL(db)
|
||||||
queue = DownloadQueueService()
|
queue = DownloadQueueService()
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
installer = ModelInstallService(app_config=config,
|
installer = ModelInstallService(app_config=config,
|
||||||
record_store=record_store,
|
record_store=record_store,
|
||||||
download_queue=queue
|
download_queue=queue
|
||||||
)
|
)
|
||||||
installer.start()
|
installer.start()
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the
|
|||||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||||
model configuration to `load_model_by_config()`. It may raise a
|
model configuration to `load_model_by_config()`. It may raise a
|
||||||
`NotImplementedException`.
|
`NotImplementedException`.
|
||||||
|
|
||||||
|
## Invocation Context Model Manager API
|
||||||
|
|
||||||
|
Within invocations, the following methods are available from the
|
||||||
|
`InvocationContext` object:
|
||||||
|
|
||||||
|
### context.download_and_cache_model(source) -> Path
|
||||||
|
|
||||||
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
|
it locally, and then returns a Path to the local model. The source can
|
||||||
|
be a direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method loads a local model from the indicated path, returning a
|
||||||
|
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||||
|
to the object, and returns a `AnyModel` object. If no loader is
|
||||||
|
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||||
|
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||||
|
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||||
|
like a diffusers directory.
|
||||||
|
|
||||||
|
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||||
|
|
||||||
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
|
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||||
|
direct download URL or a HuggingFace repo_id.
|
||||||
|
|
||||||
|
In the case of HuggingFace repo_id, the following variants are
|
||||||
|
recognized:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4 -- default model
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||||
|
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||||
|
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||||
|
|
||||||
|
You can also point at an arbitrary individual file within a repo_id
|
||||||
|
directory using this syntax:
|
||||||
|
|
||||||
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ class ApiDependencies:
|
|||||||
conditioning = ObjectSerializerForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||||
from invokeai.backend.image_util.canny import get_canny_edges
|
from invokeai.backend.image_util.canny import get_canny_edges
|
||||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||||
from invokeai.backend.image_util.hed import HEDProcessor
|
from invokeai.backend.image_util.hed import HEDProcessor
|
||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||||
|
|
||||||
@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return context.images.get_pil(self.image.image_name, "RGB")
|
return context.images.get_pil(self.image.image_name, "RGB")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
self._context = context
|
||||||
raw_image = self.load_image(context)
|
raw_image = self.load_image(context)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
|
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(
|
||||||
image,
|
image,
|
||||||
@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = normalbae_processor(
|
processed_image = normalbae_processor(
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = mlsd_processor(
|
processed_image = mlsd_processor(
|
||||||
image,
|
image,
|
||||||
@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = pidi_processor(
|
processed_image = pidi_processor(
|
||||||
image,
|
image,
|
||||||
@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
content_shuffle_processor = ContentShuffleDetector()
|
content_shuffle_processor = ContentShuffleDetector()
|
||||||
processed_image = content_shuffle_processor(
|
processed_image = content_shuffle_processor(
|
||||||
image,
|
image,
|
||||||
@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = zoe_depth_processor(image)
|
processed_image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mediapipe_face_processor = MediapipeFaceDetector()
|
mediapipe_face_processor = MediapipeFaceDetector()
|
||||||
processed_image = mediapipe_face_processor(
|
processed_image = mediapipe_face_processor(
|
||||||
image,
|
image,
|
||||||
@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image,
|
||||||
@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
def run_processor(self, img):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_img = np.array(img, dtype=np.uint8)
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
processed_np_image = self.tile_resample(
|
processed_np_image = self.tile_resample(
|
||||||
np_img,
|
np_img,
|
||||||
# res=self.tile_size,
|
# res=self.tile_size,
|
||||||
@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|||||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||||
@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
np_image = np.array(image, dtype=np.uint8)
|
np_image = np.array(image, dtype=np.uint8)
|
||||||
height, width = np_image.shape[:2]
|
height, width = np_image.shape[:2]
|
||||||
|
|
||||||
@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
)
|
)
|
||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
depth_anything_detector = DepthAnythingDetector()
|
def loader(model_path: Path):
|
||||||
depth_anything_detector.load_model(model_size=self.model_size)
|
return DepthAnythingDetector.load_model(
|
||||||
|
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||||
|
)
|
||||||
|
|
||||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
with self._context.models.load_remote_model(
|
||||||
return processed_image
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||||
|
) as model:
|
||||||
|
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||||
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||||
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
draw_hands: bool = InputField(default=False)
|
draw_hands: bool = InputField(default=False)
|
||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image):
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
dw_openpose = DWOpenposeDetector()
|
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||||
|
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||||
|
|
||||||
|
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
processed_image = dw_openpose(
|
processed_image = dw_openpose(
|
||||||
image,
|
image,
|
||||||
draw_face=self.draw_face,
|
draw_face=self.draw_face,
|
||||||
|
@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"""Infill the image with the specified method"""
|
"""Infill the image with the specified method"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
def load_image(self) -> tuple[Image.Image, bool]:
|
||||||
"""Process the image to have an alpha channel before being infilled"""
|
"""Process the image to have an alpha channel before being infilled"""
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = self._context.images.get_pil(self.image.image_name)
|
||||||
has_alpha = True if image.mode == "RGBA" else False
|
has_alpha = True if image.mode == "RGBA" else False
|
||||||
return image, has_alpha
|
return image, has_alpha
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
self._context = context
|
||||||
# Retrieve and process image to be infilled
|
# Retrieve and process image to be infilled
|
||||||
input_image, has_alpha = self.load_image(context)
|
input_image, has_alpha = self.load_image()
|
||||||
|
|
||||||
# If the input image has no alpha channel, return it
|
# If the input image has no alpha channel, return it
|
||||||
if has_alpha is False:
|
if has_alpha is False:
|
||||||
@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
|||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
def infill(self, image: Image.Image):
|
def infill(self, image: Image.Image):
|
||||||
lama = LaMA()
|
with self._context.models.load_remote_model(
|
||||||
return lama(image)
|
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||||
|
loader=LaMA.load_jit_model,
|
||||||
|
) as model:
|
||||||
|
lama = LaMA(model)
|
||||||
|
return lama(image)
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -10,10 +9,8 @@ from pydantic import ConfigDict
|
|||||||
from invokeai.app.invocations.fields import ImageField
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
esrgan_model_path = None
|
|
||||||
|
|
||||||
if self.model_name in [
|
if self.model_name in [
|
||||||
"RealESRGAN_x4plus.pth",
|
"RealESRGAN_x4plus.pth",
|
||||||
@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
context.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
loadnet = context.models.load_remote_model(
|
||||||
|
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||||
# Downloads the ESRGAN model if it doesn't already exist
|
|
||||||
download_with_progress_bar(
|
|
||||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
upscaler = RealESRGAN(
|
with loadnet as loadnet_model:
|
||||||
scale=netscale,
|
upscaler = RealESRGAN(
|
||||||
model_path=esrgan_model_path,
|
scale=netscale,
|
||||||
model=rrdbnet_model,
|
loadnet=loadnet_model,
|
||||||
half=False,
|
model=rrdbnet_model,
|
||||||
tile=self.tile_size,
|
half=False,
|
||||||
)
|
tile=self.tile_size,
|
||||||
|
)
|
||||||
|
|
||||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
# TODO: This strips the alpha... is that okay?
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
patchmatch: Enable patchmatch inpaint code.
|
patchmatch: Enable patchmatch inpaint code.
|
||||||
models_dir: Path to the models directory.
|
models_dir: Path to the models directory.
|
||||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||||
|
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||||
db_dir: Path to InvokeAI databases directory.
|
db_dir: Path to InvokeAI databases directory.
|
||||||
outputs_dir: Path to directory for outputs.
|
outputs_dir: Path to directory for outputs.
|
||||||
@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
|
|
||||||
# PATHS
|
# PATHS
|
||||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||||
|
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||||
@ -303,6 +305,11 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||||
return self._resolve(self.convert_cache_dir)
|
return self._resolve(self.convert_cache_dir)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def download_cache_path(self) -> Path:
|
||||||
|
"""Path to the downloaded models directory, resolved to an absolute path.."""
|
||||||
|
return self._resolve(self.download_cache_dir)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def custom_nodes_path(self) -> Path:
|
def custom_nodes_path(self) -> Path:
|
||||||
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
||||||
|
@ -1,10 +1,17 @@
|
|||||||
"""Init file for download queue."""
|
"""Init file for download queue."""
|
||||||
|
|
||||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
from .download_base import (
|
||||||
|
DownloadJob,
|
||||||
|
DownloadJobStatus,
|
||||||
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadJob,
|
||||||
|
UnknownJobIDException,
|
||||||
|
)
|
||||||
from .download_default import DownloadQueueService, TqdmProgress
|
from .download_default import DownloadQueueService, TqdmProgress
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DownloadJob",
|
"DownloadJob",
|
||||||
|
"MultiFileDownloadJob",
|
||||||
"DownloadQueueServiceBase",
|
"DownloadQueueServiceBase",
|
||||||
"DownloadQueueService",
|
"DownloadQueueService",
|
||||||
"TqdmProgress",
|
"TqdmProgress",
|
||||||
|
@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobStatus(str, Enum):
|
class DownloadJobStatus(str, Enum):
|
||||||
"""State of a download job."""
|
"""State of a download job."""
|
||||||
@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
|
|||||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||||
|
|
||||||
|
|
||||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||||
|
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||||
|
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||||
|
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||||
|
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
class DownloadJobBase(BaseModel):
|
||||||
class DownloadJob(BaseModel):
|
"""Base of classes to monitor and control downloads."""
|
||||||
"""Class to monitor and control a model download request."""
|
|
||||||
|
|
||||||
# required variables to be passed in on creation
|
|
||||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
|
||||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
|
||||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
|
||||||
# automatically assigned on creation
|
# automatically assigned on creation
|
||||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
|
||||||
|
|
||||||
# set internally during download process
|
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||||
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
|
||||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
|
||||||
job_ended: Optional[str] = Field(
|
|
||||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
|
||||||
)
|
|
||||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
|
||||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||||
|
|
||||||
@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
|
|||||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""Return hash of the string representation of this object, for indexing."""
|
|
||||||
return hash(str(self))
|
|
||||||
|
|
||||||
def __le__(self, other: "DownloadJob") -> bool:
|
|
||||||
"""Return True if this job's priority is less than another's."""
|
|
||||||
return self.priority <= other.priority
|
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
"""Call to cancel the job."""
|
"""Call to cancel the job."""
|
||||||
self._cancelled = True
|
self._cancelled = True
|
||||||
@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
|
|||||||
"""Return true if job completed without errors."""
|
"""Return true if job completed without errors."""
|
||||||
return self.status == DownloadJobStatus.COMPLETED
|
return self.status == DownloadJobStatus.COMPLETED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def waiting(self) -> bool:
|
||||||
|
"""Return true if the job is waiting to run."""
|
||||||
|
return self.status == DownloadJobStatus.WAITING
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
"""Return true if the job is running."""
|
"""Return true if the job is running."""
|
||||||
@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
|
|||||||
self._on_cancelled = on_cancelled
|
self._on_cancelled = on_cancelled
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class DownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control a model download request."""
|
||||||
|
|
||||||
|
# required variables to be passed in on creation
|
||||||
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||||
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||||
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||||
|
|
||||||
|
# set internally during download process
|
||||||
|
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||||
|
job_ended: Optional[str] = Field(
|
||||||
|
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||||
|
)
|
||||||
|
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return hash of the string representation of this object, for indexing."""
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __le__(self, other: "DownloadJob") -> bool:
|
||||||
|
"""Return True if this job's priority is less than another's."""
|
||||||
|
return self.priority <= other.priority
|
||||||
|
|
||||||
|
|
||||||
|
class MultiFileDownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control multifile downloads."""
|
||||||
|
|
||||||
|
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueueServiceBase(ABC):
|
class DownloadQueueServiceBase(ABC):
|
||||||
"""Multithreaded queue for downloading models via URL."""
|
"""Multithreaded queue for downloading models via URL."""
|
||||||
|
|
||||||
@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
"""
|
||||||
|
Create and enqueue a multifile download job.
|
||||||
|
|
||||||
|
:param parts: Set of URL / filename pairs
|
||||||
|
:param dest: Path to download to. See below.
|
||||||
|
:param access_token: Access token to download the indicated files. If not provided,
|
||||||
|
each file's URL may be matched to an access token using the config file matching
|
||||||
|
system.
|
||||||
|
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||||
|
you will need to pass the job to submit_multifile_download().
|
||||||
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||||
|
events.
|
||||||
|
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||||
|
|
||||||
|
The `dest` argument is a Path object pointing to a directory. All downloads
|
||||||
|
with be placed inside this directory. The callbacks will receive the
|
||||||
|
MultiFileDownloadJob.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
"""
|
||||||
|
Enqueue a previously-created multi-file download job.
|
||||||
|
|
||||||
|
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def submit_download_job(
|
def submit_download_job(
|
||||||
self,
|
self,
|
||||||
@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Wait until the indicated download job has reached a terminal state.
|
"""Wait until the indicated download job has reached a terminal state.
|
||||||
|
|
||||||
This will block until the indicated install job has completed,
|
This will block until the indicated install job has completed,
|
||||||
|
@ -8,23 +8,28 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .download_base import (
|
from .download_base import (
|
||||||
DownloadEventHandler,
|
DownloadEventHandler,
|
||||||
DownloadExceptionHandler,
|
DownloadExceptionHandler,
|
||||||
DownloadJob,
|
DownloadJob,
|
||||||
|
DownloadJobBase,
|
||||||
DownloadJobCancelledException,
|
DownloadJobCancelledException,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueueServiceBase,
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadJob,
|
||||||
ServiceInactiveException,
|
ServiceInactiveException,
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
)
|
)
|
||||||
@ -42,20 +47,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_parallel_dl: int = 5,
|
max_parallel_dl: int = 5,
|
||||||
|
app_config: Optional[InvokeAIAppConfig] = None,
|
||||||
event_bus: Optional["EventServiceBase"] = None,
|
event_bus: Optional["EventServiceBase"] = None,
|
||||||
requests_session: Optional[requests.sessions.Session] = None,
|
requests_session: Optional[requests.sessions.Session] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize DownloadQueue.
|
Initialize DownloadQueue.
|
||||||
|
|
||||||
|
:param app_config: InvokeAIAppConfig object
|
||||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||||
"""
|
"""
|
||||||
|
self._app_config = app_config or get_config()
|
||||||
self._jobs: Dict[int, DownloadJob] = {}
|
self._jobs: Dict[int, DownloadJob] = {}
|
||||||
|
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._job_completed_event = threading.Event()
|
self._job_terminated_event = threading.Event()
|
||||||
self._worker_pool: Set[threading.Thread] = set()
|
self._worker_pool: Set[threading.Thread] = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||||
@ -107,18 +116,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
raise ServiceInactiveException(
|
raise ServiceInactiveException(
|
||||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||||
)
|
)
|
||||||
with self._lock:
|
job.id = self._next_id()
|
||||||
job.id = self._next_job_id
|
job.set_callbacks(
|
||||||
self._next_job_id += 1
|
on_start=on_start,
|
||||||
job.set_callbacks(
|
on_progress=on_progress,
|
||||||
on_start=on_start,
|
on_complete=on_complete,
|
||||||
on_progress=on_progress,
|
on_cancelled=on_cancelled,
|
||||||
on_complete=on_complete,
|
on_error=on_error,
|
||||||
on_cancelled=on_cancelled,
|
)
|
||||||
on_error=on_error,
|
self._jobs[job.id] = job
|
||||||
)
|
self._queue.put(job)
|
||||||
self._jobs[job.id] = job
|
|
||||||
self._queue.put(job)
|
|
||||||
|
|
||||||
def download(
|
def download(
|
||||||
self,
|
self,
|
||||||
@ -141,7 +148,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
dest=dest,
|
dest=dest,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
access_token=access_token,
|
access_token=access_token or self._lookup_access_token(source),
|
||||||
)
|
)
|
||||||
self.submit_download_job(
|
self.submit_download_job(
|
||||||
job,
|
job,
|
||||||
@ -153,10 +160,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
)
|
)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||||
|
mfdj.set_callbacks(
|
||||||
|
on_start=on_start,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_complete=on_complete,
|
||||||
|
on_cancelled=on_cancelled,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
url = part.url
|
||||||
|
path = dest / part.path
|
||||||
|
assert path.is_relative_to(dest), "only relative download paths accepted"
|
||||||
|
job = DownloadJob(
|
||||||
|
source=url,
|
||||||
|
dest=path,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
mfdj.download_parts.add(job)
|
||||||
|
self._download_part2parent[job.source] = mfdj
|
||||||
|
if submit_job:
|
||||||
|
self.submit_multifile_download(mfdj)
|
||||||
|
return mfdj
|
||||||
|
|
||||||
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
for download_job in job.download_parts:
|
||||||
|
self.submit_download_job(
|
||||||
|
download_job,
|
||||||
|
on_start=self._mfd_started,
|
||||||
|
on_progress=self._mfd_progress,
|
||||||
|
on_complete=self._mfd_complete,
|
||||||
|
on_cancelled=self._mfd_cancelled,
|
||||||
|
on_error=self._mfd_error,
|
||||||
|
)
|
||||||
|
|
||||||
def join(self) -> None:
|
def join(self) -> None:
|
||||||
"""Wait for all jobs to complete."""
|
"""Wait for all jobs to complete."""
|
||||||
self._queue.join()
|
self._queue.join()
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
id = self._next_job_id
|
||||||
|
self._next_job_id += 1
|
||||||
|
return id
|
||||||
|
|
||||||
def list_jobs(self) -> List[DownloadJob]:
|
def list_jobs(self) -> List[DownloadJob]:
|
||||||
"""List all the jobs."""
|
"""List all the jobs."""
|
||||||
return list(self._jobs.values())
|
return list(self._jobs.values())
|
||||||
@ -178,14 +238,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except KeyError as excp:
|
except KeyError as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""
|
"""
|
||||||
Cancel the indicated job.
|
Cancel the indicated job.
|
||||||
|
|
||||||
If it is running it will be stopped.
|
If it is running it will be stopped.
|
||||||
job.status will be set to DownloadJobStatus.CANCELLED
|
job.status will be set to DownloadJobStatus.CANCELLED
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
||||||
job.cancel()
|
job.cancel()
|
||||||
|
|
||||||
def cancel_all_jobs(self) -> None:
|
def cancel_all_jobs(self) -> None:
|
||||||
@ -194,12 +254,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if not job.in_terminal_state:
|
if not job.in_terminal_state:
|
||||||
self.cancel_job(job)
|
self.cancel_job(job)
|
||||||
|
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not job.in_terminal_state:
|
while not job.in_terminal_state:
|
||||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._job_completed_event.clear()
|
self._job_terminated_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
return job
|
return job
|
||||||
@ -228,22 +288,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
job.job_started = get_iso_timestamp()
|
job.job_started = get_iso_timestamp()
|
||||||
self._do_download(job)
|
self._do_download(job)
|
||||||
self._signal_job_complete(job)
|
self._signal_job_complete(job)
|
||||||
except (OSError, HTTPError) as excp:
|
|
||||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
|
||||||
job.error = traceback.format_exc()
|
|
||||||
self._signal_job_error(job, excp)
|
|
||||||
except DownloadJobCancelledException:
|
except DownloadJobCancelledException:
|
||||||
self._signal_job_cancelled(job)
|
self._signal_job_cancelled(job)
|
||||||
self._cleanup_cancelled_job(job)
|
self._cleanup_cancelled_job(job)
|
||||||
|
except Exception as excp:
|
||||||
|
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||||
|
job.error = traceback.format_exc()
|
||||||
|
self._signal_job_error(job, excp)
|
||||||
finally:
|
finally:
|
||||||
job.job_ended = get_iso_timestamp()
|
job.job_ended = get_iso_timestamp()
|
||||||
self._job_completed_event.set() # signal a change to terminal state
|
self._job_terminated_event.set() # signal a change to terminal state
|
||||||
|
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||||
|
self._job_terminated_event.set()
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||||
|
|
||||||
def _do_download(self, job: DownloadJob) -> None:
|
def _do_download(self, job: DownloadJob) -> None:
|
||||||
"""Do the actual download."""
|
"""Do the actual download."""
|
||||||
|
|
||||||
url = job.source
|
url = job.source
|
||||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
@ -335,38 +398,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def _in_progress_path(self, path: Path) -> Path:
|
def _in_progress_path(self, path: Path) -> Path:
|
||||||
return path.with_name(path.name + ".downloading")
|
return path.with_name(path.name + ".downloading")
|
||||||
|
|
||||||
|
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||||
|
# Pull the token from config if it exists and matches the URL
|
||||||
|
token = None
|
||||||
|
for pair in self._app_config.remote_api_tokens or []:
|
||||||
|
if re.search(pair.url_regex, str(source)):
|
||||||
|
token = pair.token
|
||||||
|
break
|
||||||
|
return token
|
||||||
|
|
||||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.RUNNING
|
job.status = DownloadJobStatus.RUNNING
|
||||||
if job.on_start:
|
self._execute_cb(job, "on_start")
|
||||||
try:
|
|
||||||
job.on_start(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_started(job)
|
self._event_bus.emit_download_started(job)
|
||||||
|
|
||||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||||
if job.on_progress:
|
self._execute_cb(job, "on_progress")
|
||||||
try:
|
|
||||||
job.on_progress(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_progress(job)
|
self._event_bus.emit_download_progress(job)
|
||||||
|
|
||||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.COMPLETED
|
job.status = DownloadJobStatus.COMPLETED
|
||||||
if job.on_complete:
|
self._execute_cb(job, "on_complete")
|
||||||
try:
|
|
||||||
job.on_complete(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_complete(job)
|
self._event_bus.emit_download_complete(job)
|
||||||
|
|
||||||
@ -374,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||||
return
|
return
|
||||||
job.status = DownloadJobStatus.CANCELLED
|
job.status = DownloadJobStatus.CANCELLED
|
||||||
if job.on_cancelled:
|
self._execute_cb(job, "on_cancelled")
|
||||||
try:
|
|
||||||
job.on_cancelled(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_cancelled(job)
|
self._event_bus.emit_download_cancelled(job)
|
||||||
|
|
||||||
|
# if multifile download, then signal the parent
|
||||||
|
if parent_job := self._download_part2parent.get(job.source, None):
|
||||||
|
if not parent_job.in_terminal_state:
|
||||||
|
parent_job.status = DownloadJobStatus.CANCELLED
|
||||||
|
self._execute_cb(parent_job, "on_cancelled")
|
||||||
|
|
||||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
job.status = DownloadJobStatus.ERROR
|
job.status = DownloadJobStatus.ERROR
|
||||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||||
if job.on_error:
|
self._execute_cb(job, "on_error", excp)
|
||||||
try:
|
|
||||||
job.on_error(job, excp)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_error(job)
|
self._event_bus.emit_download_error(job)
|
||||||
|
|
||||||
@ -406,6 +455,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except OSError as excp:
|
except OSError as excp:
|
||||||
self._logger.warning(excp)
|
self._logger.warning(excp)
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# callbacks used for multifile downloads
|
||||||
|
########################################
|
||||||
|
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"File download started: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.waiting:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.status = DownloadJobStatus.RUNNING
|
||||||
|
assert download_job.download_path is not None
|
||||||
|
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||||
|
mf_job.download_path = (
|
||||||
|
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||||
|
) # keep just the first component of the path
|
||||||
|
self._execute_cb(mf_job, "on_start")
|
||||||
|
|
||||||
|
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.cancelled:
|
||||||
|
for part in mf_job.download_parts:
|
||||||
|
self.cancel_job(part)
|
||||||
|
elif mf_job.running:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
self._execute_cb(mf_job, "on_progress")
|
||||||
|
|
||||||
|
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"Download complete: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
|
||||||
|
# are there any more active jobs left in this task?
|
||||||
|
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
||||||
|
mf_job.status = DownloadJobStatus.COMPLETED
|
||||||
|
self._execute_cb(mf_job, "on_complete")
|
||||||
|
|
||||||
|
# we're done with this sub-job
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
self._logger.warning(f"Download cancelled: {download_job.source}")
|
||||||
|
mf_job.cancel()
|
||||||
|
|
||||||
|
for s in mf_job.download_parts:
|
||||||
|
self.cancel_job(s)
|
||||||
|
|
||||||
|
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
mf_job.status = download_job.status
|
||||||
|
mf_job.error = download_job.error
|
||||||
|
mf_job.error_type = download_job.error_type
|
||||||
|
self._execute_cb(mf_job, "on_error", excp)
|
||||||
|
self._logger.error(
|
||||||
|
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||||
|
)
|
||||||
|
for s in [x for x in mf_job.download_parts if x.running]:
|
||||||
|
self.cancel_job(s)
|
||||||
|
self._download_part2parent.pop(download_job.source)
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _execute_cb(
|
||||||
|
self,
|
||||||
|
job: DownloadJob | MultiFileDownloadJob,
|
||||||
|
callback_name: Literal[
|
||||||
|
"on_start",
|
||||||
|
"on_progress",
|
||||||
|
"on_complete",
|
||||||
|
"on_cancelled",
|
||||||
|
"on_error",
|
||||||
|
],
|
||||||
|
excp: Optional[Exception] = None,
|
||||||
|
) -> None:
|
||||||
|
if callback := getattr(job, callback_name, None):
|
||||||
|
args = [job, excp] if excp else [job]
|
||||||
|
try:
|
||||||
|
callback(*args)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(
|
||||||
|
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pc_name_max(directory: str) -> int:
|
def get_pc_name_max(directory: str) -> int:
|
||||||
if hasattr(os, "pathconf"):
|
if hasattr(os, "pathconf"):
|
||||||
|
@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
from invokeai.backend.model_manager import AnyModelConfig
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallServiceBase(ABC):
|
class ModelInstallServiceBase(ABC):
|
||||||
@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
:param source: A Url or a string that can be converted into one.
|
:param source: A string representing a URL or repo_id.
|
||||||
:param access_token: Optional access token to access restricted resources.
|
|
||||||
|
|
||||||
The model file will be downloaded into the system-wide model cache
|
The model file will be downloaded into the system-wide model cache
|
||||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
|||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.download import DownloadJob
|
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
from invokeai.backend.model_manager.config import ModelSourceType
|
from invokeai.backend.model_manager.config import ModelSourceType
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
|
|||||||
CANCELLED = "cancelled" # terminated with an error message
|
CANCELLED = "cancelled" # terminated with an error message
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallPart(BaseModel):
|
|
||||||
url: AnyHttpUrl
|
|
||||||
path: Path
|
|
||||||
bytes: int = 0
|
|
||||||
total_bytes: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class UnknownInstallJobException(Exception):
|
class UnknownInstallJobException(Exception):
|
||||||
"""Raised when the status of an unknown job is requested."""
|
"""Raised when the status of an unknown job is requested."""
|
||||||
|
|
||||||
@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
|
|||||||
)
|
)
|
||||||
# internal flags and transitory settings
|
# internal flags and transitory settings
|
||||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||||
|
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
|
@ -5,21 +5,22 @@ import os
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from hashlib import sha256
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from pydantic_core import Url
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
|
|||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
from invokeai.backend.util.util import slugify
|
||||||
|
|
||||||
from .model_install_common import (
|
from .model_install_common import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -91,7 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._downloads_changed_event = threading.Event()
|
self._downloads_changed_event = threading.Event()
|
||||||
self._install_completed_event = threading.Event()
|
self._install_completed_event = threading.Event()
|
||||||
self._download_queue = download_queue
|
self._download_queue = download_queue
|
||||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._install_thread: Optional[threading.Thread] = None
|
self._install_thread: Optional[threading.Thread] = None
|
||||||
@ -210,33 +212,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
inplace: Optional[bool] = False,
|
inplace: Optional[bool] = False,
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
"""Install a model using pattern matching to infer the type of source."""
|
||||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
source_obj = self._guess_source(source)
|
||||||
source_obj: Optional[StringLikeSource] = None
|
if isinstance(source_obj, LocalModelSource):
|
||||||
|
source_obj.inplace = inplace
|
||||||
if Path(source).exists(): # A local file or directory
|
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
source_obj.access_token = access_token
|
||||||
elif match := re.match(hf_repoid_re, source):
|
|
||||||
source_obj = HFModelSource(
|
|
||||||
repo_id=match.group(1),
|
|
||||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
|
||||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
|
||||||
# Pull the token from config if it exists and matches the URL
|
|
||||||
_token = access_token
|
|
||||||
if _token is None:
|
|
||||||
for pair in self.app_config.remote_api_tokens or []:
|
|
||||||
if re.search(pair.url_regex, source):
|
|
||||||
_token = pair.token
|
|
||||||
break
|
|
||||||
source_obj = URLModelSource(
|
|
||||||
url=AnyHttpUrl(source),
|
|
||||||
access_token=_token,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
|
||||||
return self.import_model(source_obj, config)
|
return self.import_model(source_obj, config)
|
||||||
|
|
||||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||||
@ -297,8 +278,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
"""Cancel the indicated job."""
|
"""Cancel the indicated job."""
|
||||||
job.cancel()
|
job.cancel()
|
||||||
with self._lock:
|
self._logger.warning(f"Cancelling {job.source}")
|
||||||
self._cancel_download_parts(job)
|
if dj := job._multifile_job:
|
||||||
|
self._download_queue.cancel_job(dj)
|
||||||
|
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
@ -346,7 +328,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
legacy_config_path = stanza.get("config")
|
legacy_config_path = stanza.get("config")
|
||||||
if legacy_config_path:
|
if legacy_config_path:
|
||||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||||
config["config_path"] = str(legacy_config_path)
|
config["config_path"] = str(legacy_config_path)
|
||||||
@ -386,38 +368,92 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
rmtree(model_path)
|
rmtree(model_path)
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
def download_and_cache(
|
@classmethod
|
||||||
|
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
||||||
|
escaped_source = slugify(str(source))
|
||||||
|
return app_config.download_cache_path / escaped_source
|
||||||
|
|
||||||
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: Union[str, AnyHttpUrl],
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: int = 0,
|
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path."""
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
model_path = self._download_cache_path(str(source), self._app_config)
|
||||||
model_path = self._app_config.convert_cache_path / model_hash
|
|
||||||
|
|
||||||
# We expect the cache directory to contain one and only one downloaded file.
|
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||||
# We don't know the file's name in advance, as it is set by the download
|
# We don't know the file's name in advance, as it is set by the download
|
||||||
# content-disposition header.
|
# content-disposition header.
|
||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
contents: List[Path] = list(model_path.iterdir())
|
||||||
if len(contents) > 0:
|
if len(contents) > 0:
|
||||||
return contents[0]
|
return contents[0]
|
||||||
|
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
job = self._download_queue.download(
|
model_source = self._guess_source(str(source))
|
||||||
source=AnyHttpUrl(str(source)),
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
|
job = self._multifile_download(
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
access_token=access_token,
|
remote_files=remote_files,
|
||||||
on_progress=TqdmProgress().update,
|
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||||
)
|
)
|
||||||
self._download_queue.wait_for_job(job, timeout)
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
|
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||||
|
self._download_queue.wait_for_job(job)
|
||||||
if job.complete:
|
if job.complete:
|
||||||
assert job.download_path is not None
|
assert job.download_path is not None
|
||||||
return job.download_path
|
return job.download_path
|
||||||
else:
|
else:
|
||||||
raise Exception(job.error)
|
raise Exception(job.error)
|
||||||
|
|
||||||
|
def _remote_files_from_source(
|
||||||
|
self, source: ModelSource
|
||||||
|
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||||
|
metadata = None
|
||||||
|
if isinstance(source, HFModelSource):
|
||||||
|
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return metadata.download_urls(
|
||||||
|
variant=source.variant or self._guess_variant(),
|
||||||
|
subfolder=source.subfolder,
|
||||||
|
session=self._session,
|
||||||
|
), metadata
|
||||||
|
|
||||||
|
if isinstance(source, URLModelSource):
|
||||||
|
try:
|
||||||
|
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||||
|
kwargs: dict[str, Any] = {"session": self._session}
|
||||||
|
metadata = fetcher(**kwargs).from_url(source.url)
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
return metadata.download_urls(session=self._session), metadata
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||||
|
|
||||||
|
raise Exception(f"No files associated with {source}")
|
||||||
|
|
||||||
|
def _guess_source(self, source: str) -> ModelSource:
|
||||||
|
"""Turn a source string into a ModelSource object."""
|
||||||
|
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||||
|
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||||
|
source_obj: Optional[StringLikeSource] = None
|
||||||
|
|
||||||
|
if Path(source).exists(): # A local file or directory
|
||||||
|
source_obj = LocalModelSource(path=Path(source))
|
||||||
|
elif match := re.match(hf_repoid_re, source):
|
||||||
|
source_obj = HFModelSource(
|
||||||
|
repo_id=match.group(1),
|
||||||
|
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
|
||||||
|
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||||
|
)
|
||||||
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
|
source_obj = URLModelSource(
|
||||||
|
url=Url(source),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
return source_obj
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the installer threads
|
# Internal functions that manage the installer threads
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@ -478,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
multifile_download_job = install_job._multifile_job
|
||||||
job.set_error(
|
if multifile_download_job and any(
|
||||||
|
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
|
||||||
|
):
|
||||||
|
install_job.set_error(
|
||||||
InvalidModelConfigException(
|
InvalidModelConfigException(
|
||||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
job.set_error(excp)
|
install_job.set_error(excp)
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(install_job)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the models directory
|
# Internal functions that manage the models directory
|
||||||
@ -513,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||||
only situations in which we may have orphaned models in the models directory.
|
only situations in which we may have orphaned models in the models directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
installed_model_paths = {
|
installed_model_paths = {
|
||||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||||
}
|
}
|
||||||
@ -525,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if resolved_path in installed_model_paths:
|
if resolved_path in installed_model_paths:
|
||||||
return True
|
return True
|
||||||
# Skip core models entirely - these aren't registered with the model manager.
|
# Skip core models entirely - these aren't registered with the model manager.
|
||||||
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
|
for special_directory in [
|
||||||
return False
|
self.app_config.models_path / "core",
|
||||||
|
self.app_config.convert_cache_dir,
|
||||||
|
self.app_config.download_cache_dir,
|
||||||
|
]:
|
||||||
|
if resolved_path.is_relative_to(special_directory):
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
model_id = self.register_path(model_path)
|
model_id = self.register_path(model_path)
|
||||||
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
||||||
@ -641,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
inplace=source.inplace or False,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_hf(
|
||||||
|
self,
|
||||||
|
source: HFModelSource,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
if source.access_token is None:
|
||||||
if not source.access_token:
|
source.access_token = HfFolder.get_token()
|
||||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
|
|
||||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
||||||
remote_files = metadata.download_urls(
|
|
||||||
variant=source.variant or self._guess_variant(),
|
|
||||||
subfolder=source.subfolder,
|
|
||||||
session=self._session,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -662,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_url(
|
||||||
# URLs from HuggingFace will be handled specially
|
self,
|
||||||
metadata = None
|
source: URLModelSource,
|
||||||
fetcher = None
|
config: Optional[Dict[str, Any]],
|
||||||
try:
|
) -> ModelInstallJob:
|
||||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
remote_files, metadata = self._remote_files_from_source(source)
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
kwargs: dict[str, Any] = {"session": self._session}
|
|
||||||
if fetcher is not None:
|
|
||||||
metadata = fetcher(**kwargs).from_url(source.url)
|
|
||||||
self._logger.debug(f"metadata={metadata}")
|
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
|
||||||
else:
|
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
source=source,
|
source=source,
|
||||||
config=config,
|
config=config,
|
||||||
@ -692,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
|
||||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
|
||||||
# being held in a daemon thread.
|
|
||||||
if len(remote_files) == 0:
|
if len(remote_files) == 0:
|
||||||
raise ValueError(f"{source}: No downloadable files found")
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
tmpdir = Path(
|
destdir = Path(
|
||||||
mkdtemp(
|
mkdtemp(
|
||||||
dir=self._app_config.models_path,
|
dir=self._app_config.models_path,
|
||||||
prefix=TMPDIR_PREFIX,
|
prefix=TMPDIR_PREFIX,
|
||||||
@ -708,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
config_in=config or {},
|
config_in=config or {},
|
||||||
source_metadata=metadata,
|
source_metadata=metadata,
|
||||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||||
bytes=0,
|
bytes=0,
|
||||||
total_bytes=0,
|
total_bytes=0,
|
||||||
)
|
)
|
||||||
# In the event that there is a subfolder specified in the source,
|
# remember the temporary directory for later removal
|
||||||
# we need to remove it from the destination path in order to avoid
|
install_job._install_tmpdir = destdir
|
||||||
# creating unwanted subfolders
|
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||||
if isinstance(source, HFModelSource) and source.subfolder:
|
|
||||||
root = Path(remote_files[0].path.parts[0])
|
|
||||||
subfolder = root / source.subfolder
|
|
||||||
else:
|
|
||||||
root = Path(".")
|
|
||||||
subfolder = Path(".")
|
|
||||||
|
|
||||||
# we remember the path up to the top of the tmpdir so that it may be
|
multifile_job = self._multifile_download(
|
||||||
# removed safely at the end of the install process.
|
remote_files=remote_files,
|
||||||
install_job._install_tmpdir = tmpdir
|
dest=destdir,
|
||||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||||
|
access_token=source.access_token,
|
||||||
|
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||||
|
)
|
||||||
|
self._download_cache[multifile_job.id] = install_job
|
||||||
|
install_job._multifile_job = multifile_job
|
||||||
|
|
||||||
files_string = "file" if len(remote_files) == 1 else "file"
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
|
||||||
self._logger.debug(f"remote_files={remote_files}")
|
self._logger.debug(f"remote_files={remote_files}")
|
||||||
for model_file in remote_files:
|
self._download_queue.submit_multifile_download(multifile_job)
|
||||||
url = model_file.url
|
|
||||||
path = root / model_file.path.relative_to(subfolder)
|
|
||||||
self._logger.debug(f"Downloading {url} => {path}")
|
|
||||||
install_job.total_bytes += model_file.size
|
|
||||||
assert hasattr(source, "access_token")
|
|
||||||
dest = tmpdir / path.parent
|
|
||||||
dest.mkdir(parents=True, exist_ok=True)
|
|
||||||
download_job = DownloadJob(
|
|
||||||
source=url,
|
|
||||||
dest=dest,
|
|
||||||
access_token=source.access_token,
|
|
||||||
)
|
|
||||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
|
||||||
install_job.download_parts.add(download_job)
|
|
||||||
|
|
||||||
# only start the jobs once install_job.download_parts is fully populated
|
|
||||||
for download_job in install_job.download_parts:
|
|
||||||
self._download_queue.submit_download_job(
|
|
||||||
download_job,
|
|
||||||
on_start=self._download_started_callback,
|
|
||||||
on_progress=self._download_progress_callback,
|
|
||||||
on_complete=self._download_complete_callback,
|
|
||||||
on_error=self._download_error_callback,
|
|
||||||
on_cancelled=self._download_cancelled_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def _stat_size(self, path: Path) -> int:
|
def _stat_size(self, path: Path) -> int:
|
||||||
@ -768,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
def _multifile_download(
|
||||||
|
self,
|
||||||
|
remote_files: List[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
subfolder: Optional[Path] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||||
|
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||||
|
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||||
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||||
|
if subfolder:
|
||||||
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||||
|
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||||
|
path_to_add = Path(f"{top}_{subfolder}")
|
||||||
|
else:
|
||||||
|
path_to_remove = Path(".")
|
||||||
|
path_to_add = Path(".")
|
||||||
|
|
||||||
|
parts: List[RemoteModelFile] = []
|
||||||
|
for model_file in remote_files:
|
||||||
|
assert model_file.size is not None
|
||||||
|
parts.append(
|
||||||
|
RemoteModelFile(
|
||||||
|
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||||
|
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._download_queue.multifile_download(
|
||||||
|
parts=parts,
|
||||||
|
dest=dest,
|
||||||
|
access_token=access_token,
|
||||||
|
submit_job=submit_job,
|
||||||
|
on_start=self._download_started_callback,
|
||||||
|
on_progress=self._download_progress_callback,
|
||||||
|
on_complete=self._download_complete_callback,
|
||||||
|
on_error=self._download_error_callback,
|
||||||
|
on_cancelled=self._download_cancelled_callback,
|
||||||
|
)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Callbacks are executed by the download queue in a separate thread
|
# Callbacks are executed by the download queue in a separate thread
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download started: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
install_job.status = InstallStatus.DOWNLOADING
|
install_job.status = InstallStatus.DOWNLOADING
|
||||||
|
|
||||||
assert download_job.download_path
|
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||||
if install_job.local_path == install_job._install_tmpdir:
|
assert download_job.download_path
|
||||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
install_job.local_path = download_job.download_path
|
||||||
dest_name = partial_path.parts[0]
|
install_job.download_parts = download_job.download_parts
|
||||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
|
install_job.total_bytes = download_job.total_bytes
|
||||||
# Update the total bytes count for remote sources.
|
|
||||||
if not install_job.total_bytes:
|
|
||||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
|
||||||
|
|
||||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
|
||||||
with self._lock:
|
|
||||||
install_job = self._download_cache[download_job.source]
|
|
||||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
else:
|
|
||||||
# update sizes
|
|
||||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
|
||||||
self._signal_job_downloading(install_job)
|
self._signal_job_downloading(install_job)
|
||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download complete: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
if install_job := self._download_cache.get(download_job.id, None):
|
||||||
|
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||||
|
self._download_queue.cancel_job(download_job)
|
||||||
|
else:
|
||||||
|
# update sizes
|
||||||
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
|
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
|
||||||
|
self._signal_job_downloading(install_job)
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
with self._lock:
|
||||||
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
self._put_in_queue(install_job)
|
self._put_in_queue(install_job) # this starts the installation and registration
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._download_cache.pop(download_job.source, None)
|
self._downloads_changed_event.set()
|
||||||
self._downloads_changed_event.set()
|
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
assert install_job is not None
|
assert excp is not None
|
||||||
assert excp is not None
|
install_job.set_error(excp)
|
||||||
install_job.set_error(excp)
|
self._download_queue.cancel_job(download_job)
|
||||||
self._logger.error(
|
|
||||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
|
||||||
)
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
if install_job := self._download_cache.pop(download_job.id, None):
|
||||||
if not install_job:
|
self._downloads_changed_event.set()
|
||||||
return
|
# if install job has already registered an error, then do not replace its status with cancelled
|
||||||
self._downloads_changed_event.set()
|
if not install_job.errored:
|
||||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
install_job.cancel()
|
||||||
# if install job has already registered an error, then do not replace its status with cancelled
|
|
||||||
if not install_job.errored:
|
|
||||||
install_job.cancel()
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
|
||||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
|
||||||
# do not lock here because it gets called within a locked context
|
|
||||||
for s in install_job.download_parts:
|
|
||||||
self._download_queue.cancel_job(s)
|
|
||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
|
||||||
self._put_in_queue(install_job)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
@ -861,6 +876,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
|
assert job._multifile_job is not None
|
||||||
|
assert job.bytes is not None
|
||||||
|
assert job.total_bytes is not None
|
||||||
self._event_bus.emit_model_install_download_progress(job)
|
self._event_bus.emit_model_install_download_progress(job)
|
||||||
|
|
||||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||||
@ -875,6 +893,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.info(f"Model install complete: {job.source}")
|
self._logger.info(f"Model install complete: {job.source}")
|
||||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
|
assert job.local_path is not None
|
||||||
|
assert job.config_out is not None
|
||||||
self._event_bus.emit_model_install_complete(job)
|
self._event_bus.emit_model_install_complete(job)
|
||||||
|
|
||||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||||
@ -890,7 +910,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_cancelled(job)
|
self._event_bus.emit_model_install_cancelled(job)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||||
|
"""
|
||||||
|
Return a metadata fetcher appropriate for provided url.
|
||||||
|
|
||||||
|
This used to be more useful, but the number of supported model
|
||||||
|
sources has been reduced to HuggingFace alone.
|
||||||
|
"""
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -2,10 +2,11 @@
|
|||||||
"""Base class for model loader."""
|
"""Base class for model loader."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
@ -31,3 +32,26 @@ class ModelLoadServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def convert_cache(self) -> ModelConvertCacheBase:
|
def convert_cache(self) -> ModelConvertCacheBase:
|
||||||
"""Return the checkpoint convert cache used by this loader."""
|
"""Return the checkpoint convert cache used by this loader."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model_from_path(
|
||||||
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Load the model file or directory located at the indicated Path.
|
||||||
|
|
||||||
|
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||||
|
argument is provided, the loader will be invoked to load the model into
|
||||||
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
|
torch.load() as appropriate to the file suffix.
|
||||||
|
|
||||||
|
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||||
|
LoadedModel, but without the config attribute.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: A pathlib.Path to a checkpoint-style models file
|
||||||
|
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModel object.
|
||||||
|
"""
|
||||||
|
@ -1,18 +1,26 @@
|
|||||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of model loader service."""
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
from typing import Optional, Type
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Type
|
||||||
|
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
|
from torch import load as torch_load
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import (
|
from invokeai.backend.model_manager.load import (
|
||||||
LoadedModel,
|
LoadedModel,
|
||||||
|
LoadedModelWithoutConfig,
|
||||||
ModelLoaderRegistry,
|
ModelLoaderRegistry,
|
||||||
ModelLoaderRegistryBase,
|
ModelLoaderRegistryBase,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_load_base import ModelLoadServiceBase
|
from .model_load_base import ModelLoadServiceBase
|
||||||
@ -75,3 +83,41 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||||
|
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
|
def load_model_from_path(
|
||||||
|
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
cache_key = str(model_path)
|
||||||
|
ram_cache = self.ram_cache
|
||||||
|
try:
|
||||||
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||||
|
scan_result = scan_file_path(checkpoint)
|
||||||
|
if scan_result.infected_files != 0:
|
||||||
|
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||||
|
result = torch_load(checkpoint, map_location="cpu")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||||
|
load_class = GenericDiffusersLoader(
|
||||||
|
app_config=self._app_config,
|
||||||
|
logger=self._logger,
|
||||||
|
ram_cache=self._ram_cache,
|
||||||
|
convert_cache=self.convert_cache,
|
||||||
|
).get_hf_load_class(directory)
|
||||||
|
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||||
|
|
||||||
|
loader = loader or (
|
||||||
|
diffusers_load_directory
|
||||||
|
if model_path.is_dir()
|
||||||
|
else torch_load_file
|
||||||
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||||
|
else lambda path: safetensors_load_file(path, device="cpu")
|
||||||
|
)
|
||||||
|
assert loader is not None
|
||||||
|
raw_model = loader(model_path)
|
||||||
|
ram_cache.put(key=cache_key, model=raw_model)
|
||||||
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_manager.config import (
|
|
||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
MainModelDefaultSettings,
|
MainModelDefaultSettings,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
@ -14,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
@ -320,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
|
"""Common API for loading, downloading and managing models."""
|
||||||
|
|
||||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||||
"""Checks if a model exists.
|
"""Check if a model exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@ -331,13 +341,13 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
return self._services.model_manager.store.exists(identifier)
|
return self._services.model_manager.store.exists(identifier)
|
||||||
|
else:
|
||||||
return self._services.model_manager.store.exists(identifier.key)
|
return self._services.model_manager.store.exists(identifier.key)
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""Loads a model.
|
"""Load a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@ -361,7 +371,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
def load_by_attrs(
|
def load_by_attrs(
|
||||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""Loads a model by its attributes.
|
"""Load a model by its attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Name of the model.
|
name: Name of the model.
|
||||||
@ -384,7 +394,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||||
|
|
||||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||||
"""Gets a model's config.
|
"""Get a model's config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
identifier: The key or ModelField representing the model.
|
identifier: The key or ModelField representing the model.
|
||||||
@ -394,11 +404,11 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
return self._services.model_manager.store.get_model(identifier)
|
return self._services.model_manager.store.get_model(identifier)
|
||||||
|
else:
|
||||||
return self._services.model_manager.store.get_model(identifier.key)
|
return self._services.model_manager.store.get_model(identifier.key)
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by path.
|
"""Search for models by path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path: The path to search for.
|
path: The path to search for.
|
||||||
@ -415,7 +425,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
type: Optional[ModelType] = None,
|
type: Optional[ModelType] = None,
|
||||||
format: Optional[ModelFormat] = None,
|
format: Optional[ModelFormat] = None,
|
||||||
) -> list[AnyModelConfig]:
|
) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by attributes.
|
"""Search for models by attributes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name to search for (exact match).
|
name: The name to search for (exact match).
|
||||||
@ -434,6 +444,72 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
model_format=format,
|
model_format=format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def download_and_cache_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
) -> Path:
|
||||||
|
"""
|
||||||
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
|
This can be used to single-file install models and other resources of arbitrary types
|
||||||
|
which should not get registered with the database. If the model is already
|
||||||
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL that points to the model, or a huggingface repo_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the downloaded model
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
|
|
||||||
|
def load_local_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Load the model file located at the indicated path
|
||||||
|
|
||||||
|
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: A model Path
|
||||||
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModelWithoutConfig object.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
def load_remote_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||||
|
) -> LoadedModelWithoutConfig:
|
||||||
|
"""
|
||||||
|
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||||
|
|
||||||
|
If the model is already downloaded, it will be loaded from the cache.
|
||||||
|
|
||||||
|
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL or huggingface repoid.
|
||||||
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A LoadedModelWithoutConfig object.
|
||||||
|
"""
|
||||||
|
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||||
|
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
|
@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_8(app_config=config))
|
migrator.register_migration(build_migration_8(app_config=config))
|
||||||
migrator.register_migration(build_migration_9())
|
migrator.register_migration(build_migration_9())
|
||||||
migrator.register_migration(build_migration_10())
|
migrator.register_migration(build_migration_10())
|
||||||
|
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
@ -0,0 +1,75 @@
|
|||||||
|
import shutil
|
||||||
|
import sqlite3
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
LEGACY_CORE_MODELS = [
|
||||||
|
# OpenPose
|
||||||
|
"any/annotators/dwpose/yolox_l.onnx",
|
||||||
|
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||||
|
# DepthAnything
|
||||||
|
"any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||||
|
"any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||||
|
"any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||||
|
# Lama inpaint
|
||||||
|
"core/misc/lama/lama.pt",
|
||||||
|
# RealESRGAN upscale
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||||
|
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Migration11Callback:
|
||||||
|
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||||
|
self._app_config = app_config
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._remove_convert_cache()
|
||||||
|
self._remove_downloaded_models()
|
||||||
|
self._remove_unused_core_models()
|
||||||
|
|
||||||
|
def _remove_convert_cache(self) -> None:
|
||||||
|
"""Rename models/.cache to models/.convert_cache."""
|
||||||
|
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||||
|
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||||
|
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||||
|
|
||||||
|
def _remove_downloaded_models(self) -> None:
|
||||||
|
"""Remove models from their old locations; they will re-download when needed."""
|
||||||
|
self._logger.info(
|
||||||
|
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
|
||||||
|
)
|
||||||
|
for model_path in LEGACY_CORE_MODELS:
|
||||||
|
legacy_dest_path = self._app_config.models_path / model_path
|
||||||
|
legacy_dest_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
def _remove_unused_core_models(self) -> None:
|
||||||
|
"""Remove unused core models and their directories."""
|
||||||
|
self._logger.info("Removing defunct core models.")
|
||||||
|
for dir in ["face_restoration", "misc", "upscaling"]:
|
||||||
|
path_to_remove = self._app_config.models_path / "core" / dir
|
||||||
|
shutil.rmtree(path_to_remove, ignore_errors=True)
|
||||||
|
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 10 to 11.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||||
|
"models/.download_cache" directory.
|
||||||
|
- Renames "models/.cache" to "models/.convert_cache".
|
||||||
|
"""
|
||||||
|
migration_11 = Migration(
|
||||||
|
from_version=10,
|
||||||
|
to_version=11,
|
||||||
|
callback=Migration11Callback(app_config=app_config, logger=logger),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_11
|
@ -1,51 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressBar:
|
|
||||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "file"):
|
|
||||||
self.pbar = None
|
|
||||||
self.name = model_name
|
|
||||||
|
|
||||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
|
||||||
if not self.pbar:
|
|
||||||
self.pbar = tqdm(
|
|
||||||
desc=self.name,
|
|
||||||
initial=0,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1000,
|
|
||||||
total=total_size,
|
|
||||||
)
|
|
||||||
self.pbar.update(block_size)
|
|
||||||
|
|
||||||
|
|
||||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
|
||||||
"""Download a file from a URL to a destination path, with a progress bar.
|
|
||||||
If the file already exists, it will not be downloaded again.
|
|
||||||
|
|
||||||
Exceptions are not caught.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name (str): Name of the file being downloaded.
|
|
||||||
url (str): URL to download the file from.
|
|
||||||
dest_path (Path): Destination path to save the file to.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if the file was downloaded, False if it already existed.
|
|
||||||
"""
|
|
||||||
if dest_path.exists():
|
|
||||||
return False # already downloaded
|
|
||||||
|
|
||||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
|
||||||
|
|
||||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
|
||||||
|
|
||||||
return True
|
|
@ -1,5 +1,5 @@
|
|||||||
import pathlib
|
from pathlib import Path
|
||||||
from typing import Literal, Union
|
from typing import Literal
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,28 +10,17 @@ from PIL import Image
|
|||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
|
|
||||||
DEPTH_ANYTHING_MODELS = {
|
DEPTH_ANYTHING_MODELS = {
|
||||||
"large": {
|
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||||
},
|
|
||||||
"base": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
|
||||||
},
|
|
||||||
"small": {
|
|
||||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
|
||||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -53,36 +42,27 @@ transform = Compose(
|
|||||||
|
|
||||||
|
|
||||||
class DepthAnythingDetector:
|
class DepthAnythingDetector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||||
self.model = None
|
self.model = model
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.device = device
|
||||||
self.device = TorchDevice.choose_torch_device()
|
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
@staticmethod
|
||||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
def load_model(
|
||||||
download_with_progress_bar(
|
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
) -> DPT_DINOv2:
|
||||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
match model_size:
|
||||||
DEPTH_ANYTHING_MODEL_PATH,
|
case "small":
|
||||||
)
|
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||||
|
case "base":
|
||||||
|
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
|
case "large":
|
||||||
|
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
|
|
||||||
if not self.model or model_size != self.model_size:
|
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||||
del self.model
|
model.eval()
|
||||||
self.model_size = model_size
|
|
||||||
|
|
||||||
match self.model_size:
|
model.to(device)
|
||||||
case "small":
|
return model
|
||||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
|
||||||
case "base":
|
|
||||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
|
||||||
case "large":
|
|
||||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
self.model.to(self.device)
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
if not self.model:
|
if not self.model:
|
||||||
|
@ -1,30 +1,53 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from controlnet_aux.util import resize_image
|
from controlnet_aux.util import resize_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||||
|
|
||||||
|
DWPOSE_MODELS = {
|
||||||
|
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||||
|
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||||
|
}
|
||||||
|
|
||||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
|
||||||
|
def draw_pose(
|
||||||
|
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||||
|
H: int,
|
||||||
|
W: int,
|
||||||
|
draw_face: bool = True,
|
||||||
|
draw_body: bool = True,
|
||||||
|
draw_hands: bool = True,
|
||||||
|
resolution: int = 512,
|
||||||
|
) -> Image.Image:
|
||||||
bodies = pose["bodies"]
|
bodies = pose["bodies"]
|
||||||
faces = pose["faces"]
|
faces = pose["faces"]
|
||||||
hands = pose["hands"]
|
hands = pose["hands"]
|
||||||
|
|
||||||
|
assert isinstance(bodies, dict)
|
||||||
candidate = bodies["candidate"]
|
candidate = bodies["candidate"]
|
||||||
|
|
||||||
|
assert isinstance(bodies, dict)
|
||||||
subset = bodies["subset"]
|
subset = bodies["subset"]
|
||||||
|
|
||||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||||
|
|
||||||
if draw_body:
|
if draw_body:
|
||||||
canvas = draw_bodypose(canvas, candidate, subset)
|
canvas = draw_bodypose(canvas, candidate, subset)
|
||||||
|
|
||||||
if draw_hands:
|
if draw_hands:
|
||||||
|
assert isinstance(hands, np.ndarray)
|
||||||
canvas = draw_handpose(canvas, hands)
|
canvas = draw_handpose(canvas, hands)
|
||||||
|
|
||||||
if draw_face:
|
if draw_face:
|
||||||
canvas = draw_facepose(canvas, faces)
|
assert isinstance(hands, np.ndarray)
|
||||||
|
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||||
|
|
||||||
dwpose_image = resize_image(
|
dwpose_image: Image.Image = resize_image(
|
||||||
canvas,
|
canvas,
|
||||||
resolution,
|
resolution,
|
||||||
)
|
)
|
||||||
@ -39,11 +62,16 @@ class DWOpenposeDetector:
|
|||||||
Credits: https://github.com/IDEA-Research/DWPose
|
Credits: https://github.com/IDEA-Research/DWPose
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||||
self.pose_estimation = Wholebody()
|
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
draw_face: bool = False,
|
||||||
|
draw_body: bool = True,
|
||||||
|
draw_hands: bool = False,
|
||||||
|
resolution: int = 512,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
np_image = np.array(image)
|
np_image = np.array(image)
|
||||||
H, W, C = np_image.shape
|
H, W, C = np_image.shape
|
||||||
@ -79,3 +107,6 @@ class DWOpenposeDetector:
|
|||||||
return draw_pose(
|
return draw_pose(
|
||||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||||
|
@ -5,11 +5,13 @@ import math
|
|||||||
import cv2
|
import cv2
|
||||||
import matplotlib
|
import matplotlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
eps = 0.01
|
eps = 0.01
|
||||||
|
NDArrayInt = npt.NDArray[np.uint8]
|
||||||
|
|
||||||
|
|
||||||
def draw_bodypose(canvas, candidate, subset):
|
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
candidate = np.array(candidate)
|
candidate = np.array(candidate)
|
||||||
subset = np.array(subset)
|
subset = np.array(subset)
|
||||||
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
|
|||||||
return canvas
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
def draw_handpose(canvas, all_hand_peaks):
|
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
|
|||||||
return canvas
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
def draw_facepose(canvas, all_lmks):
|
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
|
||||||
H, W, C = canvas.shape
|
H, W, C = canvas.shape
|
||||||
for lmks in all_lmks:
|
for lmks in all_lmks:
|
||||||
lmks = np.array(lmks)
|
lmks = np.array(lmks)
|
||||||
|
@ -2,47 +2,26 @@
|
|||||||
# Modified pathing to suit Invoke
|
# Modified pathing to suit Invoke
|
||||||
|
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
|
|
||||||
DWPOSE_MODELS = {
|
|
||||||
"yolox_l.onnx": {
|
|
||||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
|
||||||
},
|
|
||||||
"dw-ll_ucoco_384.onnx": {
|
|
||||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
|
||||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self):
|
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||||
device = TorchDevice.choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
|
||||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
|
||||||
|
|
||||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
|
||||||
download_with_progress_bar(
|
|
||||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
|
||||||
)
|
|
||||||
|
|
||||||
onnx_det = DET_MODEL_PATH
|
|
||||||
onnx_pose = POSE_MODEL_PATH
|
|
||||||
|
|
||||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import gc
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -6,9 +6,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
@ -19,28 +17,11 @@ def norm_img(np_img):
|
|||||||
return np_img
|
return np_img
|
||||||
|
|
||||||
|
|
||||||
def load_jit_model(url_or_path, device):
|
|
||||||
model_path = url_or_path
|
|
||||||
logger.info(f"Loading model from: {model_path}")
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
class LaMA:
|
class LaMA:
|
||||||
|
def __init__(self, model: AnyModel):
|
||||||
|
self._model = model
|
||||||
|
|
||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = TorchDevice.choose_torch_device()
|
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
|
||||||
|
|
||||||
if not model_location.exists():
|
|
||||||
download_with_progress_bar(
|
|
||||||
name="LaMa Inpainting Model",
|
|
||||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
||||||
dest_path=model_location,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = load_jit_model(model_location, device)
|
|
||||||
|
|
||||||
image = np.asarray(input_image.convert("RGB"))
|
image = np.asarray(input_image.convert("RGB"))
|
||||||
image = norm_img(image)
|
image = norm_img(image)
|
||||||
|
|
||||||
@ -48,20 +29,25 @@ class LaMA:
|
|||||||
mask = np.asarray(mask)
|
mask = np.asarray(mask)
|
||||||
mask = np.invert(mask)
|
mask = np.invert(mask)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
mask = (mask > 0) * 1
|
mask = (mask > 0) * 1
|
||||||
|
|
||||||
|
device = next(self._model.buffers()).device
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
infilled_image = model(image, mask)
|
infilled_image = self._model(image, mask)
|
||||||
|
|
||||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||||
infilled_image = Image.fromarray(infilled_image)
|
infilled_image = Image.fromarray(infilled_image)
|
||||||
|
|
||||||
del model
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return infilled_image
|
return infilled_image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
|
||||||
|
model_path = url_or_path
|
||||||
|
logger.info(f"Loading model from: {model_path}")
|
||||||
|
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -52,7 +52,7 @@ class RealESRGAN:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scale: int,
|
scale: int,
|
||||||
model_path: Path,
|
loadnet: AnyModel,
|
||||||
model: RRDBNet,
|
model: RRDBNet,
|
||||||
tile: int = 0,
|
tile: int = 0,
|
||||||
tile_pad: int = 10,
|
tile_pad: int = 10,
|
||||||
@ -67,8 +67,6 @@ class RealESRGAN:
|
|||||||
self.half = half
|
self.half = half
|
||||||
self.device = TorchDevice.choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
|
||||||
|
|
||||||
# prefer to use params_ema
|
# prefer to use params_ema
|
||||||
if "params_ema" in loadnet:
|
if "params_ema" in loadnet:
|
||||||
keyname = "params_ema"
|
keyname = "params_ema"
|
||||||
|
@ -36,7 +36,7 @@ from ..raw_model import RawModel
|
|||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
Default = "" # model files without "fp16" or other qualifier
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
|
@ -7,7 +7,7 @@ from importlib import import_module
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import LoadedModel, ModelLoaderBase
|
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||||
from .load_default import ModelLoader
|
from .load_default import ModelLoader
|
||||||
from .model_cache.model_cache_default import ModelCache
|
from .model_cache.model_cache_default import ModelCache
|
||||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||||
@ -19,6 +19,7 @@ for module in loaders:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LoadedModel",
|
"LoadedModel",
|
||||||
|
"LoadedModelWithoutConfig",
|
||||||
"ModelCache",
|
"ModelCache",
|
||||||
"ModelConvertCache",
|
"ModelConvertCache",
|
||||||
"ModelLoaderBase",
|
"ModelLoaderBase",
|
||||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from invokeai.backend.util import GIG, directory_size
|
from invokeai.backend.util import GIG, directory_size
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
from invokeai.backend.util.util import safe_filename
|
||||||
|
|
||||||
from .convert_cache_base import ModelConvertCacheBase
|
from .convert_cache_base import ModelConvertCacheBase
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
|
|||||||
|
|
||||||
def cache_path(self, key: str) -> Path:
|
def cache_path(self, key: str) -> Path:
|
||||||
"""Return the path for a model with the indicated key."""
|
"""Return the path for a model with the indicated key."""
|
||||||
|
key = safe_filename(self._cache_path, key)
|
||||||
return self._cache_path / key
|
return self._cache_path / key
|
||||||
|
|
||||||
def make_room(self, size: float) -> None:
|
def make_room(self, size: float) -> None:
|
||||||
|
@ -23,7 +23,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadedModel:
|
class LoadedModelWithoutConfig:
|
||||||
"""
|
"""
|
||||||
Context manager object that mediates transfer from RAM<->VRAM.
|
Context manager object that mediates transfer from RAM<->VRAM.
|
||||||
|
|
||||||
@ -61,7 +61,6 @@ class LoadedModel:
|
|||||||
not have a state_dict, in which case this value will be None.
|
not have a state_dict, in which case this value will be None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config: AnyModelConfig
|
|
||||||
_locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
|
|
||||||
def __enter__(self) -> AnyModel:
|
def __enter__(self) -> AnyModel:
|
||||||
@ -89,6 +88,13 @@ class LoadedModel:
|
|||||||
return self._locker.model
|
return self._locker.model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadedModel(LoadedModelWithoutConfig):
|
||||||
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
|
config: Optional[AnyModelConfig] = None
|
||||||
|
|
||||||
|
|
||||||
# TODO(MM2):
|
# TODO(MM2):
|
||||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||||
# know about. I think the problem may be related to this class being an ABC.
|
# know about. I think the problem may be related to this class being an ABC.
|
||||||
|
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
|||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||||
if self._needs_conversion(config, model_path, cache_path):
|
if self._needs_conversion(config, model_path, cache_path):
|
||||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||||
else:
|
else:
|
||||||
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
config.key,
|
config.key,
|
||||||
submodel_type=submodel_type,
|
submodel_type=submodel_type,
|
||||||
model=loaded_model,
|
model=loaded_model,
|
||||||
size=calc_model_size_by_data(loaded_model),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._ram_cache.get(
|
return self._ram_cache.get(
|
||||||
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
if subtype == submodel_type:
|
if subtype == submodel_type:
|
||||||
continue
|
continue
|
||||||
if submodel := getattr(pipeline, subtype.value, None):
|
if submodel := getattr(pipeline, subtype.value, None):
|
||||||
self._ram_cache.put(
|
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
|
||||||
)
|
|
||||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
|
@ -169,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: T,
|
model: T,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
|
@ -29,6 +29,7 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
@ -153,13 +154,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
size: int,
|
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
key = self._make_cache_key(key, submodel_type)
|
||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||||
@ -252,12 +253,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
"""
|
"""
|
||||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
|
||||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
|
||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
|
||||||
return
|
|
||||||
|
|
||||||
source_device = cache_entry.device
|
source_device = cache_entry.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||||
@ -265,6 +261,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||||
|
if not hasattr(cache_entry.model, "to"):
|
||||||
|
return
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||||
|
@ -35,10 +35,6 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
def lock(self) -> AnyModel:
|
def lock(self) -> AnyModel:
|
||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return self.model
|
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
|
||||||
self._cache_entry.lock()
|
self._cache_entry.lock()
|
||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
@ -59,9 +55,6 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
if not hasattr(self.model, "to"):
|
|
||||||
return
|
|
||||||
|
|
||||||
self._cache_entry.unlock()
|
self._cache_entry.unlock()
|
||||||
if not self._cache.lazy_offloading:
|
if not self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(0)
|
self._cache.offload_unlocked_models(0)
|
||||||
|
@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
class_name = config.get("_class_name", None)
|
if class_name := config.get("_class_name"):
|
||||||
if class_name:
|
|
||||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
if config.get("model_type", None) == "clip_vision_model":
|
elif class_name := config.get("architectures"):
|
||||||
class_name = config.get("architectures")
|
|
||||||
assert class_name is not None
|
|
||||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||||
if not class_name:
|
else:
|
||||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
assert s.size is not None
|
assert s.size is not None
|
||||||
files.append(
|
files.append(
|
||||||
RemoteModelFile(
|
RemoteModelFile(
|
||||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||||
path=Path(name, s.rfilename),
|
path=Path(name, s.rfilename),
|
||||||
size=s.size,
|
size=s.size,
|
||||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||||
|
@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
|
|||||||
|
|
||||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||||
path: Path = Field(description="The path to the file, relative to the model root")
|
path: Path = Field(description="The path to the file, relative to the model root")
|
||||||
size: int = Field(description="The size of this file, in bytes")
|
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataBase(BaseModel):
|
class ModelMetadataBase(BaseModel):
|
||||||
"""Base class for model metadata information."""
|
"""Base class for model metadata information."""
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -12,6 +14,33 @@ from transformers import logging as transformers_logging
|
|||||||
GIG = 1073741824
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||||
|
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||||
|
underscores, or hyphens. Replace slashes with underscores.
|
||||||
|
Convert to lowercase. Also strip leading and
|
||||||
|
trailing whitespace, dashes, and underscores.
|
||||||
|
|
||||||
|
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||||
|
"""
|
||||||
|
value = str(value)
|
||||||
|
if allow_unicode:
|
||||||
|
value = unicodedata.normalize("NFKC", value)
|
||||||
|
else:
|
||||||
|
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||||
|
value = re.sub(r"[/]", "_", value.lower())
|
||||||
|
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||||
|
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||||
|
|
||||||
|
|
||||||
|
def safe_filename(directory: Path, value: str) -> str:
|
||||||
|
"""Make a string safe to use as a filename."""
|
||||||
|
escaped_string = slugify(value)
|
||||||
|
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
|
||||||
|
return escaped_string[len(escaped_string) - max_name_length :]
|
||||||
|
|
||||||
|
|
||||||
def directory_size(directory: Path) -> int:
|
def directory_size(directory: Path) -> int:
|
||||||
"""
|
"""
|
||||||
Return the aggregate size of all files in a directory (bytes).
|
Return the aggregate size of all files in a directory (bytes).
|
||||||
|
@ -2,14 +2,18 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter
|
||||||
|
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||||
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||||
from invokeai.app.services.events.events_common import (
|
from invokeai.app.services.events.events_common import (
|
||||||
DownloadCancelledEvent,
|
DownloadCancelledEvent,
|
||||||
DownloadCompleteEvent,
|
DownloadCompleteEvent,
|
||||||
@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
|
|||||||
DownloadProgressEvent,
|
DownloadProgressEvent,
|
||||||
DownloadStartedEvent,
|
DownloadStartedEvent,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.test_nodes import TestEventService
|
from tests.test_nodes import TestEventService
|
||||||
|
|
||||||
# Prevent pytest deprecation warnings
|
# Prevent pytest deprecation warnings
|
||||||
TestAdapter.__test__ = False # type: ignore
|
TestAdapter.__test__ = False
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def session() -> Session:
|
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
sess = TestSession()
|
|
||||||
for i in ["12345", "9999", "54321"]:
|
|
||||||
content = (
|
|
||||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
|
||||||
) # for pause tests, must make content large
|
|
||||||
sess.mount(
|
|
||||||
f"http://www.civitai.com/models/{i}",
|
|
||||||
TestAdapter(
|
|
||||||
content,
|
|
||||||
headers={
|
|
||||||
"Content-Length": len(content),
|
|
||||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# here are some malformed URLs to test
|
|
||||||
# missing the content length
|
|
||||||
sess.mount(
|
|
||||||
"http://www.civitai.com/models/missing",
|
|
||||||
TestAdapter(
|
|
||||||
b"Missing content length",
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": 'filename="missing.txt"',
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# not found test
|
|
||||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
|
||||||
|
|
||||||
return sess
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
|
||||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
def event_handler(job: DownloadJob) -> None:
|
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
events.add(job.status)
|
events.add(job.status)
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
job = queue.download(
|
job = queue.download(
|
||||||
@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.join()
|
queue.join()
|
||||||
|
|
||||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||||
|
|
||||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
queue.start()
|
queue.start()
|
||||||
queue.download(
|
queue.download(
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=15, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
cancelled = False
|
cancelled = False
|
||||||
@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
nonlocal cancelled
|
nonlocal cancelled
|
||||||
cancelled = True
|
cancelled = True
|
||||||
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError("Join took too long to return")
|
|
||||||
|
|
||||||
job = queue.download(
|
job = queue.download(
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
dest=tmp_path,
|
dest=tmp_path,
|
||||||
@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
events = set()
|
||||||
|
|
||||||
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
events.add(job.status)
|
||||||
|
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=metadata.download_urls(session=mm2_session),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=event_handler,
|
||||||
|
on_progress=event_handler,
|
||||||
|
on_complete=event_handler,
|
||||||
|
on_error=event_handler,
|
||||||
|
)
|
||||||
|
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
|
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||||
|
assert Path(
|
||||||
|
tmp_path, "sdxl-turbo/model_index.json"
|
||||||
|
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||||
|
assert Path(
|
||||||
|
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||||
|
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||||
|
|
||||||
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
events = set()
|
||||||
|
|
||||||
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
events.add(job.status)
|
||||||
|
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
files = metadata.download_urls(session=mm2_session)
|
||||||
|
# this will give a 404 error
|
||||||
|
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=files,
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=event_handler,
|
||||||
|
on_progress=event_handler,
|
||||||
|
on_complete=event_handler,
|
||||||
|
on_error=event_handler,
|
||||||
|
)
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||||
|
assert job.error_type is not None
|
||||||
|
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||||
|
assert DownloadJobStatus.ERROR in events
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||||
|
event_bus = TestEventService()
|
||||||
|
|
||||||
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
|
queue.start()
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
|
|
||||||
|
def cancelled_callback(job: DownloadJob) -> None:
|
||||||
|
nonlocal cancelled
|
||||||
|
cancelled = True
|
||||||
|
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=metadata.download_urls(session=mm2_session),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_cancelled=cancelled_callback,
|
||||||
|
)
|
||||||
|
queue.cancel_job(job)
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus.CANCELLED
|
||||||
|
assert cancelled
|
||||||
|
events = event_bus.events
|
||||||
|
assert DownloadCancelledEvent in [type(x) for x in events]
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=[
|
||||||
|
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||||
|
],
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
|
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||||
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as error:
|
||||||
|
queue.multifile_download(
|
||||||
|
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
assert str(error.value) == "only relative download paths accepted"
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def clear_config() -> Generator[None, None, None]:
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
get_config.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||||
|
with clear_config():
|
||||||
|
config = get_config()
|
||||||
|
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||||
|
queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
|
queue.start()
|
||||||
|
# this one has an access token assigned
|
||||||
|
job1 = queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
# this one doesn't
|
||||||
|
job2 = queue.download(
|
||||||
|
source=AnyHttpUrl(
|
||||||
|
"http://www.huggingface.co/foo.txt",
|
||||||
|
),
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
queue.join()
|
||||||
|
# this token is defined in the temporary root invokeai.yaml
|
||||||
|
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||||
|
assert job1.access_token == "cv_12345"
|
||||||
|
assert job2.access_token is None
|
||||||
|
queue.stop()
|
||||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.events.events_common import (
|
|||||||
ModelInstallStartedEvent,
|
ModelInstallStartedEvent,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_install import (
|
from invokeai.app.services.model_install import (
|
||||||
|
HFModelSource,
|
||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_install.model_install_common import (
|
from invokeai.app.services.model_install.model_install_common import (
|
||||||
@ -29,7 +30,14 @@ from invokeai.app.services.model_install.model_install_common import (
|
|||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
BaseModelType,
|
||||||
|
InvalidModelConfigException,
|
||||||
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.test_nodes import TestEventService
|
from tests.test_nodes import TestEventService
|
||||||
|
|
||||||
OS = platform.uname().system
|
OS = platform.uname().system
|
||||||
@ -222,7 +230,7 @@ def test_delete_register(
|
|||||||
store.get_model(key)
|
store.get_model(key)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||||
|
|
||||||
@ -243,15 +251,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
|||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||||
|
|
||||||
assert len(bus.events) == 4
|
assert len(bus.events) == 5
|
||||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts
|
||||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
|
||||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
|
||||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
|
||||||
|
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||||
|
|
||||||
bus: TestEventService = mm2_installer.event_bus
|
bus: TestEventService = mm2_installer.event_bus
|
||||||
@ -277,6 +286,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
|||||||
assert len(bus.events) >= 3
|
assert len(bus.events) >= 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
|
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||||
|
|
||||||
|
bus = mm2_installer.event_bus
|
||||||
|
store = mm2_installer.record_store
|
||||||
|
assert isinstance(bus, EventServiceBase)
|
||||||
|
assert store is not None
|
||||||
|
|
||||||
|
job = mm2_installer.import_model(source)
|
||||||
|
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||||
|
assert len(job_list) == 1
|
||||||
|
assert job.complete
|
||||||
|
assert job.config_out
|
||||||
|
|
||||||
|
key = job.config_out.key
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||||
|
assert model_record.type == ModelType.Main
|
||||||
|
assert model_record.format == ModelFormat.Diffusers
|
||||||
|
|
||||||
|
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||||
|
assert len(bus.events) >= 3
|
||||||
|
event_types = [type(x) for x in bus.events]
|
||||||
|
assert all(
|
||||||
|
x in event_types
|
||||||
|
for x in [
|
||||||
|
ModelInstallDownloadProgressEvent,
|
||||||
|
ModelInstallDownloadsCompleteEvent,
|
||||||
|
ModelInstallStartedEvent,
|
||||||
|
ModelInstallCompleteEvent,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
|
||||||
|
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
|
||||||
|
assert completed_events[0].total_bytes == downloading_events[-1].bytes
|
||||||
|
assert job.total_bytes == completed_events[0].total_bytes
|
||||||
|
print(downloading_events[-1])
|
||||||
|
print(job.download_parts)
|
||||||
|
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
|
||||||
|
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||||
job = mm2_installer.import_model(source)
|
job = mm2_installer.import_model(source)
|
||||||
@ -308,7 +360,6 @@ def test_other_error_during_install(
|
|||||||
assert job.error == "Test error"
|
assert job.error == "Test error"
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_params",
|
"model_params",
|
||||||
[
|
[
|
||||||
@ -326,7 +377,7 @@ def test_other_error_during_install(
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.timeout(timeout=40, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
assert "name" in model_params and "type" in model_params
|
assert "name" in model_params and "type" in model_params
|
||||||
@ -342,7 +393,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
|||||||
}
|
}
|
||||||
assert "repo_id" in model_params
|
assert "repo_id" in model_params
|
||||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||||
if model_params["type"] != "embedding":
|
if model_params["type"] != "embedding":
|
||||||
assert install_job1.errored
|
assert install_job1.errored
|
||||||
assert install_job1.error_type == "InvalidModelConfigException"
|
assert install_job1.error_type == "InvalidModelConfigException"
|
||||||
@ -351,6 +402,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
|||||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||||
assert install_job2.complete
|
assert install_job2.complete
|
||||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
|
88
tests/app/services/model_load/test_load_api.py
Normal file
88
tests/app/services/model_load/test_load_api.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from diffusers import AutoencoderTiny
|
||||||
|
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||||
|
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_context(
|
||||||
|
mock_services: InvocationServices,
|
||||||
|
mm2_model_manager: ModelManagerServiceBase,
|
||||||
|
) -> InvocationContext:
|
||||||
|
mock_services.model_manager = mm2_model_manager
|
||||||
|
return build_invocation_context(
|
||||||
|
services=mock_services,
|
||||||
|
data=None, # type: ignore
|
||||||
|
is_canceled=None, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||||
|
downloaded_path = mock_context.models.download_and_cache_model(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
|
assert downloaded_path.is_file()
|
||||||
|
assert downloaded_path.exists()
|
||||||
|
assert downloaded_path.name == "test_embedding.safetensors"
|
||||||
|
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
||||||
|
|
||||||
|
downloaded_path_2 = mock_context.models.download_and_cache_model(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
|
assert downloaded_path == downloaded_path_2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||||
|
downloaded_path = mock_context.models.download_and_cache_model(
|
||||||
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
|
)
|
||||||
|
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||||
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
|
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||||
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
|
assert loaded_model_1.model is loaded_model_2.model
|
||||||
|
|
||||||
|
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||||
|
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||||
|
assert loaded_model_1.model is not loaded_model_3.model
|
||||||
|
assert isinstance(loaded_model_1.model, dict)
|
||||||
|
assert isinstance(loaded_model_3.model, dict)
|
||||||
|
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="This requires a test model to load")
|
||||||
|
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||||
|
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||||
|
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||||
|
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||||
|
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
|
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||||
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
|
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||||
|
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
|
||||||
|
assert (model_path / "model_index.json").exists()
|
||||||
|
assert (model_path / "vae").is_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||||
|
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||||
|
assert model_path.is_dir()
|
||||||
|
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||||
|
model_path / "diffusion_pytorch_model.safetensors"
|
||||||
|
).exists()
|
@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
|
|||||||
return mm2_model_files / "test_embedding.safetensors"
|
return mm2_model_files / "test_embedding.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
# Can be used to test diffusers model directory loading, but
|
||||||
|
# the test file adds ~10MB of space.
|
||||||
|
# @pytest.fixture
|
||||||
|
# def vae_directory(mm2_model_files: Path) -> Path:
|
||||||
|
# return mm2_model_files / "taesdxl"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||||
return mm2_model_files / "test-diffusers-main"
|
return mm2_model_files / "test-diffusers-main"
|
||||||
@ -294,4 +301,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for i in ["12345", "9999", "54321"]:
|
||||||
|
content = (
|
||||||
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||||
|
) # for pause tests, must make content large
|
||||||
|
sess.mount(
|
||||||
|
f"http://www.civitai.com/models/{i}",
|
||||||
|
TestAdapter(
|
||||||
|
content,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(content),
|
||||||
|
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess.mount(
|
||||||
|
"http://www.huggingface.co/foo.txt",
|
||||||
|
TestAdapter(
|
||||||
|
content,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(content),
|
||||||
|
"Content-Disposition": 'filename="foo.safetensors"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# here are some malformed URLs to test
|
||||||
|
# missing the content length
|
||||||
|
sess.mount(
|
||||||
|
"http://www.civitai.com/models/missing",
|
||||||
|
TestAdapter(
|
||||||
|
b"Missing content length",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": 'filename="missing.txt"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# not found test
|
||||||
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||||
|
|
||||||
return sess
|
return sess
|
||||||
|
Loading…
Reference in New Issue
Block a user