Add simplified model manager install API to InvocationContext (#6132)

## Summary

This three two model manager-related methods to the InvocationContext
uniform API. They are accessible via `context.models.*`:

1. **`load_local_model(model_path: Path, loader:
Optional[Callable[[Path], AnyModel]] = None) ->
LoadedModelWithoutConfig`**

*Load the model located at the indicated path.*

This will load a local model (.safetensors, .ckpt or diffusers
directory) into the model manager RAM cache and return its
`LoadedModelWithoutConfig`. If the optional loader argument is provided,
the loader will be invoked to load the model into memory. Otherwise the
method will call `safetensors.torch.load_file()` `torch.load()` (with a
pickle scan), or `from_pretrained()` as appropriate to the path type.

Be aware that the `LoadedModelWithoutConfig` object differs from
`LoadedModel` by having no `config` attribute.

Here is an example of usage:

```
def invoke(self, context: InvocatinContext) -> ImageOutput:
       model_path = Path('/opt/models/RealESRGAN_x4plus.pth')
       loadnet = context.models.load_local_model(model_path)
       with loadnet as loadnet_model:
             upscaler = RealESRGAN(loadnet=loadnet_model,...)
```

---

2. **`load_remote_model(source: str | AnyHttpUrl, loader:
Optional[Callable[[Path], AnyModel]] = None) ->
LoadedModelWithoutConfig`**

*Load the model located at the indicated URL or repo_id.*

This is similar to `load_local_model()` but it accepts either a
HugginFace repo_id (as a string), or a URL. The model's file(s) will be
downloaded to `models/.download_cache` and then loaded, returning a

```
def invoke(self, context: InvocatinContext) -> ImageOutput:
       model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'
       loadnet = context.models.load_remote_model(model_url)
       with loadnet as loadnet_model:
             upscaler = RealESRGAN(loadnet=loadnet_model,...)
```
---

3. **`download_and_cache_model( source: str | AnyHttpUrl, access_token:
Optional[str] = None, timeout: Optional[int] = 0) -> Path`**

Download the model file located at source to the models cache and return
its Path. This will check `models/.download_cache` for the desired model
file and download it from the indicated source if not already present.
The local Path to the downloaded file is then returned.

---

## Other Changes

This PR performs a migration, in which it renames `models/.cache` to
`models/.convert_cache`, and migrates previously-downloaded ESRGAN,
openpose, DepthAnything and Lama inpaint models from the `models/core`
directory into `models/.download_cache`.

There are a number of legacy model files in `models/core`, such as
GFPGAN, which are no longer used. This PR deletes them and tidies up the
`models/core` directory.

## Related Issues / Discussions

I have systematically replaced all the calls to
`download_with_progress_bar()`. This function is no longer used
elsewhere and has been removed.

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

I have added unit tests for the three new calls. You may test that the
`load_and_cache_model()` call is working by running the upscaler within
the web app. On first try, you will see the model file being downloaded
into the models `.cache` directory. On subsequent tries, the model will
either load from RAM (if it hasn't been displaced) or will be loaded
from the filesystem.

<!--WHEN APPLICABLE: Describe how we can test the changes in this PR.-->

## Merge Plan

Squash merge when approved.

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [X] _The PR has a short but descriptive title, suitable for a
changelog_
- [X] _Tests added / updated (if applicable)_
- [X] _Documentation added / updated (if applicable)_
This commit is contained in:
Kent Keirsey 2024-06-08 16:24:31 -07:00 committed by GitHub
commit 9432336e2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
42 changed files with 1510 additions and 645 deletions

View File

@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
specify the source and destination of the download, and keep track of
the progress of the download.
The only job type currently implemented is `DownloadJob`, a pydantic object with the
Two job types are defined. `DownloadJob` and
`MultiFileDownloadJob`. The former is a pydantic object with the
following fields:
| **Field** | **Type** | **Default** | **Description** |
@ -138,7 +139,7 @@ following fields:
| `dest` | Path | | Where to download to |
| `access_token` | str | | [optional] string containing authentication token for access |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
`error_type` field of "DownloadJobCancelledException". In addition,
the job's `cancelled` property will be set to True.
The `MultiFileDownloadJob` is used for diffusers model downloads,
which contain multiple files and directories under a common root:
| **Field** | **Type** | **Default** | **Description** |
|----------------|-----------------|---------------|-----------------|
| _Fields passed in at job creation time_ |
| `download_parts` | Set[DownloadJob]| | Component download jobs |
| `dest` | Path | | Where to download to |
| `on_start` | Callable | | [optional] callback when the download starts |
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
| `on_complete` | Callable | | [optional] callback called after successful download completion |
| `on_error` | Callable | | [optional] callback called after an error occurs |
| `id` | int | auto assigned | Job ID, an integer >= 0 |
| _Fields updated over the course of the download task_
| `status` | DownloadJobStatus| | Status code |
| `download_path` | Path | | Path to the root of the downloaded files |
| `bytes` | int | 0 | Bytes downloaded so far |
| `total_bytes` | int | 0 | Total size of the file at the remote site |
| `error_type` | str | | String version of the exception that caused an error during download |
| `error` | str | | String version of the traceback associated with an error |
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
Note that the MultiFileDownloadJob does not support the `priority`,
`job_started`, `job_ended` or `content_type` attributes. You can get
these from the individual download jobs in `download_parts`.
### Callbacks
Download jobs can be associated with a series of callbacks, each with
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
with `join()`.
#### job = queue.download(source, dest, priority, access_token)
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
Create a new download job and put it on the queue, returning the
DownloadJob object.
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
This is similar to download(), but instead of taking a single source,
it accepts a `parts` argument consisting of a list of
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
where the URL is the location of the remote file, and the Path is the
destination.
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
consists of a url/path pair. Note that the path *must* be relative.
The method returns a `MultiFileDownloadJob`.
```
from invokeai.backend.model_manager.metadata import RemoteModelFile
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
path='my_model/textencoder/pytorch_model.safetensors'
)
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
path='my_model/vae/diffusers_model.safetensors'
)
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
dest='/tmp/downloads',
on_progress=TqdmProgress().update)
queue.wait_for_job(job)
print(f"The files were downloaded to {job.download_path}")
```
#### jobs = queue.list_jobs()
Return a list of all active and inactive `DownloadJob`s.

View File

@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
following initialization pattern:
```
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config import get_config
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_install import ModelInstallService
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.shared.sqlite import SqliteDatabase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config, logger)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
queue = DownloadQueueService()
queue.start()
installer = ModelInstallService(app_config=config,
installer = ModelInstallService(app_config=config,
record_store=record_store,
download_queue=queue
)
download_queue=queue
)
installer.start()
```
@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
model configuration to `load_model_by_config()`. It may raise a
`NotImplementedException`.
## Invocation Context Model Manager API
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
be a direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a
`LoadedModel`. The optional loader is a Callable that accepts a Path
to the object, and returns a `AnyModel` object. If no loader is
provided, then the method will use `torch.load()` for a .ckpt or .bin
checkpoint file, `safetensors.torch.load_file()` for a safetensors
checkpoint file, or `cls.from_pretrained()` for a directory that looks
like a diffusers directory.
### context.load_remote_model(source, [loader]) -> LoadedModel
This method accepts a `source` of a remote model, downloads and caches
it locally, loads it, and returns a `LoadedModel`. The source can be a
direct download URL or a HuggingFace repo_id.
In the case of HuggingFace repo_id, the following variants are
recognized:
* stabilityai/stable-diffusion-v4 -- default model
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
You can also point at an arbitrary individual file within a repo_id
directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors

View File

@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,

View File

@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
import cv2
@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return context.images.get_pil(self.image.image_name, "RGB")
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
processed_image = pidi_processor(
image,
@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image,
@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
return np_img
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
def run_processor(self, image: Image.Image) -> Image.Image:
np_img = np.array(image, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image):
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
def run_processor(self, image: Image.Image):
def run_processor(self, image: Image.Image) -> Image.Image:
np_image = np.array(image, dtype=np.uint8)
height, width = np_image.shape[:2]
@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
)
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
depth_anything_detector = DepthAnythingDetector()
depth_anything_detector.load_model(model_size=self.model_size)
def run_processor(self, image: Image.Image) -> Image.Image:
def loader(model_path: Path):
return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
) as model:
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@invocation(
@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
draw_hands: bool = InputField(default=False)
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image):
dw_openpose = DWOpenposeDetector()
def run_processor(self, image: Image.Image) -> Image.Image:
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
processed_image = dw_openpose(
image,
draw_face=self.draw_face,

View File

@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Infill the image with the specified method"""
pass
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
def load_image(self) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
image = context.images.get_pil(self.image.image_name)
image = self._context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image(context)
input_image, has_alpha = self.load_image()
# If the input image has no alpha channel, return it
if has_alpha is False:
@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
lama = LaMA()
return lama(image)
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")

View File

@ -1,5 +1,4 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal
import cv2
@ -10,10 +9,8 @@ from pydantic import ConfigDict
from invokeai.app.invocations.fields import ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
rrdbnet_model = None
netscale = None
esrgan_model_path = None
if self.model_name in [
"RealESRGAN_x4plus.pth",
@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
context.logger.error(msg)
raise ValueError(msg)
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
# Downloads the ESRGAN model if it doesn't already exist
download_with_progress_bar(
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
)
upscaler = RealESRGAN(
scale=netscale,
model_path=esrgan_model_path,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
with loadnet as loadnet_model:
upscaler = RealESRGAN(
scale=netscale,
loadnet=loadnet_model,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
upscaled_image = upscaler.upscale(cv2_image)
TorchDevice.empty_cache()
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
image_dto = context.images.save(image=pil_image)

View File

@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
outputs_dir: Path to directory for outputs.
@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
@ -303,6 +305,11 @@ class InvokeAIAppConfig(BaseSettings):
"""Path to the converted cache models directory, resolved to an absolute path.."""
return self._resolve(self.convert_cache_dir)
@property
def download_cache_path(self) -> Path:
"""Path to the downloaded models directory, resolved to an absolute path.."""
return self._resolve(self.download_cache_dir)
@property
def custom_nodes_path(self) -> Path:
"""Path to the custom nodes directory, resolved to an absolute path.."""

View File

@ -1,10 +1,17 @@
"""Init file for download queue."""
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
from .download_base import (
DownloadJob,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
UnknownJobIDException,
)
from .download_default import DownloadQueueService, TqdmProgress
__all__ = [
"DownloadJob",
"MultiFileDownloadJob",
"DownloadQueueServiceBase",
"DownloadQueueService",
"TqdmProgress",

View File

@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
from invokeai.backend.model_manager.metadata import RemoteModelFile
class DownloadJobStatus(str, Enum):
"""State of a download job."""
@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
@total_ordering
class DownloadJob(BaseModel):
"""Class to monitor and control a model download request."""
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total file size (bytes)")
@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
def cancel(self) -> None:
"""Call to cancel the job."""
self._cancelled = True
@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
"""Return true if job completed without errors."""
return self.status == DownloadJobStatus.COMPLETED
@property
def waiting(self) -> bool:
"""Return true if the job is waiting to run."""
return self.status == DownloadJobStatus.WAITING
@property
def running(self) -> bool:
"""Return true if the job is running."""
@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
self._on_cancelled = on_cancelled
@total_ordering
class DownloadJob(DownloadJobBase):
"""Class to monitor and control a model download request."""
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
job_ended: Optional[str] = Field(
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
)
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
def __hash__(self) -> int:
"""Return hash of the string representation of this object, for indexing."""
return hash(str(self))
def __le__(self, other: "DownloadJob") -> bool:
"""Return True if this job's priority is less than another's."""
return self.priority <= other.priority
class MultiFileDownloadJob(DownloadJobBase):
"""Class to monitor and control multifile downloads."""
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
class DownloadQueueServiceBase(ABC):
"""Multithreaded queue for downloading models via URL."""
@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
"""
Create and enqueue a multifile download job.
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
The `dest` argument is a Path object pointing to a directory. All downloads
with be placed inside this directory. The callbacks will receive the
MultiFileDownloadJob.
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod
def submit_download_job(
self,
@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,

View File

@ -8,23 +8,28 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.model_manager.metadata import RemoteModelFile
from invokeai.backend.util.logging import InvokeAILogger
from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobBase,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
)
@ -42,20 +47,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional["EventServiceBase"] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
self._stop_event = threading.Event()
self._job_completed_event = threading.Event()
self._job_terminated_event = threading.Event()
self._worker_pool: Set[threading.Thread] = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@ -107,18 +116,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
job.id = self._next_id()
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
self._jobs[job.id] = job
self._queue.put(job)
def download(
self,
@ -141,7 +148,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
access_token=access_token,
access_token=access_token or self._lookup_access_token(source),
)
self.submit_download_job(
job,
@ -153,10 +160,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
return job
def multifile_download(
self,
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
on_complete=on_complete,
on_cancelled=on_cancelled,
on_error=on_error,
)
for part in parts:
url = part.url
path = dest / part.path
assert path.is_relative_to(dest), "only relative download paths accepted"
job = DownloadJob(
source=url,
dest=path,
access_token=access_token,
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
on_progress=self._mfd_progress,
on_complete=self._mfd_complete,
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
def join(self) -> None:
"""Wait for all jobs to complete."""
self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
@ -178,14 +238,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
job.cancel()
def cancel_all_jobs(self) -> None:
@ -194,12 +254,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
self._job_completed_event.clear()
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
self._job_terminated_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
@ -228,22 +288,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)
except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
except DownloadJobCancelledException:
self._signal_job_cancelled(job)
self._cleanup_cancelled_job(job)
except Exception as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
self._signal_job_error(job, excp)
finally:
job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state
self._job_terminated_event.set() # signal a change to terminal state
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
self._job_terminated_event.set()
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
@ -335,38 +398,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:
try:
job.on_start(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_start")
if self._event_bus:
self._event_bus.emit_download_started(job)
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
try:
job.on_progress(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_progress")
if self._event_bus:
self._event_bus.emit_download_progress(job)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
if job.on_complete:
try:
job.on_complete(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_complete")
if self._event_bus:
self._event_bus.emit_download_complete(job)
@ -374,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
return
job.status = DownloadJobStatus.CANCELLED
if job.on_cancelled:
try:
job.on_cancelled(job)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_cancelled")
if self._event_bus:
self._event_bus.emit_download_cancelled(job)
# if multifile download, then signal the parent
if parent_job := self._download_part2parent.get(job.source, None):
if not parent_job.in_terminal_state:
parent_job.status = DownloadJobStatus.CANCELLED
self._execute_cb(parent_job, "on_cancelled")
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
if job.on_error:
try:
job.on_error(job, excp)
except Exception as e:
self._logger.error(
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
self._execute_cb(job, "on_error", excp)
if self._event_bus:
self._event_bus.emit_download_error(job)
@ -406,6 +455,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
except OSError as excp:
self._logger.warning(excp)
########################################
# callbacks used for multifile downloads
########################################
def _mfd_started(self, download_job: DownloadJob) -> None:
self._logger.info(f"File download started: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
if mf_job.cancelled:
for part in mf_job.download_parts:
self.cancel_job(part)
elif mf_job.running:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
self._execute_cb(mf_job, "on_progress")
def _mfd_complete(self, download_job: DownloadJob) -> None:
self._logger.info(f"Download complete: {download_job.source}")
with self._lock:
mf_job = self._download_part2parent[download_job.source]
# are there any more active jobs left in this task?
if mf_job.running and all(x.complete for x in mf_job.download_parts):
mf_job.status = DownloadJobStatus.COMPLETED
self._execute_cb(mf_job, "on_complete")
# we're done with this sub-job
self._job_terminated_event.set()
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
self._logger.warning(f"Download cancelled: {download_job.source}")
mf_job.cancel()
for s in mf_job.download_parts:
self.cancel_job(s)
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
mf_job = self._download_part2parent[download_job.source]
assert mf_job is not None
if not mf_job.in_terminal_state:
mf_job.status = download_job.status
mf_job.error = download_job.error
mf_job.error_type = download_job.error_type
self._execute_cb(mf_job, "on_error", excp)
self._logger.error(
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
)
for s in [x for x in mf_job.download_parts if x.running]:
self.cancel_job(s)
self._download_part2parent.pop(download_job.source)
self._job_terminated_event.set()
def _execute_cb(
self,
job: DownloadJob | MultiFileDownloadJob,
callback_name: Literal[
"on_start",
"on_progress",
"on_complete",
"on_cancelled",
"on_error",
],
excp: Optional[Exception] = None,
) -> None:
if callback := getattr(job, callback_name, None):
args = [job, excp] if excp else [job]
try:
callback(*args)
except Exception as e:
self._logger.error(
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
)
def get_pc_name_max(directory: str) -> int:
if hasattr(os, "pathconf"):

View File

@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager import AnyModelConfig
class ModelInstallServiceBase(ABC):
@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one.
:param access_token: Optional access token to access restricted resources.
:param source: A string representing a URL or repo_id.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache

View File

@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:

View File

@ -5,21 +5,22 @@ import os
import re
import threading
import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.util import slugify
from .model_install_common import (
MODEL_SOURCE_TO_TYPE_MAP,
@ -91,7 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
@ -210,33 +212,12 @@ class ModelInstallService(ModelInstallServiceBase):
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
"""Install a model using pattern matching to infer the type of source."""
source_obj = self._guess_source(source)
if isinstance(source_obj, LocalModelSource):
source_obj.inplace = inplace
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
@ -297,8 +278,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
with self._lock:
self._cancel_download_parts(job)
self._logger.warning(f"Cancelling {job.source}")
if dj := job._multifile_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@ -346,7 +328,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
@ -386,38 +368,92 @@ class ModelInstallService(ModelInstallServiceBase):
rmtree(model_path)
self.unregister(key)
def download_and_cache(
@classmethod
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
escaped_source = slugify(str(source))
return app_config.download_cache_path / escaped_source
def download_and_cache_model(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
source: str | AnyHttpUrl,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.convert_cache_path / model_hash
model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file.
# We expect the cache directory to contain one and only one downloaded file or directory.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()]
contents: List[Path] = list(model_path.iterdir())
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download(
source=AnyHttpUrl(str(source)),
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
)
self._download_queue.wait_for_job(job, timeout)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
self._download_queue.wait_for_job(job)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
def _remote_files_from_source(
self, source: ModelSource
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
metadata = None
if isinstance(source, HFModelSource):
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
), metadata
if isinstance(source, URLModelSource):
try:
fetcher = self.get_fetcher_from_url(str(source.url))
kwargs: dict[str, Any] = {"session": self._session}
metadata = fetcher(**kwargs).from_url(source.url)
assert isinstance(metadata, ModelMetadataWithFiles)
return metadata.download_urls(session=self._session), metadata
except ValueError:
pass
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
raise Exception(f"No files associated with {source}")
def _guess_source(self, source: str) -> ModelSource:
"""Turn a source string into a ModelSource object."""
variants = "|".join(ModelRepoVariant.__members__.values())
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
source_obj: Optional[StringLikeSource] = None
if Path(source).exists(): # A local file or directory
source_obj = LocalModelSource(path=Path(source))
elif match := re.match(hf_repoid_re, source):
source_obj = HFModelSource(
repo_id=match.group(1),
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
subfolder=Path(match.group(3)) if match.group(3) else None,
)
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=Url(source),
)
else:
raise ValueError(f"Unsupported model source: '{source}'")
return source_obj
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
@ -478,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
multifile_download_job = install_job._multifile_job
if multifile_download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
install_job.set_error(excp)
self._signal_job_errored(install_job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
@ -513,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
@ -525,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase):
if resolved_path in installed_model_paths:
return True
# Skip core models entirely - these aren't registered with the model manager.
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
return False
for special_directory in [
self.app_config.models_path / "core",
self.app_config.convert_cache_dir,
self.app_config.download_cache_dir,
]:
if resolved_path.is_relative_to(special_directory):
return False
try:
model_id = self.register_path(model_path)
self._logger.info(f"Registered {model_path.name} with id {model_id}")
@ -641,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase):
inplace=source.inplace or False,
)
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
if not source.access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
variant=source.variant or self._guess_variant(),
subfolder=source.subfolder,
session=self._session,
)
if source.access_token is None:
source.access_token = HfFolder.get_token()
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@ -662,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase):
metadata=metadata,
)
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
# URLs from HuggingFace will be handled specially
metadata = None
fetcher = None
try:
fetcher = self.get_fetcher_from_url(str(source.url))
except ValueError:
pass
kwargs: dict[str, Any] = {"session": self._session}
if fetcher is not None:
metadata = fetcher(**kwargs).from_url(source.url)
self._logger.debug(f"metadata={metadata}")
if metadata and isinstance(metadata, ModelMetadataWithFiles):
remote_files = metadata.download_urls(session=self._session)
else:
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
def _import_from_url(
self,
source: URLModelSource,
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,
config=config,
@ -692,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
tmpdir = Path(
destdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
@ -708,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase):
source=source,
config_in=config or {},
source_metadata=metadata,
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,
total_bytes=0,
)
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
if isinstance(source, HFModelSource) and source.subfolder:
root = Path(remote_files[0].path.parts[0])
subfolder = root / source.subfolder
else:
root = Path(".")
subfolder = Path(".")
# remember the temporary directory for later removal
install_job._install_tmpdir = destdir
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
# we remember the path up to the top of the tmpdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = tmpdir
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
multifile_job = self._multifile_download(
remote_files=remote_files,
dest=destdir,
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
access_token=source.access_token,
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
)
self._download_cache[multifile_job.id] = install_job
install_job._multifile_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = tmpdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
self._download_queue.submit_multifile_download(multifile_job)
return install_job
def _stat_size(self, path: Path) -> int:
@ -768,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase):
size += sum(self._stat_size(Path(root, x)) for x in files)
return size
def _multifile_download(
self,
remote_files: List[RemoteModelFile],
dest: Path,
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else:
path_to_remove = Path(".")
path_to_add = Path(".")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
assert model_file.size is not None
parts.append(
RemoteModelFile(
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
path=path_to_add / model_file.path.relative_to(path_to_remove),
)
)
return self._download_queue.multifile_download(
parts=parts,
dest=dest,
access_token=access_token,
submit_job=submit_job,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
install_job.status = InstallStatus.DOWNLOADING
if install_job := self._download_cache.get(download_job.id, None):
install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir:
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
dest_name = partial_path.parts[0]
install_job.local_path = install_job._install_tmpdir / dest_name
# Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
if install_job.local_path == install_job._install_tmpdir: # first time
assert download_job.download_path
install_job.local_path = download_job.download_path
install_job.download_parts = download_job.download_parts
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = download_job.total_bytes
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
if install_job := self._download_cache.get(download_job.id, None):
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._download_queue.cancel_job(download_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
if install_job := self._download_cache.pop(download_job.id, None):
self._signal_job_downloads_done(install_job)
self._put_in_queue(install_job)
self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
assert install_job is not None
assert excp is not None
install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
if install_job := self._download_cache.pop(download_job.id, None):
assert excp is not None
install_job.set_error(excp)
self._download_queue.cancel_job(download_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
if install_job := self._download_cache.pop(download_job.id, None):
self._downloads_changed_event.set()
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
@ -861,6 +876,9 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_progress(job)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
@ -875,6 +893,8 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
assert job.local_path is not None
assert job.config_out is not None
self._event_bus.emit_model_install_complete(job)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
@ -890,7 +910,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(job)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -2,10 +2,11 @@
"""Base class for model loader."""
from abc import ABC, abstractmethod
from typing import Optional
from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@ -31,3 +32,26 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
This will load an arbitrary model file into the RAM cache. If the optional loader
argument is provided, the loader will be invoked to load the model into
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args:
model_path: A pathlib.Path to a checkpoint-style models file
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
Returns:
A LoadedModel object.
"""

View File

@ -1,18 +1,26 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
"""Implementation of model loader service."""
from typing import Optional, Type
from pathlib import Path
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_load_base import ModelLoadServiceBase
@ -75,3 +83,41 @@ class ModelLoadService(ModelLoadServiceBase):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
def torch_load_file(checkpoint: Path) -> AnyModel:
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
result = torch_load(checkpoint, map_location="cpu")
return result
def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
loader = loader or (
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
assert loader is not None
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
@ -14,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -320,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Checks if a model exists.
"""Check if a model exists.
Args:
identifier: The key or ModelField representing the model.
@ -331,13 +341,13 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(identifier.key)
else:
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model.
"""Load a model.
Args:
identifier: The key or ModelField representing the model.
@ -361,7 +371,7 @@ class ModelsInterface(InvocationContextInterface):
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Loads a model by its attributes.
"""Load a model by its attributes.
Args:
name: Name of the model.
@ -384,7 +394,7 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.
"""Get a model's config.
Args:
identifier: The key or ModelField representing the model.
@ -394,11 +404,11 @@ class ModelsInterface(InvocationContextInterface):
"""
if isinstance(identifier, str):
return self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(identifier.key)
else:
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Searches for models by path.
"""Search for models by path.
Args:
path: The path to search for.
@ -415,7 +425,7 @@ class ModelsInterface(InvocationContextInterface):
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""Searches for models by attributes.
"""Search for models by attributes.
Args:
name: The name to search for (exact match).
@ -434,6 +444,72 @@ class ModelsInterface(InvocationContextInterface):
model_format=format,
)
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL that points to the model, or a huggingface repo_id.
Returns:
Path to the downloaded model
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
self,
model_path: Path,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Load the model file located at the indicated path
If a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
path: A model Path
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
self,
source: str | AnyHttpUrl,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL or repo_id.
If the model is already downloaded, it will be loaded from the cache.
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A URL or huggingface repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:

View File

@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_8(app_config=config))
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.run_migrations()
return db

View File

@ -0,0 +1,75 @@
import shutil
import sqlite3
from logging import Logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
LEGACY_CORE_MODELS = [
# OpenPose
"any/annotators/dwpose/yolox_l.onnx",
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
# DepthAnything
"any/annotators/depth_anything/depth_anything_vitl14.pth",
"any/annotators/depth_anything/depth_anything_vitb14.pth",
"any/annotators/depth_anything/depth_anything_vits14.pth",
# Lama inpaint
"core/misc/lama/lama.pt",
# RealESRGAN upscale
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
]
class Migration11Callback:
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
self._app_config = app_config
self._logger = logger
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_convert_cache()
self._remove_downloaded_models()
self._remove_unused_core_models()
def _remove_convert_cache(self) -> None:
"""Rename models/.cache to models/.convert_cache."""
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
shutil.rmtree(legacy_convert_path, ignore_errors=True)
def _remove_downloaded_models(self) -> None:
"""Remove models from their old locations; they will re-download when needed."""
self._logger.info(
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
)
for model_path in LEGACY_CORE_MODELS:
legacy_dest_path = self._app_config.models_path / model_path
legacy_dest_path.unlink(missing_ok=True)
def _remove_unused_core_models(self) -> None:
"""Remove unused core models and their directories."""
self._logger.info("Removing defunct core models.")
for dir in ["face_restoration", "misc", "upscaling"]:
path_to_remove = self._app_config.models_path / "core" / dir
shutil.rmtree(path_to_remove, ignore_errors=True)
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
"""
Build the migration from database version 10 to 11.
This migration does the following:
- Moves "core" models previously downloaded with download_with_progress_bar() into new
"models/.download_cache" directory.
- Renames "models/.cache" to "models/.convert_cache".
"""
migration_11 = Migration(
from_version=10,
to_version=11,
callback=Migration11Callback(app_config=app_config, logger=logger),
)
return migration_11

View File

@ -1,51 +0,0 @@
from pathlib import Path
from urllib import request
from tqdm import tqdm
from invokeai.backend.util.logging import InvokeAILogger
class ProgressBar:
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
def __init__(self, model_name: str = "file"):
self.pbar = None
self.name = model_name
def __call__(self, block_num: int, block_size: int, total_size: int):
if not self.pbar:
self.pbar = tqdm(
desc=self.name,
initial=0,
unit="iB",
unit_scale=True,
unit_divisor=1000,
total=total_size,
)
self.pbar.update(block_size)
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
"""Download a file from a URL to a destination path, with a progress bar.
If the file already exists, it will not be downloaded again.
Exceptions are not caught.
Args:
name (str): Name of the file being downloaded.
url (str): URL to download the file from.
dest_path (Path): Destination path to save the file to.
Returns:
bool: True if the file was downloaded, False if it already existed.
"""
if dest_path.exists():
return False # already downloaded
InvokeAILogger.get_logger().info(f"Downloading {name}...")
dest_path.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url, dest_path, ProgressBar(name))
return True

View File

@ -1,5 +1,5 @@
import pathlib
from typing import Literal, Union
from pathlib import Path
from typing import Literal
import cv2
import numpy as np
@ -10,28 +10,17 @@ from PIL import Image
from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
logger = InvokeAILogger.get_logger(config=config)
DEPTH_ANYTHING_MODELS = {
"large": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
},
"base": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
},
"small": {
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
},
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
}
@ -53,36 +42,27 @@ transform = Compose(
class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = TorchDevice.choose_torch_device()
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.model = model
self.device = device
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
download_with_progress_bar(
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
DEPTH_ANYTHING_MODELS[model_size]["url"],
DEPTH_ANYTHING_MODEL_PATH,
)
@staticmethod
def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
if not self.model or model_size != self.model_size:
del self.model
self.model_size = model_size
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
model.eval()
match self.model_size:
case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(self.device)
return self.model
model.to(device)
return model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model:

View File

@ -1,30 +1,53 @@
from pathlib import Path
from typing import Dict
import numpy as np
import torch
from controlnet_aux.util import resize_image
from PIL import Image
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
DWPOSE_MODELS = {
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
}
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
def draw_pose(
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
H: int,
W: int,
draw_face: bool = True,
draw_body: bool = True,
draw_hands: bool = True,
resolution: int = 512,
) -> Image.Image:
bodies = pose["bodies"]
faces = pose["faces"]
hands = pose["hands"]
assert isinstance(bodies, dict)
candidate = bodies["candidate"]
assert isinstance(bodies, dict)
subset = bodies["subset"]
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
if draw_body:
canvas = draw_bodypose(canvas, candidate, subset)
if draw_hands:
assert isinstance(hands, np.ndarray)
canvas = draw_handpose(canvas, hands)
if draw_face:
canvas = draw_facepose(canvas, faces)
assert isinstance(hands, np.ndarray)
canvas = draw_facepose(canvas, faces) # type: ignore
dwpose_image = resize_image(
dwpose_image: Image.Image = resize_image(
canvas,
resolution,
)
@ -39,11 +62,16 @@ class DWOpenposeDetector:
Credits: https://github.com/IDEA-Research/DWPose
"""
def __init__(self) -> None:
self.pose_estimation = Wholebody()
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
def __call__(
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
self,
image: Image.Image,
draw_face: bool = False,
draw_body: bool = True,
draw_hands: bool = False,
resolution: int = 512,
) -> Image.Image:
np_image = np.array(image)
H, W, C = np_image.shape
@ -79,3 +107,6 @@ class DWOpenposeDetector:
return draw_pose(
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
)
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]

View File

@ -5,11 +5,13 @@ import math
import cv2
import matplotlib
import numpy as np
import numpy.typing as npt
eps = 0.01
NDArrayInt = npt.NDArray[np.uint8]
def draw_bodypose(canvas, candidate, subset):
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
candidate = np.array(candidate)
subset = np.array(subset)
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
return canvas
def draw_handpose(canvas, all_hand_peaks):
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
edges = [
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
return canvas
def draw_facepose(canvas, all_lmks):
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
H, W, C = canvas.shape
for lmks in all_lmks:
lmks = np.array(lmks)

View File

@ -2,47 +2,26 @@
# Modified pathing to suit Invoke
from pathlib import Path
import numpy as np
import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
DWPOSE_MODELS = {
"yolox_l.onnx": {
"local": "any/annotators/dwpose/yolox_l.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
},
"dw-ll_ucoco_384.onnx": {
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
},
}
config = get_config()
class Wholebody:
def __init__(self):
def __init__(self, onnx_det: Path, onnx_pose: Path):
device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
download_with_progress_bar(
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
)
onnx_det = DET_MODEL_PATH
onnx_pose = POSE_MODEL_PATH
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)

View File

@ -1,4 +1,4 @@
import gc
from pathlib import Path
from typing import Any
import numpy as np
@ -6,9 +6,7 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.model_manager.config import AnyModel
def norm_img(np_img):
@ -19,28 +17,11 @@ def norm_img(np_img):
return np_img
def load_jit_model(url_or_path, device):
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model = torch.jit.load(model_path, map_location="cpu").to(device)
model.eval()
return model
class LaMA:
def __init__(self, model: AnyModel):
self._model = model
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = TorchDevice.choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))
image = norm_img(image)
@ -48,20 +29,25 @@ class LaMA:
mask = np.asarray(mask)
mask = np.invert(mask)
mask = norm_img(mask)
mask = (mask > 0) * 1
device = next(self._model.buffers()).device
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
with torch.inference_mode():
infilled_image = model(image, mask)
infilled_image = self._model(image, mask)
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
infilled_image = Image.fromarray(infilled_image)
del model
gc.collect()
torch.cuda.empty_cache()
return infilled_image
@staticmethod
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
model_path = url_or_path
logger.info(f"Loading model from: {model_path}")
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
model.eval()
return model

View File

@ -1,6 +1,5 @@
import math
from enum import Enum
from pathlib import Path
from typing import Any, Optional
import cv2
@ -11,6 +10,7 @@ from cv2.typing import MatLike
from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import TorchDevice
"""
@ -52,7 +52,7 @@ class RealESRGAN:
def __init__(
self,
scale: int,
model_path: Path,
loadnet: AnyModel,
model: RRDBNet,
tile: int = 0,
tile_pad: int = 10,
@ -67,8 +67,6 @@ class RealESRGAN:
self.half = half
self.device = TorchDevice.choose_torch_device()
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
# prefer to use params_ema
if "params_ema" in loadnet:
keyname = "params_ema"

View File

@ -36,7 +36,7 @@ from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):
@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
Default = "" # model files without "fp16" or other qualifier - empty str
Default = "" # model files without "fp16" or other qualifier
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"

View File

@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, ModelLoaderBase
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@ -19,6 +19,7 @@ for module in loaders:
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",

View File

@ -7,6 +7,7 @@ from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:

View File

@ -23,7 +23,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass
class LoadedModel:
class LoadedModelWithoutConfig:
"""
Context manager object that mediates transfer from RAM<->VRAM.
@ -61,7 +61,6 @@ class LoadedModel:
not have a state_dict, in which case this value will be None.
"""
config: AnyModelConfig
_locker: ModelLockerBase
def __enter__(self) -> AnyModel:
@ -89,6 +88,13 @@ class LoadedModel:
return self._locker.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.

View File

@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
cache_path: Path = self._convert_cache.cache_path(config.key)
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
)
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:

View File

@ -169,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
model: T,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""

View File

@ -29,6 +29,7 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -153,13 +154,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
@ -252,12 +253,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
# These attributes are not in the base ModelMixin class but in various derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
@ -265,6 +261,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
if torch.device(source_device).type == torch.device(target_device).type:
return
# Some models don't have a `to` method, in which case they run in RAM/CPU.
if not hasattr(cache_entry.model, "to"):
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from

View File

@ -35,10 +35,6 @@ class ModelLocker(ModelLockerBase):
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
@ -59,9 +55,6 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)

View File

@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
if class_name := config.get("_class_name"):
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")
assert class_name is not None
elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name:
else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

View File

@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
assert s.size is not None
files.append(
RemoteModelFile(
url=hf_hub_url(id, s.rfilename, revision=variant),
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,

View File

@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
url: AnyHttpUrl = Field(description="The url to download this model file")
path: Path = Field(description="The path to the file, relative to the model root")
size: int = Field(description="The size of this file, in bytes")
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
def __hash__(self) -> int:
return hash(str(self))
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""

View File

@ -1,6 +1,8 @@
import base64
import io
import os
import re
import unicodedata
import warnings
from pathlib import Path
@ -12,6 +14,33 @@ from transformers import logging as transformers_logging
GIG = 1073741824
def slugify(value: str, allow_unicode: bool = False) -> str:
"""
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
dashes to single dashes. Remove characters that aren't alphanumerics,
underscores, or hyphens. Replace slashes with underscores.
Convert to lowercase. Also strip leading and
trailing whitespace, dashes, and underscores.
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
"""
value = str(value)
if allow_unicode:
value = unicodedata.normalize("NFKC", value)
else:
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
value = re.sub(r"[/]", "_", value.lower())
value = re.sub(r"[^.\w\s-]", "", value.lower())
return re.sub(r"[-\s]+", "-", value).strip("-_")
def safe_filename(directory: Path, value: str) -> str:
"""Make a string safe to use as a filename."""
escaped_string = slugify(value)
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
return escaped_string[len(escaped_string) - max_name_length :]
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).

View File

@ -2,14 +2,18 @@
import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from requests_testadapter import TestAdapter
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
DownloadProgressEvent,
DownloadStartedEvent,
)
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
TestAdapter.__test__ = False
@pytest.fixture
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set()
def event_handler(job: DownloadJob) -> None:
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
job = queue.download(
@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
assert Path(
tmp_path, "sdxl-turbo/text_encoder/config.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
files = metadata.download_urls(session=mm2_session)
# this will give a 404 error
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
job = queue.multifile_download(
parts=files,
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert job.error_type is not None
assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
dest=tmp_path,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert DownloadCancelledEvent in [type(x) for x in events]
queue.stop()
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=[
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
],
dest=tmp_path,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
queue.stop()
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
with pytest.raises(AssertionError) as error:
queue.multifile_download(
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
dest=tmp_path,
)
assert str(error.value) == "only relative download paths accepted"
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
def test_tokens(tmp_path: Path, mm2_session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=mm2_session)
queue.start()
# this one has an access token assigned
job1 = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
# this one doesn't
job2 = queue.download(
source=AnyHttpUrl(
"http://www.huggingface.co/foo.txt",
),
dest=tmp_path,
)
queue.join()
# this token is defined in the temporary root invokeai.yaml
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
assert job1.access_token == "cv_12345"
assert job2.access_token is None
queue.stop()

View File

@ -20,6 +20,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
HFModelSource,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
@ -29,7 +30,14 @@ from invokeai.app.services.model_install.model_install_common import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@ -222,7 +230,7 @@ def test_delete_register(
store.get_model(key)
@pytest.mark.timeout(timeout=20, method="thread")
@pytest.mark.timeout(timeout=10, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
@ -243,15 +251,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
assert len(bus.events) == 5
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus: TestEventService = mm2_installer.event_bus
@ -277,6 +286,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert len(bus.events) >= 3
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_types = [type(x) for x in bus.events]
assert all(
x in event_types
for x in [
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
]
)
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
assert completed_events[0].total_bytes == downloading_events[-1].bytes
assert job.total_bytes == completed_events[0].total_bytes
print(downloading_events[-1])
print(job.download_parts)
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
@ -308,7 +360,6 @@ def test_other_error_during_install(
assert job.error == "Test error"
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize(
"model_params",
[
@ -326,7 +377,7 @@ def test_other_error_during_install(
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
@pytest.mark.timeout(timeout=10, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params
@ -342,7 +393,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
}
assert "repo_id" in model_params
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
mm2_installer.wait_for_job(install_job1, timeout=20)
mm2_installer.wait_for_job(install_job1, timeout=10)
if model_params["type"] != "embedding":
assert install_job1.errored
assert install_job1.error_type == "InvalidModelConfigException"
@ -351,6 +402,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
mm2_installer.wait_for_job(install_job2, timeout=20)
mm2_installer.wait_for_job(install_job2, timeout=10)
assert install_job2.complete
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out

View File

@ -0,0 +1,88 @@
from pathlib import Path
import pytest
import torch
from diffusers import AutoencoderTiny
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@pytest.fixture()
def mock_context(
mock_services: InvocationServices,
mm2_model_manager: ModelManagerServiceBase,
) -> InvocationContext:
mock_services.model_manager = mm2_model_manager
return build_invocation_context(
services=mock_services,
data=None, # type: ignore
is_canceled=None, # type: ignore
)
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path.is_file()
assert downloaded_path.exists()
assert downloaded_path.name == "test_embedding.safetensors"
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
downloaded_path_2 = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert downloaded_path == downloaded_path_2
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
downloaded_path = mock_context.models.download_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict)
assert isinstance(loaded_model_3.model, dict)
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
@pytest.mark.skip(reason="This requires a test model to load")
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_local_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
def test_download_diffusers(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
assert (model_path / "model_index.json").exists()
assert (model_path / "vae").is_dir()
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
assert model_path.is_dir()
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
model_path / "diffusion_pytorch_model.safetensors"
).exists()

View File

@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
# Can be used to test diffusers model directory loading, but
# the test file adds ~10MB of space.
# @pytest.fixture
# def vae_directory(mm2_model_files: Path) -> Path:
# return mm2_model_files / "taesdxl"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@ -294,4 +301,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
},
),
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess