mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add simplified model manager install API to InvocationContext (#6132)
## Summary This three two model manager-related methods to the InvocationContext uniform API. They are accessible via `context.models.*`: 1. **`load_local_model(model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated path.* This will load a local model (.safetensors, .ckpt or diffusers directory) into the model manager RAM cache and return its `LoadedModelWithoutConfig`. If the optional loader argument is provided, the loader will be invoked to load the model into memory. Otherwise the method will call `safetensors.torch.load_file()` `torch.load()` (with a pickle scan), or `from_pretrained()` as appropriate to the path type. Be aware that the `LoadedModelWithoutConfig` object differs from `LoadedModel` by having no `config` attribute. Here is an example of usage: ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_path = Path('/opt/models/RealESRGAN_x4plus.pth') loadnet = context.models.load_local_model(model_path) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 2. **`load_remote_model(source: str | AnyHttpUrl, loader: Optional[Callable[[Path], AnyModel]] = None) -> LoadedModelWithoutConfig`** *Load the model located at the indicated URL or repo_id.* This is similar to `load_local_model()` but it accepts either a HugginFace repo_id (as a string), or a URL. The model's file(s) will be downloaded to `models/.download_cache` and then loaded, returning a ``` def invoke(self, context: InvocatinContext) -> ImageOutput: model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth' loadnet = context.models.load_remote_model(model_url) with loadnet as loadnet_model: upscaler = RealESRGAN(loadnet=loadnet_model,...) ``` --- 3. **`download_and_cache_model( source: str | AnyHttpUrl, access_token: Optional[str] = None, timeout: Optional[int] = 0) -> Path`** Download the model file located at source to the models cache and return its Path. This will check `models/.download_cache` for the desired model file and download it from the indicated source if not already present. The local Path to the downloaded file is then returned. --- ## Other Changes This PR performs a migration, in which it renames `models/.cache` to `models/.convert_cache`, and migrates previously-downloaded ESRGAN, openpose, DepthAnything and Lama inpaint models from the `models/core` directory into `models/.download_cache`. There are a number of legacy model files in `models/core`, such as GFPGAN, which are no longer used. This PR deletes them and tidies up the `models/core` directory. ## Related Issues / Discussions I have systematically replaced all the calls to `download_with_progress_bar()`. This function is no longer used elsewhere and has been removed. <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions I have added unit tests for the three new calls. You may test that the `load_and_cache_model()` call is working by running the upscaler within the web app. On first try, you will see the model file being downloaded into the models `.cache` directory. On subsequent tries, the model will either load from RAM (if it hasn't been displaced) or will be loaded from the filesystem. <!--WHEN APPLICABLE: Describe how we can test the changes in this PR.--> ## Merge Plan Squash merge when approved. <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [X] _The PR has a short but descriptive title, suitable for a changelog_ - [X] _Tests added / updated (if applicable)_ - [X] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
9432336e2b
@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects
|
||||
specify the source and destination of the download, and keep track of
|
||||
the progress of the download.
|
||||
|
||||
The only job type currently implemented is `DownloadJob`, a pydantic object with the
|
||||
Two job types are defined. `DownloadJob` and
|
||||
`MultiFileDownloadJob`. The former is a pydantic object with the
|
||||
following fields:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
@ -138,7 +139,7 @@ following fields:
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `access_token` | str | | [optional] string containing authentication token for access |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an
|
||||
`error_type` field of "DownloadJobCancelledException". In addition,
|
||||
the job's `cancelled` property will be set to True.
|
||||
|
||||
The `MultiFileDownloadJob` is used for diffusers model downloads,
|
||||
which contain multiple files and directories under a common root:
|
||||
|
||||
| **Field** | **Type** | **Default** | **Description** |
|
||||
|----------------|-----------------|---------------|-----------------|
|
||||
| _Fields passed in at job creation time_ |
|
||||
| `download_parts` | Set[DownloadJob]| | Component download jobs |
|
||||
| `dest` | Path | | Where to download to |
|
||||
| `on_start` | Callable | | [optional] callback when the download starts |
|
||||
| `on_progress` | Callable | | [optional] callback called at intervals during download progress |
|
||||
| `on_complete` | Callable | | [optional] callback called after successful download completion |
|
||||
| `on_error` | Callable | | [optional] callback called after an error occurs |
|
||||
| `id` | int | auto assigned | Job ID, an integer >= 0 |
|
||||
| _Fields updated over the course of the download task_
|
||||
| `status` | DownloadJobStatus| | Status code |
|
||||
| `download_path` | Path | | Path to the root of the downloaded files |
|
||||
| `bytes` | int | 0 | Bytes downloaded so far |
|
||||
| `total_bytes` | int | 0 | Total size of the file at the remote site |
|
||||
| `error_type` | str | | String version of the exception that caused an error during download |
|
||||
| `error` | str | | String version of the traceback associated with an error |
|
||||
| `cancelled` | bool | False | Set to true if the job was cancelled by the caller|
|
||||
|
||||
Note that the MultiFileDownloadJob does not support the `priority`,
|
||||
`job_started`, `job_ended` or `content_type` attributes. You can get
|
||||
these from the individual download jobs in `download_parts`.
|
||||
|
||||
|
||||
### Callbacks
|
||||
|
||||
Download jobs can be associated with a series of callbacks, each with
|
||||
@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with
|
||||
running jobs with `cancel_all_jobs()`, and wait for all jobs to finish
|
||||
with `join()`.
|
||||
|
||||
#### job = queue.download(source, dest, priority, access_token)
|
||||
#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
Create a new download job and put it on the queue, returning the
|
||||
DownloadJob object.
|
||||
|
||||
#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error)
|
||||
|
||||
This is similar to download(), but instead of taking a single source,
|
||||
it accepts a `parts` argument consisting of a list of
|
||||
`RemoteModelFile` objects. Each part corresponds to a URL/Path pair,
|
||||
where the URL is the location of the remote file, and the Path is the
|
||||
destination.
|
||||
|
||||
`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and
|
||||
consists of a url/path pair. Note that the path *must* be relative.
|
||||
|
||||
The method returns a `MultiFileDownloadJob`.
|
||||
|
||||
|
||||
```
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'',
|
||||
path='my_model/textencoder/pytorch_model.safetensors'
|
||||
)
|
||||
remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt',
|
||||
path='my_model/vae/diffusers_model.safetensors'
|
||||
)
|
||||
job = queue.multifile_download(parts=[remote_file_1, remote_file_2],
|
||||
dest='/tmp/downloads',
|
||||
on_progress=TqdmProgress().update)
|
||||
queue.wait_for_job(job)
|
||||
print(f"The files were downloaded to {job.download_path}")
|
||||
```
|
||||
|
||||
#### jobs = queue.list_jobs()
|
||||
|
||||
Return a list of all active and inactive `DownloadJob`s.
|
||||
|
@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the
|
||||
following initialization pattern:
|
||||
|
||||
```
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
config = get_config()
|
||||
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config, logger)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
|
||||
installer = ModelInstallService(app_config=config,
|
||||
record_store=record_store,
|
||||
download_queue=queue
|
||||
)
|
||||
download_queue=queue
|
||||
)
|
||||
installer.start()
|
||||
```
|
||||
|
||||
@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the
|
||||
`ModelRecordServiceBase` object in `mm.store`, and passes the returned
|
||||
model configuration to `load_model_by_config()`. It may raise a
|
||||
`NotImplementedException`.
|
||||
|
||||
## Invocation Context Model Manager API
|
||||
|
||||
Within invocations, the following methods are available from the
|
||||
`InvocationContext` object:
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, and then returns a Path to the local model. The source can
|
||||
be a direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||
|
||||
This method loads a local model from the indicated path, returning a
|
||||
`LoadedModel`. The optional loader is a Callable that accepts a Path
|
||||
to the object, and returns a `AnyModel` object. If no loader is
|
||||
provided, then the method will use `torch.load()` for a .ckpt or .bin
|
||||
checkpoint file, `safetensors.torch.load_file()` for a safetensors
|
||||
checkpoint file, or `cls.from_pretrained()` for a directory that looks
|
||||
like a diffusers directory.
|
||||
|
||||
### context.load_remote_model(source, [loader]) -> LoadedModel
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, loads it, and returns a `LoadedModel`. The source can be a
|
||||
direct download URL or a HuggingFace repo_id.
|
||||
|
||||
In the case of HuggingFace repo_id, the following variants are
|
||||
recognized:
|
||||
|
||||
* stabilityai/stable-diffusion-v4 -- default model
|
||||
* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant
|
||||
* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder
|
||||
* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder
|
||||
|
||||
You can also point at an arbitrary individual file within a repo_id
|
||||
directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
|
||||
|
||||
|
@ -93,7 +93,7 @@ class ApiDependencies:
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
|
@ -2,6 +2,7 @@
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||
|
||||
@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# depth_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`")
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter")
|
||||
f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter")
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
content_shuffle_processor = ContentShuffleDetector()
|
||||
processed_image = content_shuffle_processor(
|
||||
image,
|
||||
@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(
|
||||
image,
|
||||
@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA)
|
||||
return np_img
|
||||
|
||||
def run_processor(self, img):
|
||||
np_img = np.array(img, dtype=np.uint8)
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_np_image = self.tile_resample(
|
||||
np_img,
|
||||
# res=self.tile_size,
|
||||
@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
with self._context.models.load_remote_model(
|
||||
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
|
||||
) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
dw_openpose = DWOpenposeDetector()
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"])
|
||||
onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
|
||||
|
||||
dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
|
@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Infill the image with the specified method"""
|
||||
pass
|
||||
|
||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||
def load_image(self) -> tuple[Image.Image, bool]:
|
||||
"""Process the image to have an alpha channel before being infilled"""
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = self._context.images.get_pil(self.image.image_name)
|
||||
has_alpha = True if image.mode == "RGBA" else False
|
||||
return image, has_alpha
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
# Retrieve and process image to be infilled
|
||||
input_image, has_alpha = self.load_image(context)
|
||||
input_image, has_alpha = self.load_image()
|
||||
|
||||
# If the input image has no alpha channel, return it
|
||||
if has_alpha is False:
|
||||
@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
return lama(image)
|
||||
with self._context.models.load_remote_model(
|
||||
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
loader=LaMA.load_jit_model,
|
||||
) as model:
|
||||
lama = LaMA(model)
|
||||
return lama(image)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
|
@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
@ -10,10 +9,8 @@ from pydantic import ConfigDict
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
esrgan_model_path = None
|
||||
|
||||
if self.model_name in [
|
||||
"RealESRGAN_x4plus.pth",
|
||||
@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
context.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
# Downloads the ESRGAN model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path
|
||||
loadnet = context.models.load_remote_model(
|
||||
source=ESRGAN_MODEL_URLS[self.model_name],
|
||||
)
|
||||
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
with loadnet as loadnet_model:
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
loadnet=loadnet_model,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
)
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
|
||||
|
@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
patchmatch: Enable patchmatch inpaint code.
|
||||
models_dir: Path to the models directory.
|
||||
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
|
||||
download_cache_dir: Path to the directory that contains dynamically downloaded models.
|
||||
legacy_conf_dir: Path to directory of legacy checkpoint config files.
|
||||
db_dir: Path to InvokeAI databases directory.
|
||||
outputs_dir: Path to directory for outputs.
|
||||
@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# PATHS
|
||||
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
|
||||
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
|
||||
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
|
||||
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
|
||||
outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.")
|
||||
@ -303,6 +305,11 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
"""Path to the converted cache models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.convert_cache_dir)
|
||||
|
||||
@property
|
||||
def download_cache_path(self) -> Path:
|
||||
"""Path to the downloaded models directory, resolved to an absolute path.."""
|
||||
return self._resolve(self.download_cache_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""Path to the custom nodes directory, resolved to an absolute path.."""
|
||||
|
@ -1,10 +1,17 @@
|
||||
"""Init file for download queue."""
|
||||
|
||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
||||
from .download_base import (
|
||||
DownloadJob,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
MultiFileDownloadJob,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
from .download_default import DownloadQueueService, TqdmProgress
|
||||
|
||||
__all__ = [
|
||||
"DownloadJob",
|
||||
"MultiFileDownloadJob",
|
||||
"DownloadQueueServiceBase",
|
||||
"DownloadQueueService",
|
||||
"TqdmProgress",
|
||||
|
@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
"""State of a download job."""
|
||||
@ -33,30 +35,23 @@ class ServiceInactiveException(Exception):
|
||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(BaseModel):
|
||||
"""Class to monitor and control a model download request."""
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Base of classes to monitor and control downloads."""
|
||||
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||
|
||||
@ -74,14 +69,6 @@ class DownloadJob(BaseModel):
|
||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the string representation of this object, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
return self.priority <= other.priority
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self._cancelled = True
|
||||
@ -98,6 +85,11 @@ class DownloadJob(BaseModel):
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == DownloadJobStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if the job is waiting to run."""
|
||||
return self.status == DownloadJobStatus.WAITING
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if the job is running."""
|
||||
@ -154,6 +146,37 @@ class DownloadJob(BaseModel):
|
||||
self._on_cancelled = on_cancelled
|
||||
|
||||
|
||||
@total_ordering
|
||||
class DownloadJob(DownloadJobBase):
|
||||
"""Class to monitor and control a model download request."""
|
||||
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||
job_ended: Optional[str] = Field(
|
||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||
)
|
||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the string representation of this object, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __le__(self, other: "DownloadJob") -> bool:
|
||||
"""Return True if this job's priority is less than another's."""
|
||||
return self.priority <= other.priority
|
||||
|
||||
|
||||
class MultiFileDownloadJob(DownloadJobBase):
|
||||
"""Class to monitor and control multifile downloads."""
|
||||
|
||||
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
||||
|
||||
|
||||
class DownloadQueueServiceBase(ABC):
|
||||
"""Multithreaded queue for downloading models via URL."""
|
||||
|
||||
@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
"""
|
||||
Create and enqueue a multifile download job.
|
||||
|
||||
:param parts: Set of URL / filename pairs
|
||||
:param dest: Path to download to. See below.
|
||||
:param access_token: Access token to download the indicated files. If not provided,
|
||||
each file's URL may be matched to an access token using the config file matching
|
||||
system.
|
||||
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||
you will need to pass the job to submit_multifile_download().
|
||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||
events.
|
||||
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||
|
||||
The `dest` argument is a Path object pointing to a directory. All downloads
|
||||
with be placed inside this directory. The callbacks will receive the
|
||||
MultiFileDownloadJob.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
"""
|
||||
Enqueue a previously-created multi-file download job.
|
||||
|
||||
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
|
@ -8,23 +8,28 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .download_base import (
|
||||
DownloadEventHandler,
|
||||
DownloadExceptionHandler,
|
||||
DownloadJob,
|
||||
DownloadJobBase,
|
||||
DownloadJobCancelledException,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
MultiFileDownloadJob,
|
||||
ServiceInactiveException,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
@ -42,20 +47,24 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
app_config: Optional[InvokeAIAppConfig] = None,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param app_config: InvokeAIAppConfig object
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._app_config = app_config or get_config()
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._job_terminated_event = threading.Event()
|
||||
self._worker_pool: Set[threading.Thread] = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
@ -107,18 +116,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
job.id = self._next_id()
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
|
||||
def download(
|
||||
self,
|
||||
@ -141,7 +148,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
access_token=access_token or self._lookup_access_token(source),
|
||||
)
|
||||
self.submit_download_job(
|
||||
job,
|
||||
@ -153,10 +160,63 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
return job
|
||||
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||
mfdj.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
|
||||
for part in parts:
|
||||
url = part.url
|
||||
path = dest / part.path
|
||||
assert path.is_relative_to(dest), "only relative download paths accepted"
|
||||
job = DownloadJob(
|
||||
source=url,
|
||||
dest=path,
|
||||
access_token=access_token,
|
||||
)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
if submit_job:
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
for download_job in job.download_parts:
|
||||
self.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._mfd_started,
|
||||
on_progress=self._mfd_progress,
|
||||
on_complete=self._mfd_complete,
|
||||
on_cancelled=self._mfd_cancelled,
|
||||
on_error=self._mfd_error,
|
||||
)
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
@ -178,14 +238,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
||||
job.cancel()
|
||||
|
||||
def cancel_all_jobs(self) -> None:
|
||||
@ -194,12 +254,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
||||
self._job_terminated_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
@ -228,22 +288,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
job.job_started = get_iso_timestamp()
|
||||
self._do_download(job)
|
||||
self._signal_job_complete(job)
|
||||
except (OSError, HTTPError) as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job, excp)
|
||||
except DownloadJobCancelledException:
|
||||
self._signal_job_cancelled(job)
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
except Exception as excp:
|
||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||
job.error = traceback.format_exc()
|
||||
self._signal_job_error(job, excp)
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
self._job_terminated_event.set() # signal a change to terminal state
|
||||
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||
self._job_terminated_event.set()
|
||||
self._queue.task_done()
|
||||
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
"""Do the actual download."""
|
||||
|
||||
url = job.source
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
@ -335,38 +398,29 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _in_progress_path(self, path: Path) -> Path:
|
||||
return path.with_name(path.name + ".downloading")
|
||||
|
||||
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
token = None
|
||||
for pair in self._app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, str(source)):
|
||||
token = pair.token
|
||||
break
|
||||
return token
|
||||
|
||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.RUNNING
|
||||
if job.on_start:
|
||||
try:
|
||||
job.on_start(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_start")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_started(job)
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
try:
|
||||
job.on_progress(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_progress")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_progress(job)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
if job.on_complete:
|
||||
try:
|
||||
job.on_complete(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_complete")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_complete(job)
|
||||
|
||||
@ -374,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
return
|
||||
job.status = DownloadJobStatus.CANCELLED
|
||||
if job.on_cancelled:
|
||||
try:
|
||||
job.on_cancelled(job)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_cancelled")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
|
||||
# if multifile download, then signal the parent
|
||||
if parent_job := self._download_part2parent.get(job.source, None):
|
||||
if not parent_job.in_terminal_state:
|
||||
parent_job.status = DownloadJobStatus.CANCELLED
|
||||
self._execute_cb(parent_job, "on_cancelled")
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||
if job.on_error:
|
||||
try:
|
||||
job.on_error(job, excp)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
self._execute_cb(job, "on_error", excp)
|
||||
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_error(job)
|
||||
|
||||
@ -406,6 +455,97 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except OSError as excp:
|
||||
self._logger.warning(excp)
|
||||
|
||||
########################################
|
||||
# callbacks used for multifile downloads
|
||||
########################################
|
||||
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"File download started: {download_job.source}")
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
if mf_job.waiting:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.status = DownloadJobStatus.RUNNING
|
||||
assert download_job.download_path is not None
|
||||
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||
mf_job.download_path = (
|
||||
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||
) # keep just the first component of the path
|
||||
self._execute_cb(mf_job, "on_start")
|
||||
|
||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
if mf_job.cancelled:
|
||||
for part in mf_job.download_parts:
|
||||
self.cancel_job(part)
|
||||
elif mf_job.running:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
self._execute_cb(mf_job, "on_progress")
|
||||
|
||||
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Download complete: {download_job.source}")
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
||||
mf_job.status = DownloadJobStatus.COMPLETED
|
||||
self._execute_cb(mf_job, "on_complete")
|
||||
|
||||
# we're done with this sub-job
|
||||
self._job_terminated_event.set()
|
||||
|
||||
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
assert mf_job is not None
|
||||
|
||||
if not mf_job.in_terminal_state:
|
||||
self._logger.warning(f"Download cancelled: {download_job.source}")
|
||||
mf_job.cancel()
|
||||
|
||||
for s in mf_job.download_parts:
|
||||
self.cancel_job(s)
|
||||
|
||||
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
mf_job = self._download_part2parent[download_job.source]
|
||||
assert mf_job is not None
|
||||
if not mf_job.in_terminal_state:
|
||||
mf_job.status = download_job.status
|
||||
mf_job.error = download_job.error
|
||||
mf_job.error_type = download_job.error_type
|
||||
self._execute_cb(mf_job, "on_error", excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
for s in [x for x in mf_job.download_parts if x.running]:
|
||||
self.cancel_job(s)
|
||||
self._download_part2parent.pop(download_job.source)
|
||||
self._job_terminated_event.set()
|
||||
|
||||
def _execute_cb(
|
||||
self,
|
||||
job: DownloadJob | MultiFileDownloadJob,
|
||||
callback_name: Literal[
|
||||
"on_start",
|
||||
"on_progress",
|
||||
"on_complete",
|
||||
"on_cancelled",
|
||||
"on_error",
|
||||
],
|
||||
excp: Optional[Exception] = None,
|
||||
) -> None:
|
||||
if callback := getattr(job, callback_name, None):
|
||||
args = [job, excp] if excp else [job]
|
||||
try:
|
||||
callback(*args)
|
||||
except Exception as e:
|
||||
self._logger.error(
|
||||
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
|
||||
|
||||
def get_pc_name_max(directory: str) -> int:
|
||||
if hasattr(os, "pathconf"):
|
||||
|
@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
:param source: A string representing a URL or repo_id.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
@ -26,13 +26,6 @@ class InstallStatus(str, Enum):
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
|
@ -5,21 +5,22 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from pydantic_core import Url
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.util import slugify
|
||||
|
||||
from .model_install_common import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
@ -91,7 +93,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._install_thread: Optional[threading.Thread] = None
|
||||
@ -210,33 +212,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
) -> ModelInstallJob:
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source), inplace=inplace)
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=match.group(2) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
"""Install a model using pattern matching to infer the type of source."""
|
||||
source_obj = self._guess_source(source)
|
||||
if isinstance(source_obj, LocalModelSource):
|
||||
source_obj.inplace = inplace
|
||||
elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource):
|
||||
source_obj.access_token = access_token
|
||||
return self.import_model(source_obj, config)
|
||||
|
||||
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
|
||||
@ -297,8 +278,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
job.cancel()
|
||||
with self._lock:
|
||||
self._cancel_download_parts(job)
|
||||
self._logger.warning(f"Cancelling {job.source}")
|
||||
if dj := job._multifile_job:
|
||||
self._download_queue.cancel_job(dj)
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
@ -346,7 +328,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
@ -386,38 +368,92 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
rmtree(model_path)
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
@classmethod
|
||||
def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path:
|
||||
escaped_source = slugify(str(source))
|
||||
return app_config.download_cache_path / escaped_source
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.convert_cache_path / model_hash
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
contents: List[Path] = list(model_path.iterdir())
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
self._download_queue.wait_for_job(job)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
def _remote_files_from_source(
|
||||
self, source: ModelSource
|
||||
) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]:
|
||||
metadata = None
|
||||
if isinstance(source, HFModelSource):
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
), metadata
|
||||
|
||||
if isinstance(source, URLModelSource):
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
return metadata.download_urls(session=self._session), metadata
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None
|
||||
|
||||
raise Exception(f"No files associated with {source}")
|
||||
|
||||
def _guess_source(self, source: str) -> ModelSource:
|
||||
"""Turn a source string into a ModelSource object."""
|
||||
variants = "|".join(ModelRepoVariant.__members__.values())
|
||||
hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$"
|
||||
source_obj: Optional[StringLikeSource] = None
|
||||
|
||||
if Path(source).exists(): # A local file or directory
|
||||
source_obj = LocalModelSource(path=Path(source))
|
||||
elif match := re.match(hf_repoid_re, source):
|
||||
source_obj = HFModelSource(
|
||||
repo_id=match.group(1),
|
||||
variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than ''
|
||||
subfolder=Path(match.group(3)) if match.group(3) else None,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=Url(source),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
return source_obj
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@ -478,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||
job.set_error(
|
||||
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||
multifile_download_job = install_job._multifile_job
|
||||
if multifile_download_job and any(
|
||||
x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts
|
||||
):
|
||||
install_job.set_error(
|
||||
InvalidModelConfigException(
|
||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
)
|
||||
)
|
||||
else:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
install_job.set_error(excp)
|
||||
self._signal_job_errored(install_job)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
@ -513,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
@ -525,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if resolved_path in installed_model_paths:
|
||||
return True
|
||||
# Skip core models entirely - these aren't registered with the model manager.
|
||||
if str(resolved_path).startswith(str(self.app_config.models_path / "core")):
|
||||
return False
|
||||
for special_directory in [
|
||||
self.app_config.models_path / "core",
|
||||
self.app_config.convert_cache_dir,
|
||||
self.app_config.download_cache_dir,
|
||||
]:
|
||||
if resolved_path.is_relative_to(special_directory):
|
||||
return False
|
||||
try:
|
||||
model_id = self.register_path(model_path)
|
||||
self._logger.info(f"Registered {model_path.name} with id {model_id}")
|
||||
@ -641,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
inplace=source.inplace or False,
|
||||
)
|
||||
|
||||
def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
def _import_from_hf(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
if not source.access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
variant=source.variant or self._guess_variant(),
|
||||
subfolder=source.subfolder,
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
if source.access_token is None:
|
||||
source.access_token = HfFolder.get_token()
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -662,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||
# URLs from HuggingFace will be handled specially
|
||||
metadata = None
|
||||
fetcher = None
|
||||
try:
|
||||
fetcher = self.get_fetcher_from_url(str(source.url))
|
||||
except ValueError:
|
||||
pass
|
||||
kwargs: dict[str, Any] = {"session": self._session}
|
||||
if fetcher is not None:
|
||||
metadata = fetcher(**kwargs).from_url(source.url)
|
||||
self._logger.debug(f"metadata={metadata}")
|
||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||
remote_files = metadata.download_urls(session=self._session)
|
||||
else:
|
||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||
def _import_from_url(
|
||||
self,
|
||||
source: URLModelSource,
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
remote_files, metadata = self._remote_files_from_source(source)
|
||||
return self._import_remote_model(
|
||||
source=source,
|
||||
config=config,
|
||||
@ -692,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
tmpdir = Path(
|
||||
destdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
@ -708,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
source=source,
|
||||
config_in=config or {},
|
||||
source_metadata=metadata,
|
||||
local_path=tmpdir, # local path may change once the download has started due to content-disposition handling
|
||||
local_path=destdir, # local path may change once the download has started due to content-disposition handling
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
if isinstance(source, HFModelSource) and source.subfolder:
|
||||
root = Path(remote_files[0].path.parts[0])
|
||||
subfolder = root / source.subfolder
|
||||
else:
|
||||
root = Path(".")
|
||||
subfolder = Path(".")
|
||||
# remember the temporary directory for later removal
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job.total_bytes = sum((x.size or 0) for x in remote_files)
|
||||
|
||||
# we remember the path up to the top of the tmpdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = tmpdir
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
multifile_job = self._multifile_download(
|
||||
remote_files=remote_files,
|
||||
dest=destdir,
|
||||
subfolder=source.subfolder if isinstance(source, HFModelSource) else None,
|
||||
access_token=source.access_token,
|
||||
submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict
|
||||
)
|
||||
self._download_cache[multifile_job.id] = install_job
|
||||
install_job._multifile_job = multifile_job
|
||||
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})")
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.debug(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = tmpdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
dest=dest,
|
||||
access_token=source.access_token,
|
||||
)
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
self._download_queue.submit_multifile_download(multifile_job)
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
@ -768,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
size += sum(self._stat_size(Path(root, x)) for x in files)
|
||||
return size
|
||||
|
||||
def _multifile_download(
|
||||
self,
|
||||
remote_files: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
subfolder: Optional[Path] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
) -> MultiFileDownloadJob:
|
||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert model_file.size is not None
|
||||
parts.append(
|
||||
RemoteModelFile(
|
||||
url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json
|
||||
path=path_to_add / model_file.path.relative_to(path_to_remove),
|
||||
)
|
||||
)
|
||||
|
||||
return self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=dest,
|
||||
access_token=access_token,
|
||||
submit_job=submit_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download started: {download_job.source}")
|
||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir:
|
||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
||||
dest_name = partial_path.parts[0]
|
||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
||||
|
||||
# Update the total bytes count for remote sources.
|
||||
if not install_job.total_bytes:
|
||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
||||
|
||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
assert download_job.download_path
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.download_parts = download_job.download_parts
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download complete: {download_job.source}")
|
||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
if install_job := self._download_cache.get(download_job.id, None):
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
self._cancel_download_parts(install_job)
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
if install_job := self._download_cache.pop(download_job.id, None):
|
||||
self._downloads_changed_event.set()
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
||||
# do not lock here because it gets called within a locked context
|
||||
for s in install_job.download_parts:
|
||||
self._download_queue.cancel_job(s)
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._put_in_queue(install_job)
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
@ -861,6 +876,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
assert job._multifile_job is not None
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
@ -875,6 +893,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
@ -890,7 +910,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||
"""
|
||||
Return a metadata fetcher appropriate for provided url.
|
||||
|
||||
This used to be more useful, but the number of supported model
|
||||
sources has been reduced to HuggingFace alone.
|
||||
"""
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -2,10 +2,11 @@
|
||||
"""Base class for model loader."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
@ -31,3 +32,26 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
This will load an arbitrary model file into the RAM cache. If the optional loader
|
||||
argument is provided, the loader will be invoked to load the model into
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||
LoadedModel, but without the config attribute.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
loader: A Callable that expects a Path and returns a Dict[str, Tensor]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
|
@ -1,18 +1,26 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from typing import Optional, Type
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Type
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
from torch import load as torch_load
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_load_base import ModelLoadServiceBase
|
||||
@ -75,3 +83,41 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
def torch_load_file(checkpoint: Path) -> AnyModel:
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.")
|
||||
result = torch_load(checkpoint, map_location="cpu")
|
||||
return result
|
||||
|
||||
def diffusers_load_directory(directory: Path) -> AnyModel:
|
||||
load_class = GenericDiffusersLoader(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self.convert_cache,
|
||||
).get_hf_load_class(directory)
|
||||
return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())
|
||||
|
||||
loader = loader or (
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
assert loader is not None
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
|
@ -12,15 +12,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
ControlAdapterDefaultSettings,
|
||||
MainModelDefaultSettings,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
@ -14,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
@ -320,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Checks if a model exists.
|
||||
"""Check if a model exists.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -331,13 +341,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
else:
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model.
|
||||
"""Load a model.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -361,7 +371,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""Loads a model by its attributes.
|
||||
"""Load a model by its attributes.
|
||||
|
||||
Args:
|
||||
name: Name of the model.
|
||||
@ -384,7 +394,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
"""Get a model's config.
|
||||
|
||||
Args:
|
||||
identifier: The key or ModelField representing the model.
|
||||
@ -394,11 +404,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
else:
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Searches for models by path.
|
||||
"""Search for models by path.
|
||||
|
||||
Args:
|
||||
path: The path to search for.
|
||||
@ -415,7 +425,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
type: Optional[ModelType] = None,
|
||||
format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""Searches for models by attributes.
|
||||
"""Search for models by attributes.
|
||||
|
||||
Args:
|
||||
name: The name to search for (exact match).
|
||||
@ -434,6 +444,72 @@ class ModelsInterface(InvocationContextInterface):
|
||||
model_format=format,
|
||||
)
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
This can be used to single-file install models and other resources of arbitrary types
|
||||
which should not get registered with the database. If the model is already
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file located at the indicated path
|
||||
|
||||
If a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
path: A model Path
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], AnyModel]] = None,
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL or repo_id.
|
||||
|
||||
If the model is already downloaded, it will be loaded from the cache.
|
||||
|
||||
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
source: A URL or huggingface repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
|
@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_8(app_config=config))
|
||||
migrator.register_migration(build_migration_9())
|
||||
migrator.register_migration(build_migration_10())
|
||||
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,75 @@
|
||||
import shutil
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
LEGACY_CORE_MODELS = [
|
||||
# OpenPose
|
||||
"any/annotators/dwpose/yolox_l.onnx",
|
||||
"any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
# DepthAnything
|
||||
"any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||
"any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||
"any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||
# Lama inpaint
|
||||
"core/misc/lama/lama.pt",
|
||||
# RealESRGAN upscale
|
||||
"core/upscaling/realesrgan/RealESRGAN_x4plus.pth",
|
||||
"core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth",
|
||||
"core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
"core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||
]
|
||||
|
||||
|
||||
class Migration11Callback:
|
||||
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._remove_convert_cache()
|
||||
self._remove_downloaded_models()
|
||||
self._remove_unused_core_models()
|
||||
|
||||
def _remove_convert_cache(self) -> None:
|
||||
"""Rename models/.cache to models/.convert_cache."""
|
||||
self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.")
|
||||
legacy_convert_path = self._app_config.root_path / "models" / ".cache"
|
||||
shutil.rmtree(legacy_convert_path, ignore_errors=True)
|
||||
|
||||
def _remove_downloaded_models(self) -> None:
|
||||
"""Remove models from their old locations; they will re-download when needed."""
|
||||
self._logger.info(
|
||||
"Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache."
|
||||
)
|
||||
for model_path in LEGACY_CORE_MODELS:
|
||||
legacy_dest_path = self._app_config.models_path / model_path
|
||||
legacy_dest_path.unlink(missing_ok=True)
|
||||
|
||||
def _remove_unused_core_models(self) -> None:
|
||||
"""Remove unused core models and their directories."""
|
||||
self._logger.info("Removing defunct core models.")
|
||||
for dir in ["face_restoration", "misc", "upscaling"]:
|
||||
path_to_remove = self._app_config.models_path / "core" / dir
|
||||
shutil.rmtree(path_to_remove, ignore_errors=True)
|
||||
shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True)
|
||||
|
||||
|
||||
def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
|
||||
"""
|
||||
Build the migration from database version 10 to 11.
|
||||
|
||||
This migration does the following:
|
||||
- Moves "core" models previously downloaded with download_with_progress_bar() into new
|
||||
"models/.download_cache" directory.
|
||||
- Renames "models/.cache" to "models/.convert_cache".
|
||||
"""
|
||||
migration_11 = Migration(
|
||||
from_version=10,
|
||||
to_version=11,
|
||||
callback=Migration11Callback(app_config=app_config, logger=logger),
|
||||
)
|
||||
|
||||
return migration_11
|
@ -1,51 +0,0 @@
|
||||
from pathlib import Path
|
||||
from urllib import request
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
"""Simple progress bar for urllib.request.urlretrieve using tqdm."""
|
||||
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num: int, block_size: int, total_size: int):
|
||||
if not self.pbar:
|
||||
self.pbar = tqdm(
|
||||
desc=self.name,
|
||||
initial=0,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
total=total_size,
|
||||
)
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool:
|
||||
"""Download a file from a URL to a destination path, with a progress bar.
|
||||
If the file already exists, it will not be downloaded again.
|
||||
|
||||
Exceptions are not caught.
|
||||
|
||||
Args:
|
||||
name (str): Name of the file being downloaded.
|
||||
url (str): URL to download the file from.
|
||||
dest_path (Path): Destination path to save the file to.
|
||||
|
||||
Returns:
|
||||
bool: True if the file was downloaded, False if it already existed.
|
||||
"""
|
||||
if dest_path.exists():
|
||||
return False # already downloaded
|
||||
|
||||
InvokeAILogger.get_logger().info(f"Downloading {name}...")
|
||||
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
request.urlretrieve(url, dest_path, ProgressBar(name))
|
||||
|
||||
return True
|
@ -1,5 +1,5 @@
|
||||
import pathlib
|
||||
from typing import Literal, Union
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -10,28 +10,17 @@ from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitl14.pth",
|
||||
},
|
||||
"base": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vitb14.pth",
|
||||
},
|
||||
"small": {
|
||||
"url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
"local": "any/annotators/depth_anything/depth_anything_vits14.pth",
|
||||
},
|
||||
"large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true",
|
||||
"base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true",
|
||||
"small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true",
|
||||
}
|
||||
|
||||
|
||||
@ -53,36 +42,27 @@ transform = Compose(
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
download_with_progress_bar(
|
||||
pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name,
|
||||
DEPTH_ANYTHING_MODELS[model_size]["url"],
|
||||
DEPTH_ANYTHING_MODEL_PATH,
|
||||
)
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
if not self.model or model_size != self.model_size:
|
||||
del self.model
|
||||
self.model_size = model_size
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
match self.model_size:
|
||||
case "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
|
@ -1,30 +1,53 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
"dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
}
|
||||
|
||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
||||
|
||||
def draw_pose(
|
||||
pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]],
|
||||
H: int,
|
||||
W: int,
|
||||
draw_face: bool = True,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = True,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
candidate = bodies["candidate"]
|
||||
|
||||
assert isinstance(bodies, dict)
|
||||
subset = bodies["subset"]
|
||||
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
canvas = draw_facepose(canvas, faces)
|
||||
assert isinstance(hands, np.ndarray)
|
||||
canvas = draw_facepose(canvas, faces) # type: ignore
|
||||
|
||||
dwpose_image = resize_image(
|
||||
dwpose_image: Image.Image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
@ -39,11 +62,16 @@ class DWOpenposeDetector:
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pose_estimation = Wholebody()
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path) -> None:
|
||||
self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose)
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||
self,
|
||||
image: Image.Image,
|
||||
draw_face: bool = False,
|
||||
draw_body: bool = True,
|
||||
draw_hands: bool = False,
|
||||
resolution: int = 512,
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
@ -79,3 +107,6 @@ class DWOpenposeDetector:
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"]
|
||||
|
@ -5,11 +5,13 @@ import math
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
eps = 0.01
|
||||
NDArrayInt = npt.NDArray[np.uint8]
|
||||
|
||||
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
candidate = np.array(candidate)
|
||||
subset = np.array(subset)
|
||||
@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas, all_hand_peaks):
|
||||
def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [
|
||||
@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks):
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas, all_lmks):
|
||||
def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt:
|
||||
H, W, C = canvas.shape
|
||||
for lmks in all_lmks:
|
||||
lmks = np.array(lmks)
|
||||
|
@ -2,47 +2,26 @@
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": {
|
||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
},
|
||||
"dw-ll_ucoco_384.onnx": {
|
||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
},
|
||||
}
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
def __init__(self, onnx_det: Path, onnx_pose: Path):
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"]
|
||||
download_with_progress_bar(
|
||||
"dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH
|
||||
)
|
||||
|
||||
onnx_det = DET_MODEL_PATH
|
||||
onnx_pose = POSE_MODEL_PATH
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
import gc
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@ -6,9 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -19,28 +17,11 @@ def norm_img(np_img):
|
||||
return np_img
|
||||
|
||||
|
||||
def load_jit_model(url_or_path, device):
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model = torch.jit.load(model_path, map_location="cpu").to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class LaMA:
|
||||
def __init__(self, model: AnyModel):
|
||||
self._model = model
|
||||
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = TorchDevice.choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
||||
image = norm_img(image)
|
||||
|
||||
@ -48,20 +29,25 @@ class LaMA:
|
||||
mask = np.asarray(mask)
|
||||
mask = np.invert(mask)
|
||||
mask = norm_img(mask)
|
||||
|
||||
mask = (mask > 0) * 1
|
||||
|
||||
device = next(self._model.buffers()).device
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
with torch.inference_mode():
|
||||
infilled_image = model(image, mask)
|
||||
infilled_image = self._model(image, mask)
|
||||
|
||||
infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||
infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8")
|
||||
infilled_image = Image.fromarray(infilled_image)
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return infilled_image
|
||||
|
||||
@staticmethod
|
||||
def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module:
|
||||
model_path = url_or_path
|
||||
logger.info(f"Loading model from: {model_path}")
|
||||
model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore
|
||||
model.eval()
|
||||
return model
|
||||
|
@ -1,6 +1,5 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
@ -11,6 +10,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
@ -52,7 +52,7 @@ class RealESRGAN:
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
model_path: Path,
|
||||
loadnet: AnyModel,
|
||||
model: RRDBNet,
|
||||
tile: int = 0,
|
||||
tile_pad: int = 10,
|
||||
@ -67,8 +67,6 @@ class RealESRGAN:
|
||||
self.half = half
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
# prefer to use params_ema
|
||||
if "params_ema" in loadnet:
|
||||
keyname = "params_ema"
|
||||
|
@ -36,7 +36,7 @@ from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum):
|
||||
class ModelRepoVariant(str, Enum):
|
||||
"""Various hugging face variants on the diffusers format."""
|
||||
|
||||
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||
Default = "" # model files without "fp16" or other qualifier
|
||||
FP16 = "fp16"
|
||||
FP32 = "fp32"
|
||||
ONNX = "onnx"
|
||||
|
@ -7,7 +7,7 @@ from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
@ -19,6 +19,7 @@ for module in loaders:
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
|
@ -7,6 +7,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.backend.util import GIG, directory_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import safe_filename
|
||||
|
||||
from .convert_cache_base import ModelConvertCacheBase
|
||||
|
||||
@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase):
|
||||
|
||||
def cache_path(self, key: str) -> Path:
|
||||
"""Return the path for a model with the indicated key."""
|
||||
key = safe_filename(self._cache_path, key)
|
||||
return self._cache_path / key
|
||||
|
||||
def make_room(self, size: float) -> None:
|
||||
|
@ -23,7 +23,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
class LoadedModelWithoutConfig:
|
||||
"""
|
||||
Context manager object that mediates transfer from RAM<->VRAM.
|
||||
|
||||
@ -61,7 +61,6 @@ class LoadedModel:
|
||||
not have a state_dict, in which case this value will be None.
|
||||
"""
|
||||
|
||||
config: AnyModelConfig
|
||||
_locker: ModelLockerBase
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
@ -89,6 +88,13 @@ class LoadedModel:
|
||||
return self._locker.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||
# know about. I think the problem may be related to this class being an ABC.
|
||||
|
@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
else:
|
||||
@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
config.key,
|
||||
submodel_type=submodel_type,
|
||||
model=loaded_model,
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(
|
||||
@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
if subtype == submodel_type:
|
||||
continue
|
||||
if submodel := getattr(pipeline, subtype.value, None):
|
||||
self._ram_cache.put(
|
||||
config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel)
|
||||
)
|
||||
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
|
||||
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
|
||||
|
||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||
|
@ -169,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
model: T,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
|
@ -29,6 +29,7 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@ -153,13 +154,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
key: str,
|
||||
model: AnyModel,
|
||||
size: int,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
@ -252,12 +253,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
# These attributes are not in the base ModelMixin class but in various derived classes.
|
||||
# Some models don't have these attributes, in which case they run in RAM/CPU.
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
@ -265,6 +261,10 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
|
@ -35,10 +35,6 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return self.model
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
@ -59,9 +55,6 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
|
@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
if class_name := config.get("_class_name"):
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
elif class_name := config.get("architectures"):
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
else:
|
||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
assert s.size is not None
|
||||
files.append(
|
||||
RemoteModelFile(
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant),
|
||||
url=hf_hub_url(id, s.rfilename, revision=variant or "main"),
|
||||
path=Path(name, s.rfilename),
|
||||
size=s.size,
|
||||
sha256=s.lfs.get("sha256") if s.lfs else None,
|
||||
|
@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel):
|
||||
|
||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||
path: Path = Field(description="The path to the file, relative to the model root")
|
||||
size: int = Field(description="The size of this file, in bytes")
|
||||
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
@ -1,6 +1,8 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import unicodedata
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
@ -12,6 +14,33 @@ from transformers import logging as transformers_logging
|
||||
GIG = 1073741824
|
||||
|
||||
|
||||
def slugify(value: str, allow_unicode: bool = False) -> str:
|
||||
"""
|
||||
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
||||
dashes to single dashes. Remove characters that aren't alphanumerics,
|
||||
underscores, or hyphens. Replace slashes with underscores.
|
||||
Convert to lowercase. Also strip leading and
|
||||
trailing whitespace, dashes, and underscores.
|
||||
|
||||
Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py
|
||||
"""
|
||||
value = str(value)
|
||||
if allow_unicode:
|
||||
value = unicodedata.normalize("NFKC", value)
|
||||
else:
|
||||
value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
|
||||
value = re.sub(r"[/]", "_", value.lower())
|
||||
value = re.sub(r"[^.\w\s-]", "", value.lower())
|
||||
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
||||
|
||||
|
||||
def safe_filename(directory: Path, value: str) -> str:
|
||||
"""Make a string safe to use as a filename."""
|
||||
escaped_string = slugify(value)
|
||||
max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256
|
||||
return escaped_string[len(escaped_string) - max_name_length :]
|
||||
|
||||
|
||||
def directory_size(directory: Path) -> int:
|
||||
"""
|
||||
Return the aggregate size of all files in a directory (bytes).
|
||||
|
@ -2,14 +2,18 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import (
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False # type: ignore
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob) -> None:
|
||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.download(
|
||||
@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=15, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/model_index.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
files = metadata.download_urls(session=mm2_session)
|
||||
# this will give a 404 error
|
||||
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
||||
job = queue.multifile_download(
|
||||
parts=files,
|
||||
dest=tmp_path,
|
||||
on_start=event_handler,
|
||||
on_progress=event_handler,
|
||||
on_complete=event_handler,
|
||||
on_error=event_handler,
|
||||
)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||
assert job.error_type is not None
|
||||
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||
assert DownloadJobStatus.ERROR in events
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
|
||||
def cancelled_callback(job: DownloadJob) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
dest=tmp_path,
|
||||
on_cancelled=cancelled_callback,
|
||||
)
|
||||
queue.cancel_job(job)
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert DownloadCancelledEvent in [type(x) for x in events]
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=[
|
||||
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||
],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError) as error:
|
||||
queue.multifile_download(
|
||||
parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert str(error.value) == "only relative download paths accepted"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=mm2_session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
# this one doesn't
|
||||
job2 = queue.download(
|
||||
source=AnyHttpUrl(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
# this token is defined in the temporary root invokeai.yaml
|
||||
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||
assert job1.access_token == "cv_12345"
|
||||
assert job2.access_token is None
|
||||
queue.stop()
|
||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.events.events_common import (
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
@ -29,7 +30,14 @@ from invokeai.app.services.model_install.model_install_common import (
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
OS = platform.uname().system
|
||||
@ -222,7 +230,7 @@ def test_delete_register(
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
@ -243,15 +251,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
model_record = store.get_model(key)
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
assert len(bus.events) == 5
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
|
||||
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
|
||||
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started
|
||||
assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
@ -277,6 +286,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert len(bus.events) >= 3
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_types = [type(x) for x in bus.events]
|
||||
assert all(
|
||||
x in event_types
|
||||
for x in [
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
]
|
||||
)
|
||||
|
||||
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
|
||||
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
|
||||
assert completed_events[0].total_bytes == downloading_events[-1].bytes
|
||||
assert job.total_bytes == completed_events[0].total_bytes
|
||||
print(downloading_events[-1])
|
||||
print(job.download_parts)
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
@ -308,7 +360,6 @@ def test_other_error_during_install(
|
||||
assert job.error == "Test error"
|
||||
|
||||
|
||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
@ -326,7 +377,7 @@ def test_other_error_during_install(
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=40, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
@ -342,7 +393,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
@ -351,6 +402,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
88
tests/app/services/model_load/test_load_api.py
Normal file
88
tests/app/services/model_load/test_load_api.py
Normal file
@ -0,0 +1,88 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def mock_context(
|
||||
mock_services: InvocationServices,
|
||||
mm2_model_manager: ModelManagerServiceBase,
|
||||
) -> InvocationContext:
|
||||
mock_services.model_manager = mm2_model_manager
|
||||
return build_invocation_context(
|
||||
services=mock_services,
|
||||
data=None, # type: ignore
|
||||
is_canceled=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path.is_file()
|
||||
assert downloaded_path.exists()
|
||||
assert downloaded_path.name == "test_embedding.safetensors"
|
||||
assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache"
|
||||
|
||||
downloaded_path_2 = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert downloaded_path == downloaded_path_2
|
||||
|
||||
|
||||
def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None:
|
||||
downloaded_path = mock_context.models.download_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_local_model(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_local_model(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
assert isinstance(loaded_model_3.model, dict)
|
||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="This requires a test model to load")
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_local_model(vae_directory)
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors")
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
def test_download_diffusers(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo")
|
||||
assert (model_path / "model_index.json").exists()
|
||||
assert (model_path / "vae").is_dir()
|
||||
|
||||
|
||||
def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
||||
model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae")
|
||||
assert model_path.is_dir()
|
||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||
model_path / "diffusion_pytorch_model.safetensors"
|
||||
).exists()
|
@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
# Can be used to test diffusers model directory loading, but
|
||||
# the test file adds ~10MB of space.
|
||||
# @pytest.fixture
|
||||
# def vae_directory(mm2_model_files: Path) -> Path:
|
||||
# return mm2_model_files / "taesdxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test-diffusers-main"
|
||||
@ -294,4 +301,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
Loading…
Reference in New Issue
Block a user