diff --git a/docs/contributing/DOWNLOAD_QUEUE.md b/docs/contributing/DOWNLOAD_QUEUE.md index d43c670d2c..960180961e 100644 --- a/docs/contributing/DOWNLOAD_QUEUE.md +++ b/docs/contributing/DOWNLOAD_QUEUE.md @@ -128,7 +128,8 @@ The queue operates on a series of download job objects. These objects specify the source and destination of the download, and keep track of the progress of the download. -The only job type currently implemented is `DownloadJob`, a pydantic object with the +Two job types are defined. `DownloadJob` and +`MultiFileDownloadJob`. The former is a pydantic object with the following fields: | **Field** | **Type** | **Default** | **Description** | @@ -138,7 +139,7 @@ following fields: | `dest` | Path | | Where to download to | | `access_token` | str | | [optional] string containing authentication token for access | | `on_start` | Callable | | [optional] callback when the download starts | -| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | | `on_complete` | Callable | | [optional] callback called after successful download completion | | `on_error` | Callable | | [optional] callback called after an error occurs | | `id` | int | auto assigned | Job ID, an integer >= 0 | @@ -190,6 +191,33 @@ A cancelled job will have status `DownloadJobStatus.ERROR` and an `error_type` field of "DownloadJobCancelledException". In addition, the job's `cancelled` property will be set to True. +The `MultiFileDownloadJob` is used for diffusers model downloads, +which contain multiple files and directories under a common root: + +| **Field** | **Type** | **Default** | **Description** | +|----------------|-----------------|---------------|-----------------| +| _Fields passed in at job creation time_ | +| `download_parts` | Set[DownloadJob]| | Component download jobs | +| `dest` | Path | | Where to download to | +| `on_start` | Callable | | [optional] callback when the download starts | +| `on_progress` | Callable | | [optional] callback called at intervals during download progress | +| `on_complete` | Callable | | [optional] callback called after successful download completion | +| `on_error` | Callable | | [optional] callback called after an error occurs | +| `id` | int | auto assigned | Job ID, an integer >= 0 | +| _Fields updated over the course of the download task_ +| `status` | DownloadJobStatus| | Status code | +| `download_path` | Path | | Path to the root of the downloaded files | +| `bytes` | int | 0 | Bytes downloaded so far | +| `total_bytes` | int | 0 | Total size of the file at the remote site | +| `error_type` | str | | String version of the exception that caused an error during download | +| `error` | str | | String version of the traceback associated with an error | +| `cancelled` | bool | False | Set to true if the job was cancelled by the caller| + +Note that the MultiFileDownloadJob does not support the `priority`, +`job_started`, `job_ended` or `content_type` attributes. You can get +these from the individual download jobs in `download_parts`. + + ### Callbacks Download jobs can be associated with a series of callbacks, each with @@ -251,11 +279,40 @@ jobs using `list_jobs()`, fetch a single job by its with running jobs with `cancel_all_jobs()`, and wait for all jobs to finish with `join()`. -#### job = queue.download(source, dest, priority, access_token) +#### job = queue.download(source, dest, priority, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) Create a new download job and put it on the queue, returning the DownloadJob object. +#### multifile_job = queue.multifile_download(parts, dest, access_token, on_start, on_progress, on_complete, on_cancelled, on_error) + +This is similar to download(), but instead of taking a single source, +it accepts a `parts` argument consisting of a list of +`RemoteModelFile` objects. Each part corresponds to a URL/Path pair, +where the URL is the location of the remote file, and the Path is the +destination. + +`RemoteModelFile` can be imported from `invokeai.backend.model_manager.metadata`, and +consists of a url/path pair. Note that the path *must* be relative. + +The method returns a `MultiFileDownloadJob`. + + +``` +from invokeai.backend.model_manager.metadata import RemoteModelFile +remote_file_1 = RemoteModelFile(url='http://www.foo.bar/my/pytorch_model.safetensors'', + path='my_model/textencoder/pytorch_model.safetensors' + ) +remote_file_2 = RemoteModelFile(url='http://www.bar.baz/vae.ckpt', + path='my_model/vae/diffusers_model.safetensors' + ) +job = queue.multifile_download(parts=[remote_file_1, remote_file_2], + dest='/tmp/downloads', + on_progress=TqdmProgress().update) +queue.wait_for_job(job) +print(f"The files were downloaded to {job.download_path}") +``` + #### jobs = queue.list_jobs() Return a list of all active and inactive `DownloadJob`s. diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 201d11995d..9699db4f1a 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -397,26 +397,25 @@ In the event you wish to create a new installer, you may use the following initialization pattern: ``` -from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config import get_config from invokeai.app.services.model_records import ModelRecordServiceSQL from invokeai.app.services.model_install import ModelInstallService from invokeai.app.services.download import DownloadQueueService -from invokeai.app.services.shared.sqlite import SqliteDatabase +from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.backend.util.logging import InvokeAILogger -config = InvokeAIAppConfig.get_config() -config.parse_args() +config = get_config() logger = InvokeAILogger.get_logger(config=config) -db = SqliteDatabase(config, logger) +db = SqliteDatabase(config.db_path, logger) record_store = ModelRecordServiceSQL(db) queue = DownloadQueueService() queue.start() -installer = ModelInstallService(app_config=config, +installer = ModelInstallService(app_config=config, record_store=record_store, - download_queue=queue - ) + download_queue=queue + ) installer.start() ``` @@ -1602,3 +1601,59 @@ This method takes a model key, looks it up using the `ModelRecordServiceBase` object in `mm.store`, and passes the returned model configuration to `load_model_by_config()`. It may raise a `NotImplementedException`. + +## Invocation Context Model Manager API + +Within invocations, the following methods are available from the +`InvocationContext` object: + +### context.download_and_cache_model(source) -> Path + +This method accepts a `source` of a remote model, downloads and caches +it locally, and then returns a Path to the local model. The source can +be a direct download URL or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + +### context.load_local_model(model_path, [loader]) -> LoadedModel + +This method loads a local model from the indicated path, returning a +`LoadedModel`. The optional loader is a Callable that accepts a Path +to the object, and returns a `AnyModel` object. If no loader is +provided, then the method will use `torch.load()` for a .ckpt or .bin +checkpoint file, `safetensors.torch.load_file()` for a safetensors +checkpoint file, or `cls.from_pretrained()` for a directory that looks +like a diffusers directory. + +### context.load_remote_model(source, [loader]) -> LoadedModel + +This method accepts a `source` of a remote model, downloads and caches +it locally, loads it, and returns a `LoadedModel`. The source can be a +direct download URL or a HuggingFace repo_id. + +In the case of HuggingFace repo_id, the following variants are +recognized: + +* stabilityai/stable-diffusion-v4 -- default model +* stabilityai/stable-diffusion-v4:fp16 -- fp16 variant +* stabilityai/stable-diffusion-v4:fp16:vae -- the fp16 vae subfolder +* stabilityai/stable-diffusion-v4:onnx:vae -- the onnx variant vae subfolder + +You can also point at an arbitrary individual file within a repo_id +directory using this syntax: + +* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + + + diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 4e8103d8d3..19a7bb083d 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -93,7 +93,7 @@ class ApiDependencies: conditioning = ObjectSerializerForwardCache( ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ) - download_queue_service = DownloadQueueService(event_bus=events) + download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events) model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_manager = ModelManagerService.build_model_manager( app_config=configuration, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index f5edd49874..c0b332f27b 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -2,6 +2,7 @@ # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float +from pathlib import Path from typing import Dict, List, Literal, Union import cv2 @@ -36,12 +37,13 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.backend.image_util.canny import get_canny_edges -from invokeai.backend.image_util.depth_anything import DepthAnythingDetector -from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector +from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector +from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.util import np_to_pil, pil_to_np +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output @@ -139,6 +141,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): return context.images.get_pil(self.image.image_name, "RGB") def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context raw_image = self.load_image(context) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) @@ -284,7 +287,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): # depth_and_normal not supported in controlnet_aux v0.0.3 # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: + # TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") processed_image = midas_processor( image, @@ -311,7 +315,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") processed_image = normalbae_processor( image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution @@ -330,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation): thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( image, @@ -353,7 +357,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation): safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( image, @@ -381,7 +385,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): w: int = InputField(default=512, ge=0, description="Content shuffle `w` parameter") f: int = InputField(default=256, ge=0, description="Content shuffle `f` parameter") - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: content_shuffle_processor = ContentShuffleDetector() processed_image = content_shuffle_processor( image, @@ -405,7 +409,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image @@ -426,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: mediapipe_face_processor = MediapipeFaceDetector() processed_image = mediapipe_face_processor( image, @@ -454,7 +458,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( image, @@ -496,8 +500,8 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation): np_img = cv2.resize(np_img, (W, H), interpolation=cv2.INTER_AREA) return np_img - def run_processor(self, img): - np_img = np.array(img, dtype=np.uint8) + def run_processor(self, image: Image.Image) -> Image.Image: + np_img = np.array(image, dtype=np.uint8) processed_np_image = self.tile_resample( np_img, # res=self.tile_size, @@ -520,7 +524,7 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): detect_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.detect_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image): + def run_processor(self, image: Image.Image) -> Image.Image: # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( "ybelkada/segment-anything", subfolder="checkpoints" @@ -566,7 +570,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation): color_map_tile_size: int = InputField(default=64, ge=1, description=FieldDescriptions.tile_size) - def run_processor(self, image: Image.Image): + def run_processor(self, image: Image.Image) -> Image.Image: np_image = np.array(image, dtype=np.uint8) height, width = np_image.shape[:2] @@ -601,12 +605,18 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): ) resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - depth_anything_detector = DepthAnythingDetector() - depth_anything_detector.load_model(model_size=self.model_size) + def run_processor(self, image: Image.Image) -> Image.Image: + def loader(model_path: Path): + return DepthAnythingDetector.load_model( + model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device() + ) - processed_image = depth_anything_detector(image=image, resolution=self.resolution) - return processed_image + with self._context.models.load_remote_model( + source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader + ) as model: + depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device()) + processed_image = depth_anything_detector(image=image, resolution=self.resolution) + return processed_image @invocation( @@ -624,8 +634,11 @@ class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation): draw_hands: bool = InputField(default=False) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) - def run_processor(self, image: Image.Image): - dw_openpose = DWOpenposeDetector() + def run_processor(self, image: Image.Image) -> Image.Image: + onnx_det = self._context.models.download_and_cache_model(DWPOSE_MODELS["yolox_l.onnx"]) + onnx_pose = self._context.models.download_and_cache_model(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) + + dw_openpose = DWOpenposeDetector(onnx_det=onnx_det, onnx_pose=onnx_pose) processed_image = dw_openpose( image, draw_face=self.draw_face, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 418bc62fdc..7e1a2ee322 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -42,15 +42,16 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): """Infill the image with the specified method""" pass - def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]: + def load_image(self) -> tuple[Image.Image, bool]: """Process the image to have an alpha channel before being infilled""" - image = context.images.get_pil(self.image.image_name) + image = self._context.images.get_pil(self.image.image_name) has_alpha = True if image.mode == "RGBA" else False return image, has_alpha def invoke(self, context: InvocationContext) -> ImageOutput: + self._context = context # Retrieve and process image to be infilled - input_image, has_alpha = self.load_image(context) + input_image, has_alpha = self.load_image() # If the input image has no alpha channel, return it if has_alpha is False: @@ -133,8 +134,12 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation): """Infills transparent areas of an image using the LaMa model""" def infill(self, image: Image.Image): - lama = LaMA() - return lama(image) + with self._context.models.load_remote_model( + source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", + loader=LaMA.load_jit_model, + ) as model: + lama = LaMA(model) + return lama(image) @invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2") diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index deaf5696c6..f93060f8d3 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,5 +1,4 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -from pathlib import Path from typing import Literal import cv2 @@ -10,10 +9,8 @@ from pydantic import ConfigDict from invokeai.app.invocations.fields import ImageField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN -from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithBoard, WithMetadata @@ -52,7 +49,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): rrdbnet_model = None netscale = None - esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", @@ -95,28 +91,25 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): context.logger.error(msg) raise ValueError(msg) - esrgan_model_path = Path(context.config.get().models_path, f"core/upscaling/realesrgan/{self.model_name}") - - # Downloads the ESRGAN model if it doesn't already exist - download_with_progress_bar( - name=self.model_name, url=ESRGAN_MODEL_URLS[self.model_name], dest_path=esrgan_model_path + loadnet = context.models.load_remote_model( + source=ESRGAN_MODEL_URLS[self.model_name], ) - upscaler = RealESRGAN( - scale=netscale, - model_path=esrgan_model_path, - model=rrdbnet_model, - half=False, - tile=self.tile_size, - ) + with loadnet as loadnet_model: + upscaler = RealESRGAN( + scale=netscale, + loadnet=loadnet_model, + model=rrdbnet_model, + half=False, + tile=self.tile_size, + ) - # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL - # TODO: This strips the alpha... is that okay? - cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) - upscaled_image = upscaler.upscale(cv2_image) - pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") + # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL + # TODO: This strips the alpha... is that okay? + cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) + upscaled_image = upscaler.upscale(cv2_image) - TorchDevice.empty_cache() + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") image_dto = context.images.save(image=pil_image) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 54a092d03e..496988e853 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -86,6 +86,7 @@ class InvokeAIAppConfig(BaseSettings): patchmatch: Enable patchmatch inpaint code. models_dir: Path to the models directory. convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location. + download_cache_dir: Path to the directory that contains dynamically downloaded models. legacy_conf_dir: Path to directory of legacy checkpoint config files. db_dir: Path to InvokeAI databases directory. outputs_dir: Path to directory for outputs. @@ -146,7 +147,8 @@ class InvokeAIAppConfig(BaseSettings): # PATHS models_dir: Path = Field(default=Path("models"), description="Path to the models directory.") - convert_cache_dir: Path = Field(default=Path("models/.cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.") + download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.") legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.") db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.") outputs_dir: Path = Field(default=Path("outputs"), description="Path to directory for outputs.") @@ -303,6 +305,11 @@ class InvokeAIAppConfig(BaseSettings): """Path to the converted cache models directory, resolved to an absolute path..""" return self._resolve(self.convert_cache_dir) + @property + def download_cache_path(self) -> Path: + """Path to the downloaded models directory, resolved to an absolute path..""" + return self._resolve(self.download_cache_dir) + @property def custom_nodes_path(self) -> Path: """Path to the custom nodes directory, resolved to an absolute path..""" diff --git a/invokeai/app/services/download/__init__.py b/invokeai/app/services/download/__init__.py index 371c531387..33b0025809 100644 --- a/invokeai/app/services/download/__init__.py +++ b/invokeai/app/services/download/__init__.py @@ -1,10 +1,17 @@ """Init file for download queue.""" -from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException +from .download_base import ( + DownloadJob, + DownloadJobStatus, + DownloadQueueServiceBase, + MultiFileDownloadJob, + UnknownJobIDException, +) from .download_default import DownloadQueueService, TqdmProgress __all__ = [ "DownloadJob", + "MultiFileDownloadJob", "DownloadQueueServiceBase", "DownloadQueueService", "TqdmProgress", diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 2ac13b825f..4880ab98b8 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -5,11 +5,13 @@ from abc import ABC, abstractmethod from enum import Enum from functools import total_ordering from pathlib import Path -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Set, Union from pydantic import BaseModel, Field, PrivateAttr from pydantic.networks import AnyHttpUrl +from invokeai.backend.model_manager.metadata import RemoteModelFile + class DownloadJobStatus(str, Enum): """State of a download job.""" @@ -33,30 +35,23 @@ class ServiceInactiveException(Exception): """This exception is raised when user attempts to initiate a download before the service is started.""" -DownloadEventHandler = Callable[["DownloadJob"], None] -DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] +SingleFileDownloadEventHandler = Callable[["DownloadJob"], None] +SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None] +MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None] +MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None] +DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] +DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] -@total_ordering -class DownloadJob(BaseModel): - """Class to monitor and control a model download request.""" +class DownloadJobBase(BaseModel): + """Base of classes to monitor and control downloads.""" - # required variables to be passed in on creation - source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") - dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path") - access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") # automatically assigned on creation id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel - priority: int = Field(default=10, description="Queue priority; lower values are higher priority") - # set internally during download process + dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") + download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") - download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file") - job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") - job_ended: Optional[str] = Field( - default=None, description="Timestamp for when the download job ende1d (completed or errored)" - ) - content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") bytes: int = Field(default=0, description="Bytes downloaded so far") total_bytes: int = Field(default=0, description="Total file size (bytes)") @@ -74,14 +69,6 @@ class DownloadJob(BaseModel): _on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None) _on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None) - def __hash__(self) -> int: - """Return hash of the string representation of this object, for indexing.""" - return hash(str(self)) - - def __le__(self, other: "DownloadJob") -> bool: - """Return True if this job's priority is less than another's.""" - return self.priority <= other.priority - def cancel(self) -> None: """Call to cancel the job.""" self._cancelled = True @@ -98,6 +85,11 @@ class DownloadJob(BaseModel): """Return true if job completed without errors.""" return self.status == DownloadJobStatus.COMPLETED + @property + def waiting(self) -> bool: + """Return true if the job is waiting to run.""" + return self.status == DownloadJobStatus.WAITING + @property def running(self) -> bool: """Return true if the job is running.""" @@ -154,6 +146,37 @@ class DownloadJob(BaseModel): self._on_cancelled = on_cancelled +@total_ordering +class DownloadJob(DownloadJobBase): + """Class to monitor and control a model download request.""" + + # required variables to be passed in on creation + source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") + access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") + priority: int = Field(default=10, description="Queue priority; lower values are higher priority") + + # set internally during download process + job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started") + job_ended: Optional[str] = Field( + default=None, description="Timestamp for when the download job ende1d (completed or errored)" + ) + content_type: Optional[str] = Field(default=None, description="Content type of downloaded file") + + def __hash__(self) -> int: + """Return hash of the string representation of this object, for indexing.""" + return hash(str(self)) + + def __le__(self, other: "DownloadJob") -> bool: + """Return True if this job's priority is less than another's.""" + return self.priority <= other.priority + + +class MultiFileDownloadJob(DownloadJobBase): + """Class to monitor and control multifile downloads.""" + + download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.") + + class DownloadQueueServiceBase(ABC): """Multithreaded queue for downloading models via URL.""" @@ -201,6 +224,48 @@ class DownloadQueueServiceBase(ABC): """ pass + @abstractmethod + def multifile_download( + self, + parts: List[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + submit_job: bool = True, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + """ + Create and enqueue a multifile download job. + + :param parts: Set of URL / filename pairs + :param dest: Path to download to. See below. + :param access_token: Access token to download the indicated files. If not provided, + each file's URL may be matched to an access token using the config file matching + system. + :param submit_job: If true [default] then submit the job for execution. Otherwise, + you will need to pass the job to submit_multifile_download(). + :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated + events. + :returns: A MultiFileDownloadJob object for monitoring the state of the download. + + The `dest` argument is a Path object pointing to a directory. All downloads + with be placed inside this directory. The callbacks will receive the + MultiFileDownloadJob. + """ + pass + + @abstractmethod + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + """ + Enqueue a previously-created multi-file download job. + + :param job: A MultiFileDownloadJob created with multifile_download() + """ + pass + @abstractmethod def submit_download_job( self, @@ -252,7 +317,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """Cancel the job, clearing partial downloads and putting it into ERROR state.""" pass @@ -262,7 +327,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Wait until the indicated download job has reached a terminal state. This will block until the indicated install job has completed, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 180f0f1a8c..4640a656dc 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -8,23 +8,28 @@ import time import traceback from pathlib import Path from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set import requests from pydantic.networks import AnyHttpUrl from requests import HTTPError from tqdm import tqdm +from invokeai.app.services.config import InvokeAIAppConfig, get_config +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.util.misc import get_iso_timestamp +from invokeai.backend.model_manager.metadata import RemoteModelFile from invokeai.backend.util.logging import InvokeAILogger from .download_base import ( DownloadEventHandler, DownloadExceptionHandler, DownloadJob, + DownloadJobBase, DownloadJobCancelledException, DownloadJobStatus, DownloadQueueServiceBase, + MultiFileDownloadJob, ServiceInactiveException, UnknownJobIDException, ) @@ -42,20 +47,24 @@ class DownloadQueueService(DownloadQueueServiceBase): def __init__( self, max_parallel_dl: int = 5, + app_config: Optional[InvokeAIAppConfig] = None, event_bus: Optional["EventServiceBase"] = None, requests_session: Optional[requests.sessions.Session] = None, ): """ Initialize DownloadQueue. + :param app_config: InvokeAIAppConfig object :param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param requests_session: Optional requests.sessions.Session object, for unit tests. """ + self._app_config = app_config or get_config() self._jobs: Dict[int, DownloadJob] = {} + self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {} self._next_job_id = 0 self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._stop_event = threading.Event() - self._job_completed_event = threading.Event() + self._job_terminated_event = threading.Event() self._worker_pool: Set[threading.Thread] = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -107,18 +116,16 @@ class DownloadQueueService(DownloadQueueServiceBase): raise ServiceInactiveException( "The download service is not currently accepting requests. Please call start() to initialize the service." ) - with self._lock: - job.id = self._next_job_id - self._next_job_id += 1 - job.set_callbacks( - on_start=on_start, - on_progress=on_progress, - on_complete=on_complete, - on_cancelled=on_cancelled, - on_error=on_error, - ) - self._jobs[job.id] = job - self._queue.put(job) + job.id = self._next_id() + job.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + self._jobs[job.id] = job + self._queue.put(job) def download( self, @@ -141,7 +148,7 @@ class DownloadQueueService(DownloadQueueServiceBase): source=source, dest=dest, priority=priority, - access_token=access_token, + access_token=access_token or self._lookup_access_token(source), ) self.submit_download_job( job, @@ -153,10 +160,63 @@ class DownloadQueueService(DownloadQueueServiceBase): ) return job + def multifile_download( + self, + parts: List[RemoteModelFile], + dest: Path, + access_token: Optional[str] = None, + submit_job: bool = True, + on_start: Optional[DownloadEventHandler] = None, + on_progress: Optional[DownloadEventHandler] = None, + on_complete: Optional[DownloadEventHandler] = None, + on_cancelled: Optional[DownloadEventHandler] = None, + on_error: Optional[DownloadExceptionHandler] = None, + ) -> MultiFileDownloadJob: + mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id()) + mfdj.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + + for part in parts: + url = part.url + path = dest / part.path + assert path.is_relative_to(dest), "only relative download paths accepted" + job = DownloadJob( + source=url, + dest=path, + access_token=access_token, + ) + mfdj.download_parts.add(job) + self._download_part2parent[job.source] = mfdj + if submit_job: + self.submit_multifile_download(mfdj) + return mfdj + + def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: + for download_job in job.download_parts: + self.submit_download_job( + download_job, + on_start=self._mfd_started, + on_progress=self._mfd_progress, + on_complete=self._mfd_complete, + on_cancelled=self._mfd_cancelled, + on_error=self._mfd_error, + ) + def join(self) -> None: """Wait for all jobs to complete.""" self._queue.join() + def _next_id(self) -> int: + with self._lock: + id = self._next_job_id + self._next_job_id += 1 + return id + def list_jobs(self) -> List[DownloadJob]: """List all the jobs.""" return list(self._jobs.values()) @@ -178,14 +238,14 @@ class DownloadQueueService(DownloadQueueServiceBase): except KeyError as excp: raise UnknownJobIDException("Unrecognized job") from excp - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """ Cancel the indicated job. If it is running it will be stopped. job.status will be set to DownloadJobStatus.CANCELLED """ - with self._lock: + if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]: job.cancel() def cancel_all_jobs(self) -> None: @@ -194,12 +254,12 @@ class DownloadQueueService(DownloadQueueServiceBase): if not job.in_terminal_state: self.cancel_job(job) - def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: - if self._job_completed_event.wait(timeout=0.25): # in case we miss an event - self._job_completed_event.clear() + if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event + self._job_terminated_event.clear() if timeout > 0 and time.time() - start > timeout: raise TimeoutError("Timeout exceeded") return job @@ -228,22 +288,25 @@ class DownloadQueueService(DownloadQueueServiceBase): job.job_started = get_iso_timestamp() self._do_download(job) self._signal_job_complete(job) - except (OSError, HTTPError) as excp: - job.error_type = excp.__class__.__name__ + f"({str(excp)})" - job.error = traceback.format_exc() - self._signal_job_error(job, excp) except DownloadJobCancelledException: self._signal_job_cancelled(job) self._cleanup_cancelled_job(job) - + except Exception as excp: + job.error_type = excp.__class__.__name__ + f"({str(excp)})" + job.error = traceback.format_exc() + self._signal_job_error(job, excp) finally: job.job_ended = get_iso_timestamp() - self._job_completed_event.set() # signal a change to terminal state + self._job_terminated_event.set() # signal a change to terminal state + self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it + self._job_terminated_event.set() self._queue.task_done() + self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") def _do_download(self, job: DownloadJob) -> None: """Do the actual download.""" + url = job.source header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" @@ -335,38 +398,29 @@ class DownloadQueueService(DownloadQueueServiceBase): def _in_progress_path(self, path: Path) -> Path: return path.with_name(path.name + ".downloading") + def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]: + # Pull the token from config if it exists and matches the URL + token = None + for pair in self._app_config.remote_api_tokens or []: + if re.search(pair.url_regex, str(source)): + token = pair.token + break + return token + def _signal_job_started(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.RUNNING - if job.on_start: - try: - job.on_start(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_start") if self._event_bus: self._event_bus.emit_download_started(job) def _signal_job_progress(self, job: DownloadJob) -> None: - if job.on_progress: - try: - job.on_progress(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_progress") if self._event_bus: self._event_bus.emit_download_progress(job) def _signal_job_complete(self, job: DownloadJob) -> None: job.status = DownloadJobStatus.COMPLETED - if job.on_complete: - try: - job.on_complete(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_complete") if self._event_bus: self._event_bus.emit_download_complete(job) @@ -374,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase): if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]: return job.status = DownloadJobStatus.CANCELLED - if job.on_cancelled: - try: - job.on_cancelled(job) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_cancelled") if self._event_bus: self._event_bus.emit_download_cancelled(job) + # if multifile download, then signal the parent + if parent_job := self._download_part2parent.get(job.source, None): + if not parent_job.in_terminal_state: + parent_job.status = DownloadJobStatus.CANCELLED + self._execute_cb(parent_job, "on_cancelled") + def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None: job.status = DownloadJobStatus.ERROR self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}") - if job.on_error: - try: - job.on_error(job, excp) - except Exception as e: - self._logger.error( - f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}" - ) + self._execute_cb(job, "on_error", excp) + if self._event_bus: self._event_bus.emit_download_error(job) @@ -406,6 +455,97 @@ class DownloadQueueService(DownloadQueueServiceBase): except OSError as excp: self._logger.warning(excp) + ######################################## + # callbacks used for multifile downloads + ######################################## + def _mfd_started(self, download_job: DownloadJob) -> None: + self._logger.info(f"File download started: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.waiting: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.status = DownloadJobStatus.RUNNING + assert download_job.download_path is not None + path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest) + mf_job.download_path = ( + mf_job.dest / path_relative_to_destdir.parts[0] + ) # keep just the first component of the path + self._execute_cb(mf_job, "on_start") + + def _mfd_progress(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + if mf_job.cancelled: + for part in mf_job.download_parts: + self.cancel_job(part) + elif mf_job.running: + mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) + mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts) + self._execute_cb(mf_job, "on_progress") + + def _mfd_complete(self, download_job: DownloadJob) -> None: + self._logger.info(f"Download complete: {download_job.source}") + with self._lock: + mf_job = self._download_part2parent[download_job.source] + + # are there any more active jobs left in this task? + if mf_job.running and all(x.complete for x in mf_job.download_parts): + mf_job.status = DownloadJobStatus.COMPLETED + self._execute_cb(mf_job, "on_complete") + + # we're done with this sub-job + self._job_terminated_event.set() + + def _mfd_cancelled(self, download_job: DownloadJob) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + + if not mf_job.in_terminal_state: + self._logger.warning(f"Download cancelled: {download_job.source}") + mf_job.cancel() + + for s in mf_job.download_parts: + self.cancel_job(s) + + def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + with self._lock: + mf_job = self._download_part2parent[download_job.source] + assert mf_job is not None + if not mf_job.in_terminal_state: + mf_job.status = download_job.status + mf_job.error = download_job.error + mf_job.error_type = download_job.error_type + self._execute_cb(mf_job, "on_error", excp) + self._logger.error( + f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}" + ) + for s in [x for x in mf_job.download_parts if x.running]: + self.cancel_job(s) + self._download_part2parent.pop(download_job.source) + self._job_terminated_event.set() + + def _execute_cb( + self, + job: DownloadJob | MultiFileDownloadJob, + callback_name: Literal[ + "on_start", + "on_progress", + "on_complete", + "on_cancelled", + "on_error", + ], + excp: Optional[Exception] = None, + ) -> None: + if callback := getattr(job, callback_name, None): + args = [job, excp] if excp else [job] + try: + callback(*args) + except Exception as e: + self._logger.error( + f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}" + ) + def get_pc_name_max(directory: str) -> int: if hasattr(os, "pathconf"): diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 6ee671062d..20afaeaa50 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -13,7 +13,7 @@ from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource from invokeai.app.services.model_records import ModelRecordServiceBase -from invokeai.backend.model_manager.config import AnyModelConfig +from invokeai.backend.model_manager import AnyModelConfig class ModelInstallServiceBase(ABC): @@ -243,12 +243,11 @@ class ModelInstallServiceBase(ABC): """ @abstractmethod - def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: """ Download the model file located at source to the models cache and return its Path. - :param source: A Url or a string that can be converted into one. - :param access_token: Optional access token to access restricted resources. + :param source: A string representing a URL or repo_id. The model file will be downloaded into the system-wide model cache (`models/.cache`) if it isn't already there. Note that the model cache diff --git a/invokeai/app/services/model_install/model_install_common.py b/invokeai/app/services/model_install/model_install_common.py index d42e7632f3..c1538f543d 100644 --- a/invokeai/app/services/model_install/model_install_common.py +++ b/invokeai/app/services/model_install/model_install_common.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl from typing_extensions import Annotated -from invokeai.app.services.download import DownloadJob +from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant from invokeai.backend.model_manager.config import ModelSourceType from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata @@ -26,13 +26,6 @@ class InstallStatus(str, Enum): CANCELLED = "cancelled" # terminated with an error message -class ModelInstallPart(BaseModel): - url: AnyHttpUrl - path: Path - bytes: int = 0 - total_bytes: int = 0 - - class UnknownInstallJobException(Exception): """Raised when the status of an unknown job is requested.""" @@ -169,6 +162,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) + _multifile_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index df060caff3..39e38a593f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -5,21 +5,22 @@ import os import re import threading import time -from hashlib import sha256 from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl +from pydantic_core import Url from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob +from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -44,6 +45,7 @@ from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger from invokeai.backend.util.catch_sigint import catch_sigint from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.util import slugify from .model_install_common import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -91,7 +93,7 @@ class ModelInstallService(ModelInstallServiceBase): self._downloads_changed_event = threading.Event() self._install_completed_event = threading.Event() self._download_queue = download_queue - self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} + self._download_cache: Dict[int, ModelInstallJob] = {} self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -210,33 +212,12 @@ class ModelInstallService(ModelInstallServiceBase): access_token: Optional[str] = None, inplace: Optional[bool] = False, ) -> ModelInstallJob: - variants = "|".join(ModelRepoVariant.__members__.values()) - hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" - source_obj: Optional[StringLikeSource] = None - - if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source), inplace=inplace) - elif match := re.match(hf_repoid_re, source): - source_obj = HFModelSource( - repo_id=match.group(1), - variant=match.group(2) if match.group(2) else None, # pass None rather than '' - subfolder=Path(match.group(3)) if match.group(3) else None, - access_token=access_token, - ) - elif re.match(r"^https?://[^/]+", source): - # Pull the token from config if it exists and matches the URL - _token = access_token - if _token is None: - for pair in self.app_config.remote_api_tokens or []: - if re.search(pair.url_regex, source): - _token = pair.token - break - source_obj = URLModelSource( - url=AnyHttpUrl(source), - access_token=_token, - ) - else: - raise ValueError(f"Unsupported model source: '{source}'") + """Install a model using pattern matching to infer the type of source.""" + source_obj = self._guess_source(source) + if isinstance(source_obj, LocalModelSource): + source_obj.inplace = inplace + elif isinstance(source_obj, HFModelSource) or isinstance(source_obj, URLModelSource): + source_obj.access_token = access_token return self.import_model(source_obj, config) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 @@ -297,8 +278,9 @@ class ModelInstallService(ModelInstallServiceBase): def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" job.cancel() - with self._lock: - self._cancel_download_parts(job) + self._logger.warning(f"Cancelling {job.source}") + if dj := job._multifile_job: + self._download_queue.cancel_job(dj) def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" @@ -346,7 +328,7 @@ class ModelInstallService(ModelInstallServiceBase): legacy_config_path = stanza.get("config") if legacy_config_path: # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. - legacy_config_path: Path = self._app_config.root_path / legacy_config_path + legacy_config_path = self._app_config.root_path / legacy_config_path if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) config["config_path"] = str(legacy_config_path) @@ -386,38 +368,92 @@ class ModelInstallService(ModelInstallServiceBase): rmtree(model_path) self.unregister(key) - def download_and_cache( + @classmethod + def _download_cache_path(cls, source: Union[str, AnyHttpUrl], app_config: InvokeAIAppConfig) -> Path: + escaped_source = slugify(str(source)) + return app_config.download_cache_path / escaped_source + + def download_and_cache_model( self, - source: Union[str, AnyHttpUrl], - access_token: Optional[str] = None, - timeout: int = 0, + source: str | AnyHttpUrl, ) -> Path: """Download the model file located at source to the models cache and return its Path.""" - model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] - model_path = self._app_config.convert_cache_path / model_hash + model_path = self._download_cache_path(str(source), self._app_config) - # We expect the cache directory to contain one and only one downloaded file. + # We expect the cache directory to contain one and only one downloaded file or directory. # We don't know the file's name in advance, as it is set by the download # content-disposition header. if model_path.exists(): - contents = [x for x in model_path.iterdir() if x.is_file()] + contents: List[Path] = list(model_path.iterdir()) if len(contents) > 0: return contents[0] model_path.mkdir(parents=True, exist_ok=True) - job = self._download_queue.download( - source=AnyHttpUrl(str(source)), + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + job = self._multifile_download( dest=model_path, - access_token=access_token, - on_progress=TqdmProgress().update, + remote_files=remote_files, + subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, ) - self._download_queue.wait_for_job(job, timeout) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) if job.complete: assert job.download_path is not None return job.download_path else: raise Exception(job.error) + def _remote_files_from_source( + self, source: ModelSource + ) -> Tuple[List[RemoteModelFile], Optional[AnyModelRepoMetadata]]: + metadata = None + if isinstance(source, HFModelSource): + metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls( + variant=source.variant or self._guess_variant(), + subfolder=source.subfolder, + session=self._session, + ), metadata + + if isinstance(source, URLModelSource): + try: + fetcher = self.get_fetcher_from_url(str(source.url)) + kwargs: dict[str, Any] = {"session": self._session} + metadata = fetcher(**kwargs).from_url(source.url) + assert isinstance(metadata, ModelMetadataWithFiles) + return metadata.download_urls(session=self._session), metadata + except ValueError: + pass + + return [RemoteModelFile(url=source.url, path=Path("."), size=0)], None + + raise Exception(f"No files associated with {source}") + + def _guess_source(self, source: str) -> ModelSource: + """Turn a source string into a ModelSource object.""" + variants = "|".join(ModelRepoVariant.__members__.values()) + hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" + source_obj: Optional[StringLikeSource] = None + + if Path(source).exists(): # A local file or directory + source_obj = LocalModelSource(path=Path(source)) + elif match := re.match(hf_repoid_re, source): + source_obj = HFModelSource( + repo_id=match.group(1), + variant=ModelRepoVariant(match.group(2)) if match.group(2) else None, # pass None rather than '' + subfolder=Path(match.group(3)) if match.group(3) else None, + ) + elif re.match(r"^https?://[^/]+", source): + source_obj = URLModelSource( + url=Url(source), + ) + else: + raise ValueError(f"Unsupported model source: '{source}'") + return source_obj + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -478,16 +514,19 @@ class ModelInstallService(ModelInstallServiceBase): job.config_out = self.record_store.get_model(key) self._signal_job_completed(job) - def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: - if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): - job.set_error( + def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: + multifile_download_job = install_job._multifile_job + if multifile_download_job and any( + x.content_type is not None and "text/html" in x.content_type for x in multifile_download_job.download_parts + ): + install_job.set_error( InvalidModelConfigException( - f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." + f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) else: - job.set_error(excp) - self._signal_job_errored(job) + install_job.set_error(excp) + self._signal_job_errored(install_job) # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory @@ -513,7 +552,6 @@ class ModelInstallService(ModelInstallServiceBase): This is typically only used during testing with a new DB or when using the memory DB, because those are the only situations in which we may have orphaned models in the models directory. """ - installed_model_paths = { (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() } @@ -525,8 +563,13 @@ class ModelInstallService(ModelInstallServiceBase): if resolved_path in installed_model_paths: return True # Skip core models entirely - these aren't registered with the model manager. - if str(resolved_path).startswith(str(self.app_config.models_path / "core")): - return False + for special_directory in [ + self.app_config.models_path / "core", + self.app_config.convert_cache_dir, + self.app_config.download_cache_dir, + ]: + if resolved_path.is_relative_to(special_directory): + return False try: model_id = self.register_path(model_path) self._logger.info(f"Registered {model_path.name} with id {model_id}") @@ -641,20 +684,15 @@ class ModelInstallService(ModelInstallServiceBase): inplace=source.inplace or False, ) - def _import_from_hf(self, source: HFModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: + def _import_from_hf( + self, + source: HFModelSource, + config: Optional[Dict[str, Any]] = None, + ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests - source.access_token = source.access_token or HfFolder.get_token() - if not source.access_token: - self._logger.info("No HuggingFace access token present; some models may not be downloadable.") - - metadata = HuggingFaceMetadataFetch(self._session).from_id(source.repo_id, source.variant) - assert isinstance(metadata, ModelMetadataWithFiles) - remote_files = metadata.download_urls( - variant=source.variant or self._guess_variant(), - subfolder=source.subfolder, - session=self._session, - ) - + if source.access_token is None: + source.access_token = HfFolder.get_token() + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -662,22 +700,12 @@ class ModelInstallService(ModelInstallServiceBase): metadata=metadata, ) - def _import_from_url(self, source: URLModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: - # URLs from HuggingFace will be handled specially - metadata = None - fetcher = None - try: - fetcher = self.get_fetcher_from_url(str(source.url)) - except ValueError: - pass - kwargs: dict[str, Any] = {"session": self._session} - if fetcher is not None: - metadata = fetcher(**kwargs).from_url(source.url) - self._logger.debug(f"metadata={metadata}") - if metadata and isinstance(metadata, ModelMetadataWithFiles): - remote_files = metadata.download_urls(session=self._session) - else: - remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)] + def _import_from_url( + self, + source: URLModelSource, + config: Optional[Dict[str, Any]], + ) -> ModelInstallJob: + remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, config=config, @@ -692,12 +720,9 @@ class ModelInstallService(ModelInstallServiceBase): metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], ) -> ModelInstallJob: - # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. - # Currently the tmpdir isn't automatically removed at exit because it is - # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - tmpdir = Path( + destdir = Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -708,55 +733,28 @@ class ModelInstallService(ModelInstallServiceBase): source=source, config_in=config or {}, source_metadata=metadata, - local_path=tmpdir, # local path may change once the download has started due to content-disposition handling + local_path=destdir, # local path may change once the download has started due to content-disposition handling bytes=0, total_bytes=0, ) - # In the event that there is a subfolder specified in the source, - # we need to remove it from the destination path in order to avoid - # creating unwanted subfolders - if isinstance(source, HFModelSource) and source.subfolder: - root = Path(remote_files[0].path.parts[0]) - subfolder = root / source.subfolder - else: - root = Path(".") - subfolder = Path(".") + # remember the temporary directory for later removal + install_job._install_tmpdir = destdir + install_job.total_bytes = sum((x.size or 0) for x in remote_files) - # we remember the path up to the top of the tmpdir so that it may be - # removed safely at the end of the install process. - install_job._install_tmpdir = tmpdir - assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below + multifile_job = self._multifile_download( + remote_files=remote_files, + dest=destdir, + subfolder=source.subfolder if isinstance(source, HFModelSource) else None, + access_token=source.access_token, + submit_job=False, # Important! Don't submit the job until we have set our _download_cache dict + ) + self._download_cache[multifile_job.id] = install_job + install_job._multifile_job = multifile_job - files_string = "file" if len(remote_files) == 1 else "file" - self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queueing model install: {source} ({len(remote_files)} {files_string})") self._logger.debug(f"remote_files={remote_files}") - for model_file in remote_files: - url = model_file.url - path = root / model_file.path.relative_to(subfolder) - self._logger.debug(f"Downloading {url} => {path}") - install_job.total_bytes += model_file.size - assert hasattr(source, "access_token") - dest = tmpdir / path.parent - dest.mkdir(parents=True, exist_ok=True) - download_job = DownloadJob( - source=url, - dest=dest, - access_token=source.access_token, - ) - self._download_cache[download_job.source] = install_job # matches a download job to an install job - install_job.download_parts.add(download_job) - - # only start the jobs once install_job.download_parts is fully populated - for download_job in install_job.download_parts: - self._download_queue.submit_download_job( - download_job, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, - ) - + self._download_queue.submit_multifile_download(multifile_job) return install_job def _stat_size(self, path: Path) -> int: @@ -768,87 +766,104 @@ class ModelInstallService(ModelInstallServiceBase): size += sum(self._stat_size(Path(root, x)) for x in files) return size + def _multifile_download( + self, + remote_files: List[RemoteModelFile], + dest: Path, + subfolder: Optional[Path] = None, + access_token: Optional[str] = None, + submit_job: bool = True, + ) -> MultiFileDownloadJob: + # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and + # we are installing the "vae" subfolder, we do not want to create an additional folder level, such + # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo". + # So what we do is to synthesize a folder named "sdxl-turbo_vae" here. + if subfolder: + top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" + path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ + path_to_add = Path(f"{top}_{subfolder}") + else: + path_to_remove = Path(".") + path_to_add = Path(".") + + parts: List[RemoteModelFile] = [] + for model_file in remote_files: + assert model_file.size is not None + parts.append( + RemoteModelFile( + url=model_file.url, # if a subfolder, then sdxl-turbo_vae/config.json + path=path_to_add / model_file.path.relative_to(path_to_remove), + ) + ) + + return self._download_queue.multifile_download( + parts=parts, + dest=dest, + access_token=access_token, + submit_job=submit_job, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ - def _download_started_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download started: {download_job.source}") + def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - install_job.status = InstallStatus.DOWNLOADING + if install_job := self._download_cache.get(download_job.id, None): + install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: - partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) - dest_name = partial_path.parts[0] - install_job.local_path = install_job._install_tmpdir / dest_name - - # Update the total bytes count for remote sources. - if not install_job.total_bytes: - install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts) - - def _download_progress_callback(self, download_job: DownloadJob) -> None: - with self._lock: - install_job = self._download_cache[download_job.source] - if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._cancel_download_parts(install_job) - else: - # update sizes - install_job.bytes = sum(x.bytes for x in install_job.download_parts) + if install_job.local_path == install_job._install_tmpdir: # first time + assert download_job.download_path + install_job.local_path = download_job.download_path + install_job.download_parts = download_job.download_parts + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = download_job.total_bytes self._signal_job_downloading(install_job) - def _download_complete_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download complete: {download_job.source}") + def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] + if install_job := self._download_cache.get(download_job.id, None): + if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() + self._download_queue.cancel_job(download_job) + else: + # update sizes + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) + self._signal_job_downloading(install_job) - # are there any more active jobs left in this task? - if install_job.downloading and all(x.complete for x in install_job.download_parts): + def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: + with self._lock: + if install_job := self._download_cache.pop(download_job.id, None): self._signal_job_downloads_done(install_job) - self._put_in_queue(install_job) + self._put_in_queue(install_job) # this starts the installation and registration - # Let other threads know that the number of downloads has changed - self._download_cache.pop(download_job.source, None) - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() - def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) - assert install_job is not None - assert excp is not None - install_job.set_error(excp) - self._logger.error( - f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}" - ) - self._cancel_download_parts(install_job) + if install_job := self._download_cache.pop(download_job.id, None): + assert excp is not None + install_job.set_error(excp) + self._download_queue.cancel_job(download_job) - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() - def _download_cancelled_callback(self, download_job: DownloadJob) -> None: + def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) - if not install_job: - return - self._downloads_changed_event.set() - self._logger.warning(f"Model download canceled: {download_job.source}") - # if install job has already registered an error, then do not replace its status with cancelled - if not install_job.errored: - install_job.cancel() - self._cancel_download_parts(install_job) + if install_job := self._download_cache.pop(download_job.id, None): + self._downloads_changed_event.set() + # if install job has already registered an error, then do not replace its status with cancelled + if not install_job.errored: + install_job.cancel() - # Let other threads know that the number of downloads has changed - self._downloads_changed_event.set() - - def _cancel_download_parts(self, install_job: ModelInstallJob) -> None: - # on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks - # do not lock here because it gets called within a locked context - for s in install_job.download_parts: - self._download_queue.cancel_job(s) - - if all(x.in_terminal_state for x in install_job.download_parts): - # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources - self._put_in_queue(install_job) + # Let other threads know that the number of downloads has changed + self._downloads_changed_event.set() # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus @@ -861,6 +876,9 @@ class ModelInstallService(ModelInstallServiceBase): def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: + assert job._multifile_job is not None + assert job.bytes is not None + assert job.total_bytes is not None self._event_bus.emit_model_install_download_progress(job) def _signal_job_downloads_done(self, job: ModelInstallJob) -> None: @@ -875,6 +893,8 @@ class ModelInstallService(ModelInstallServiceBase): self._logger.info(f"Model install complete: {job.source}") self._logger.debug(f"{job.local_path} registered key {job.config_out.key}") if self._event_bus: + assert job.local_path is not None + assert job.config_out is not None self._event_bus.emit_model_install_complete(job) def _signal_job_errored(self, job: ModelInstallJob) -> None: @@ -890,7 +910,13 @@ class ModelInstallService(ModelInstallServiceBase): self._event_bus.emit_model_install_cancelled(job) @staticmethod - def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: + def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: + """ + Return a metadata fetcher appropriate for provided url. + + This used to be more useful, but the number of supported model + sources has been reduced to HuggingFace alone. + """ if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): return HuggingFaceMetadataFetch raise ValueError(f"Unsupported model source: '{url}'") diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 9d75aafde1..da56772195 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -2,10 +2,11 @@ """Base class for model loader.""" from abc import ABC, abstractmethod -from typing import Optional +from pathlib import Path +from typing import Callable, Optional from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType -from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -31,3 +32,26 @@ class ModelLoadServiceBase(ABC): @abstractmethod def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" + + @abstractmethod + def load_model_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + ) -> LoadedModelWithoutConfig: + """ + Load the model file or directory located at the indicated Path. + + This will load an arbitrary model file into the RAM cache. If the optional loader + argument is provided, the loader will be invoked to load the model into + memory. Otherwise the method will call safetensors.torch.load_file() or + torch.load() as appropriate to the file suffix. + + Be aware that this returns a LoadedModelWithoutConfig object, which is the same as + LoadedModel, but without the config attribute. + + Args: + model_path: A pathlib.Path to a checkpoint-style models file + loader: A Callable that expects a Path and returns a Dict[str, Tensor] + + Returns: + A LoadedModel object. + """ diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index e9f527bc86..7067481378 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -1,18 +1,26 @@ # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Team """Implementation of model loader service.""" -from typing import Optional, Type +from pathlib import Path +from typing import Callable, Optional, Type + +from picklescan.scanner import scan_file_path +from safetensors.torch import load_file as safetensors_load_file +from torch import load as torch_load from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, + LoadedModelWithoutConfig, ModelLoaderRegistry, ModelLoaderRegistryBase, ) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_load_base import ModelLoadServiceBase @@ -75,3 +83,41 @@ class ModelLoadService(ModelLoadServiceBase): self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) return loaded_model + + def load_model_from_path( + self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None + ) -> LoadedModelWithoutConfig: + cache_key = str(model_path) + ram_cache = self.ram_cache + try: + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) + except IndexError: + pass + + def torch_load_file(checkpoint: Path) -> AnyModel: + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise Exception("The model at {checkpoint} is potentially infected by malware. Aborting load.") + result = torch_load(checkpoint, map_location="cpu") + return result + + def diffusers_load_directory(directory: Path) -> AnyModel: + load_class = GenericDiffusersLoader( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self.convert_cache, + ).get_hf_load_class(directory) + return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + + loader = loader or ( + diffusers_load_directory + if model_path.is_dir() + else torch_load_file + if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) + else lambda path: safetensors_load_file(path, device="cpu") + ) + assert loader is not None + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 094ade6383..57531cf3c1 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -12,15 +12,13 @@ from pydantic import BaseModel, Field from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.util.model_exclude_null import BaseModelExcludeNull -from invokeai.backend.model_manager import ( +from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, - ModelFormat, - ModelType, -) -from invokeai.backend.model_manager.config import ( ControlAdapterDefaultSettings, MainModelDefaultSettings, + ModelFormat, + ModelType, ModelVariantType, SchedulerPredictionType, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 348be89265..01662335e4 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, Optional, Union from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -14,8 +15,15 @@ from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -320,8 +328,10 @@ class ConditioningInterface(InvocationContextInterface): class ModelsInterface(InvocationContextInterface): + """Common API for loading, downloading and managing models.""" + def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool: - """Checks if a model exists. + """Check if a model exists. Args: identifier: The key or ModelField representing the model. @@ -331,13 +341,13 @@ class ModelsInterface(InvocationContextInterface): """ if isinstance(identifier, str): return self._services.model_manager.store.exists(identifier) - - return self._services.model_manager.store.exists(identifier.key) + else: + return self._services.model_manager.store.exists(identifier.key) def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model. + """Load a model. Args: identifier: The key or ModelField representing the model. @@ -361,7 +371,7 @@ class ModelsInterface(InvocationContextInterface): def load_by_attrs( self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - """Loads a model by its attributes. + """Load a model by its attributes. Args: name: Name of the model. @@ -384,7 +394,7 @@ class ModelsInterface(InvocationContextInterface): return self._services.model_manager.load.load_model(configs[0], submodel_type) def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig: - """Gets a model's config. + """Get a model's config. Args: identifier: The key or ModelField representing the model. @@ -394,11 +404,11 @@ class ModelsInterface(InvocationContextInterface): """ if isinstance(identifier, str): return self._services.model_manager.store.get_model(identifier) - - return self._services.model_manager.store.get_model(identifier.key) + else: + return self._services.model_manager.store.get_model(identifier.key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: - """Searches for models by path. + """Search for models by path. Args: path: The path to search for. @@ -415,7 +425,7 @@ class ModelsInterface(InvocationContextInterface): type: Optional[ModelType] = None, format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: - """Searches for models by attributes. + """Search for models by attributes. Args: name: The name to search for (exact match). @@ -434,6 +444,72 @@ class ModelsInterface(InvocationContextInterface): model_format=format, ) + def download_and_cache_model( + self, + source: str | AnyHttpUrl, + ) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL that points to the model, or a huggingface repo_id. + + Returns: + Path to the downloaded model + """ + return self._services.model_manager.install.download_and_cache_model(source=source) + + def load_local_model( + self, + model_path: Path, + loader: Optional[Callable[[Path], AnyModel]] = None, + ) -> LoadedModelWithoutConfig: + """ + Load the model file located at the indicated path + + If a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. + + Be aware that the LoadedModelWithoutConfig object has no `config` attribute + + Args: + path: A model Path + loader: A Callable that expects a Path and returns a dict[str|int, Any] + + Returns: + A LoadedModelWithoutConfig object. + """ + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + + def load_remote_model( + self, + source: str | AnyHttpUrl, + loader: Optional[Callable[[Path], AnyModel]] = None, + ) -> LoadedModelWithoutConfig: + """ + Download, cache, and load the model file located at the indicated URL or repo_id. + + If the model is already downloaded, it will be loaded from the cache. + + If the a loader callable is provided, it will be invoked to load the model. Otherwise, + `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. + + Be aware that the LoadedModelWithoutConfig object has no `config` attribute + + Args: + source: A URL or huggingface repoid. + loader: A Callable that expects a Path and returns a dict[str|int, Any] + + Returns: + A LoadedModelWithoutConfig object. + """ + model_path = self._services.model_manager.install.download_and_cache_model(source=str(source)) + return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader) + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index cadf09f457..3b5f447306 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -13,6 +13,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -43,6 +44,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_8(app_config=config)) migrator.register_migration(build_migration_9()) migrator.register_migration(build_migration_10()) + migrator.register_migration(build_migration_11(app_config=config, logger=logger)) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py new file mode 100644 index 0000000000..f66374e0b1 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_11.py @@ -0,0 +1,75 @@ +import shutil +import sqlite3 +from logging import Logger + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + +LEGACY_CORE_MODELS = [ + # OpenPose + "any/annotators/dwpose/yolox_l.onnx", + "any/annotators/dwpose/dw-ll_ucoco_384.onnx", + # DepthAnything + "any/annotators/depth_anything/depth_anything_vitl14.pth", + "any/annotators/depth_anything/depth_anything_vitb14.pth", + "any/annotators/depth_anything/depth_anything_vits14.pth", + # Lama inpaint + "core/misc/lama/lama.pt", + # RealESRGAN upscale + "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", +] + + +class Migration11Callback: + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: + self._app_config = app_config + self._logger = logger + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._remove_convert_cache() + self._remove_downloaded_models() + self._remove_unused_core_models() + + def _remove_convert_cache(self) -> None: + """Rename models/.cache to models/.convert_cache.""" + self._logger.info("Removing .cache directory. Converted models will now be cached in .convert_cache.") + legacy_convert_path = self._app_config.root_path / "models" / ".cache" + shutil.rmtree(legacy_convert_path, ignore_errors=True) + + def _remove_downloaded_models(self) -> None: + """Remove models from their old locations; they will re-download when needed.""" + self._logger.info( + "Removing legacy just-in-time models. Downloaded models will now be cached in .download_cache." + ) + for model_path in LEGACY_CORE_MODELS: + legacy_dest_path = self._app_config.models_path / model_path + legacy_dest_path.unlink(missing_ok=True) + + def _remove_unused_core_models(self) -> None: + """Remove unused core models and their directories.""" + self._logger.info("Removing defunct core models.") + for dir in ["face_restoration", "misc", "upscaling"]: + path_to_remove = self._app_config.models_path / "core" / dir + shutil.rmtree(path_to_remove, ignore_errors=True) + shutil.rmtree(self._app_config.models_path / "any" / "annotators", ignore_errors=True) + + +def build_migration_11(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: + """ + Build the migration from database version 10 to 11. + + This migration does the following: + - Moves "core" models previously downloaded with download_with_progress_bar() into new + "models/.download_cache" directory. + - Renames "models/.cache" to "models/.convert_cache". + """ + migration_11 = Migration( + from_version=10, + to_version=11, + callback=Migration11Callback(app_config=app_config, logger=logger), + ) + + return migration_11 diff --git a/invokeai/app/util/download_with_progress.py b/invokeai/app/util/download_with_progress.py deleted file mode 100644 index 97a2abb2f6..0000000000 --- a/invokeai/app/util/download_with_progress.py +++ /dev/null @@ -1,51 +0,0 @@ -from pathlib import Path -from urllib import request - -from tqdm import tqdm - -from invokeai.backend.util.logging import InvokeAILogger - - -class ProgressBar: - """Simple progress bar for urllib.request.urlretrieve using tqdm.""" - - def __init__(self, model_name: str = "file"): - self.pbar = None - self.name = model_name - - def __call__(self, block_num: int, block_size: int, total_size: int): - if not self.pbar: - self.pbar = tqdm( - desc=self.name, - initial=0, - unit="iB", - unit_scale=True, - unit_divisor=1000, - total=total_size, - ) - self.pbar.update(block_size) - - -def download_with_progress_bar(name: str, url: str, dest_path: Path) -> bool: - """Download a file from a URL to a destination path, with a progress bar. - If the file already exists, it will not be downloaded again. - - Exceptions are not caught. - - Args: - name (str): Name of the file being downloaded. - url (str): URL to download the file from. - dest_path (Path): Destination path to save the file to. - - Returns: - bool: True if the file was downloaded, False if it already existed. - """ - if dest_path.exists(): - return False # already downloaded - - InvokeAILogger.get_logger().info(f"Downloading {name}...") - - dest_path.parent.mkdir(parents=True, exist_ok=True) - request.urlretrieve(url, dest_path, ProgressBar(name)) - - return True diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index c854fba3f2..1adcc6b202 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -1,5 +1,5 @@ -import pathlib -from typing import Literal, Union +from pathlib import Path +from typing import Literal import cv2 import numpy as np @@ -10,28 +10,17 @@ from PIL import Image from torchvision.transforms import Compose from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize -from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger config = get_config() logger = InvokeAILogger.get_logger(config=config) DEPTH_ANYTHING_MODELS = { - "large": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitl14.pth", - }, - "base": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vitb14.pth", - }, - "small": { - "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", - "local": "any/annotators/depth_anything/depth_anything_vits14.pth", - }, + "large": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", + "base": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitb14.pth?download=true", + "small": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vits14.pth?download=true", } @@ -53,36 +42,27 @@ transform = Compose( class DepthAnythingDetector: - def __init__(self) -> None: - self.model = None - self.model_size: Union[Literal["large", "base", "small"], None] = None - self.device = TorchDevice.choose_torch_device() + def __init__(self, model: DPT_DINOv2, device: torch.device) -> None: + self.model = model + self.device = device - def load_model(self, model_size: Literal["large", "base", "small"] = "small"): - DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] - download_with_progress_bar( - pathlib.Path(DEPTH_ANYTHING_MODELS[model_size]["url"]).name, - DEPTH_ANYTHING_MODELS[model_size]["url"], - DEPTH_ANYTHING_MODEL_PATH, - ) + @staticmethod + def load_model( + model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small" + ) -> DPT_DINOv2: + match model_size: + case "small": + model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) + case "base": + model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) + case "large": + model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - if not self.model or model_size != self.model_size: - del self.model - self.model_size = model_size + model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu")) + model.eval() - match self.model_size: - case "small": - self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) - case "base": - self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) - case "large": - self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) - - self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) - self.model.eval() - - self.model.to(self.device) - return self.model + model.to(device) + return model def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: if not self.model: diff --git a/invokeai/backend/image_util/dw_openpose/__init__.py b/invokeai/backend/image_util/dw_openpose/__init__.py index c258ef2c78..cfd3ea4b0d 100644 --- a/invokeai/backend/image_util/dw_openpose/__init__.py +++ b/invokeai/backend/image_util/dw_openpose/__init__.py @@ -1,30 +1,53 @@ +from pathlib import Path +from typing import Dict + import numpy as np import torch from controlnet_aux.util import resize_image from PIL import Image -from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose +from invokeai.backend.image_util.dw_openpose.utils import NDArrayInt, draw_bodypose, draw_facepose, draw_handpose from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody +DWPOSE_MODELS = { + "yolox_l.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", + "dw-ll_ucoco_384.onnx": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", +} -def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512): + +def draw_pose( + pose: Dict[str, NDArrayInt | Dict[str, NDArrayInt]], + H: int, + W: int, + draw_face: bool = True, + draw_body: bool = True, + draw_hands: bool = True, + resolution: int = 512, +) -> Image.Image: bodies = pose["bodies"] faces = pose["faces"] hands = pose["hands"] + + assert isinstance(bodies, dict) candidate = bodies["candidate"] + + assert isinstance(bodies, dict) subset = bodies["subset"] + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) if draw_body: canvas = draw_bodypose(canvas, candidate, subset) if draw_hands: + assert isinstance(hands, np.ndarray) canvas = draw_handpose(canvas, hands) if draw_face: - canvas = draw_facepose(canvas, faces) + assert isinstance(hands, np.ndarray) + canvas = draw_facepose(canvas, faces) # type: ignore - dwpose_image = resize_image( + dwpose_image: Image.Image = resize_image( canvas, resolution, ) @@ -39,11 +62,16 @@ class DWOpenposeDetector: Credits: https://github.com/IDEA-Research/DWPose """ - def __init__(self) -> None: - self.pose_estimation = Wholebody() + def __init__(self, onnx_det: Path, onnx_pose: Path) -> None: + self.pose_estimation = Wholebody(onnx_det=onnx_det, onnx_pose=onnx_pose) def __call__( - self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512 + self, + image: Image.Image, + draw_face: bool = False, + draw_body: bool = True, + draw_hands: bool = False, + resolution: int = 512, ) -> Image.Image: np_image = np.array(image) H, W, C = np_image.shape @@ -79,3 +107,6 @@ class DWOpenposeDetector: return draw_pose( pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution ) + + +__all__ = ["DWPOSE_MODELS", "DWOpenposeDetector"] diff --git a/invokeai/backend/image_util/dw_openpose/utils.py b/invokeai/backend/image_util/dw_openpose/utils.py index 428672ab31..dc142dfa71 100644 --- a/invokeai/backend/image_util/dw_openpose/utils.py +++ b/invokeai/backend/image_util/dw_openpose/utils.py @@ -5,11 +5,13 @@ import math import cv2 import matplotlib import numpy as np +import numpy.typing as npt eps = 0.01 +NDArrayInt = npt.NDArray[np.uint8] -def draw_bodypose(canvas, candidate, subset): +def draw_bodypose(canvas: NDArrayInt, candidate: NDArrayInt, subset: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape candidate = np.array(candidate) subset = np.array(subset) @@ -88,7 +90,7 @@ def draw_bodypose(canvas, candidate, subset): return canvas -def draw_handpose(canvas, all_hand_peaks): +def draw_handpose(canvas: NDArrayInt, all_hand_peaks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape edges = [ @@ -142,7 +144,7 @@ def draw_handpose(canvas, all_hand_peaks): return canvas -def draw_facepose(canvas, all_lmks): +def draw_facepose(canvas: NDArrayInt, all_lmks: NDArrayInt) -> NDArrayInt: H, W, C = canvas.shape for lmks in all_lmks: lmks = np.array(lmks) diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 84f5afa989..3f77f20b9c 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -2,47 +2,26 @@ # Modified pathing to suit Invoke +from pathlib import Path + import numpy as np import onnxruntime as ort from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice from .onnxdet import inference_detector from .onnxpose import inference_pose -DWPOSE_MODELS = { - "yolox_l.onnx": { - "local": "any/annotators/dwpose/yolox_l.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true", - }, - "dw-ll_ucoco_384.onnx": { - "local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx", - "url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true", - }, -} - config = get_config() class Wholebody: - def __init__(self): + def __init__(self, onnx_det: Path, onnx_pose: Path): device = TorchDevice.choose_torch_device() providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] - DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] - download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) - - POSE_MODEL_PATH = config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"] - download_with_progress_bar( - "dw-ll_ucoco_384.onnx", DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH - ) - - onnx_det = DET_MODEL_PATH - onnx_pose = POSE_MODEL_PATH - self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 4268ec773d..cd5838d1f2 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -1,4 +1,4 @@ -import gc +from pathlib import Path from typing import Any import numpy as np @@ -6,9 +6,7 @@ import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.model_manager.config import AnyModel def norm_img(np_img): @@ -19,28 +17,11 @@ def norm_img(np_img): return np_img -def load_jit_model(url_or_path, device): - model_path = url_or_path - logger.info(f"Loading model from: {model_path}") - model = torch.jit.load(model_path, map_location="cpu").to(device) - model.eval() - return model - - class LaMA: + def __init__(self, model: AnyModel): + self._model = model + def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = TorchDevice.choose_torch_device() - model_location = get_config().models_path / "core/misc/lama/lama.pt" - - if not model_location.exists(): - download_with_progress_bar( - name="LaMa Inpainting Model", - url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", - dest_path=model_location, - ) - - model = load_jit_model(model_location, device) - image = np.asarray(input_image.convert("RGB")) image = norm_img(image) @@ -48,20 +29,25 @@ class LaMA: mask = np.asarray(mask) mask = np.invert(mask) mask = norm_img(mask) - mask = (mask > 0) * 1 + + device = next(self._model.buffers()).device image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) with torch.inference_mode(): - infilled_image = model(image, mask) + infilled_image = self._model(image, mask) infilled_image = infilled_image[0].permute(1, 2, 0).detach().cpu().numpy() infilled_image = np.clip(infilled_image * 255, 0, 255).astype("uint8") infilled_image = Image.fromarray(infilled_image) - del model - gc.collect() - torch.cuda.empty_cache() - return infilled_image + + @staticmethod + def load_jit_model(url_or_path: str | Path, device: torch.device | str = "cpu") -> torch.nn.Module: + model_path = url_or_path + logger.info(f"Loading model from: {model_path}") + model: torch.nn.Module = torch.jit.load(model_path, map_location="cpu").to(device) # type: ignore + model.eval() + return model diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index 663a323967..c5fe3fa598 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -1,6 +1,5 @@ import math from enum import Enum -from pathlib import Path from typing import Any, Optional import cv2 @@ -11,6 +10,7 @@ from cv2.typing import MatLike from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet +from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.util.devices import TorchDevice """ @@ -52,7 +52,7 @@ class RealESRGAN: def __init__( self, scale: int, - model_path: Path, + loadnet: AnyModel, model: RRDBNet, tile: int = 0, tile_pad: int = 10, @@ -67,8 +67,6 @@ class RealESRGAN: self.half = half self.device = TorchDevice.choose_torch_device() - loadnet = torch.load(model_path, map_location=torch.device("cpu")) - # prefer to use params_ema if "params_ema" in loadnet: keyname = "params_ema" diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b19501843c..e3c99c5644 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -36,7 +36,7 @@ from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]] class InvalidModelConfigException(Exception): @@ -115,7 +115,7 @@ class SchedulerPredictionType(str, Enum): class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" - Default = "" # model files without "fp16" or other qualifier - empty str + Default = "" # model files without "fp16" or other qualifier FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index f47a2c4368..25125f43fb 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -7,7 +7,7 @@ from importlib import import_module from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import LoadedModel, ModelLoaderBase +from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase @@ -19,6 +19,7 @@ for module in loaders: __all__ = [ "LoadedModel", + "LoadedModelWithoutConfig", "ModelCache", "ModelConvertCache", "ModelLoaderBase", diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 8dc2aff74b..cf6448c056 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -7,6 +7,7 @@ from pathlib import Path from invokeai.backend.util import GIG, directory_size from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util.util import safe_filename from .convert_cache_base import ModelConvertCacheBase @@ -35,6 +36,7 @@ class ModelConvertCache(ModelConvertCacheBase): def cache_path(self, key: str) -> Path: """Return the path for a model with the indicated key.""" + key = safe_filename(self._cache_path, key) return self._cache_path / key def make_room(self, size: float) -> None: diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 1bb093a990..6748e85dca 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -23,7 +23,7 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod @dataclass -class LoadedModel: +class LoadedModelWithoutConfig: """ Context manager object that mediates transfer from RAM<->VRAM. @@ -61,7 +61,6 @@ class LoadedModel: not have a state_dict, in which case this value will be None. """ - config: AnyModelConfig _locker: ModelLockerBase def __enter__(self) -> AnyModel: @@ -89,6 +88,13 @@ class LoadedModel: return self._locker.model +@dataclass +class LoadedModel(LoadedModelWithoutConfig): + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: Optional[AnyModelConfig] = None + + # TODO(MM2): # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # know about. I think the problem may be related to this class being an ABC. diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index a58741763f..a63cc66a86 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -16,7 +16,7 @@ from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import TorchDevice @@ -84,7 +84,7 @@ class ModelLoader(ModelLoaderBase): except IndexError: pass - cache_path: Path = self._convert_cache.cache_path(config.key) + cache_path: Path = self._convert_cache.cache_path(str(model_path)) if self._needs_conversion(config, model_path, cache_path): loaded_model = self._do_convert(config, model_path, cache_path, submodel_type) else: @@ -95,7 +95,6 @@ class ModelLoader(ModelLoaderBase): config.key, submodel_type=submodel_type, model=loaded_model, - size=calc_model_size_by_data(loaded_model), ) return self._ram_cache.get( @@ -126,9 +125,7 @@ class ModelLoader(ModelLoaderBase): if subtype == submodel_type: continue if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put( - config.key, submodel_type=subtype, model=submodel, size=calc_model_size_by_data(submodel) - ) + self._ram_cache.put(config.key, submodel_type=subtype, model=submodel) return getattr(pipeline, submodel_type.value) if submodel_type else pipeline def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 0106c0ff18..012fd42d55 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -169,7 +169,6 @@ class ModelCacheBase(ABC, Generic[T]): self, key: str, model: T, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index a3016a63ef..335a15a5c8 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -29,6 +29,7 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -153,13 +154,13 @@ class ModelCache(ModelCacheBase[AnyModel]): self, key: str, model: AnyModel, - size: int, submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" key = self._make_cache_key(key, submodel_type) if key in self._cached_models: return + size = calc_model_size_by_data(model) self.make_room(size) state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None @@ -252,12 +253,7 @@ class ModelCache(ModelCacheBase[AnyModel]): May raise a torch.cuda.OutOfMemoryError """ - # These attributes are not in the base ModelMixin class but in various derived classes. - # Some models don't have these attributes, in which case they run in RAM/CPU. self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): - return - source_device = cache_entry.device # Note: We compare device types only so that 'cuda' == 'cuda:0'. @@ -265,6 +261,10 @@ class ModelCache(ModelCacheBase[AnyModel]): if torch.device(source_device).type == torch.device(target_device).type: return + # Some models don't have a `to` method, in which case they run in RAM/CPU. + if not hasattr(cache_entry.model, "to"): + return + # This roundabout method for moving the model around is done to avoid # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # When moving to VRAM, we copy (not move) each element of the state dict from diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 6d90ed92e8..9de17ca5f5 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -35,10 +35,6 @@ class ModelLocker(ModelLockerBase): def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" - if not hasattr(self.model, "to"): - return self.model - - # NOTE that the model has to have the to() method in order for this code to move it into GPU! self._cache_entry.lock() try: if self._cache.lazy_offloading: @@ -59,9 +55,6 @@ class ModelLocker(ModelLockerBase): def unlock(self) -> None: """Call upon exit from context.""" - if not hasattr(self.model, "to"): - return - self._cache_entry.unlock() if not self._cache.lazy_offloading: self._cache.offload_unlocked_models(0) diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index a4874b33ce..6320797b8a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader): else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: + if class_name := config.get("_class_name"): result = self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures") - assert class_name is not None + elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) - if not class_name: + else: raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 4e3625fdbe..ab78b3e064 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -83,7 +83,7 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): assert s.size is not None files.append( RemoteModelFile( - url=hf_hub_url(id, s.rfilename, revision=variant), + url=hf_hub_url(id, s.rfilename, revision=variant or "main"), path=Path(name, s.rfilename), size=s.size, sha256=s.lfs.get("sha256") if s.lfs else None, diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 585c0fa31c..f9f5335d17 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -37,9 +37,12 @@ class RemoteModelFile(BaseModel): url: AnyHttpUrl = Field(description="The url to download this model file") path: Path = Field(description="The path to the file, relative to the model root") - size: int = Field(description="The size of this file, in bytes") + size: Optional[int] = Field(description="The size of this file, in bytes", default=0) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) + def __hash__(self) -> int: + return hash(str(self)) + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 7d0d9d03f7..1ee89dcc66 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -1,6 +1,8 @@ import base64 import io import os +import re +import unicodedata import warnings from pathlib import Path @@ -12,6 +14,33 @@ from transformers import logging as transformers_logging GIG = 1073741824 +def slugify(value: str, allow_unicode: bool = False) -> str: + """ + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Replace slashes with underscores. + Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + + Adapted from Django: https://github.com/django/django/blob/main/django/utils/text.py + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + value = re.sub(r"[/]", "_", value.lower()) + value = re.sub(r"[^.\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") + + +def safe_filename(directory: Path, value: str) -> str: + """Make a string safe to use as a filename.""" + escaped_string = slugify(value) + max_name_length = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 256 + return escaped_string[len(escaped_string) - max_name_length :] + + def directory_size(directory: Path) -> int: """ Return the aggregate size of all files in a directory (bytes). diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 72c78da814..fd2e2a65ae 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -2,14 +2,18 @@ import re import time +from contextlib import contextmanager from pathlib import Path +from typing import Any, Generator, Optional import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from requests_testadapter import TestAdapter, TestSession +from requests_testadapter import TestAdapter -from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from invokeai.app.services.config import get_config +from invokeai.app.services.config.config_default import URLRegexTokenPair +from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob from invokeai.app.services.events.events_common import ( DownloadCancelledEvent, DownloadCompleteEvent, @@ -17,56 +21,23 @@ from invokeai.app.services.events.events_common import ( DownloadProgressEvent, DownloadStartedEvent, ) +from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings -TestAdapter.__test__ = False # type: ignore +TestAdapter.__test__ = False -@pytest.fixture -def session() -> Session: - sess = TestSession() - for i in ["12345", "9999", "54321"]: - content = ( - b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) - ) # for pause tests, must make content large - sess.mount( - f"http://www.civitai.com/models/{i}", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": f'filename="mock{i}.safetensors"', - }, - ), - ) - - # here are some malformed URLs to test - # missing the content length - sess.mount( - "http://www.civitai.com/models/missing", - TestAdapter( - b"Missing content length", - headers={ - "Content-Disposition": 'filename="missing.txt"', - }, - ), - ) - # not found test - sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) - - return sess - - -@pytest.mark.timeout(timeout=20, method="thread") -def test_basic_queue_download(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None: events = set() - def event_handler(job: DownloadJob) -> None: + def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: events.add(job.status) queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() job = queue.download( @@ -82,16 +53,17 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None: queue.join() assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.download_path == tmp_path / "mock12345.safetensors" assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_errors(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_errors(tmp_path: Path, mm2_session: Session) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -110,11 +82,11 @@ def test_errors(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_event_bus(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), @@ -146,10 +118,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: queue.stop() -@pytest.mark.timeout(timeout=20, method="thread") -def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -178,11 +150,11 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: queue.stop() -@pytest.mark.timeout(timeout=15, method="thread") -def test_cancel(tmp_path: Path, session: Session) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_cancel(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() cancelled = False @@ -194,9 +166,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None: nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, @@ -212,3 +181,178 @@ def test_cancel(tmp_path: Path, session: Session) -> None: assert isinstance(events[-1], DownloadCancelledEvent) assert events[-1].source == "http://www.civitai.com/models/12345" queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "sdxl-turbo" + assert Path( + tmp_path, "sdxl-turbo/model_index.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" + assert Path( + tmp_path, "sdxl-turbo/text_encoder/config.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" + + assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + events = set() + + def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: + events.add(job.status) + + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + files = metadata.download_urls(session=mm2_session) + # this will give a 404 error + files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken"))) + job = queue.multifile_download( + parts=files, + dest=tmp_path, + on_start=event_handler, + on_progress=event_handler, + on_complete=event_handler, + on_error=event_handler, + ) + queue.join() + + assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert job.error_type is not None + assert "HTTPError(NOT FOUND)" in job.error_type + assert DownloadJobStatus.ERROR in events + queue.stop() + + +@pytest.mark.timeout(timeout=10, method="thread") +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None: + event_bus = TestEventService() + + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) + queue.start() + + cancelled = False + + def cancelled_callback(job: DownloadJob) -> None: + nonlocal cancelled + cancelled = True + + fetcher = HuggingFaceMetadataFetch(mm2_session) + metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) + + job = queue.multifile_download( + parts=metadata.download_urls(session=mm2_session), + dest=tmp_path, + on_cancelled=cancelled_callback, + ) + queue.cancel_job(job) + queue.join() + + assert job.status == DownloadJobStatus.CANCELLED + assert cancelled + events = event_bus.events + assert DownloadCancelledEvent in [type(x) for x in events] + queue.stop() + + +def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=[ + RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors")) + ], + dest=tmp_path, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "mock12345.safetensors" + assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" + queue.stop() + + +def test_multifile_no_rel_paths(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + + with pytest.raises(AssertionError) as error: + queue.multifile_download( + parts=[RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("/etc/passwd"))], + dest=tmp_path, + ) + assert str(error.value) == "only relative download paths accepted" + + +@contextmanager +def clear_config() -> Generator[None, None, None]: + try: + yield None + finally: + get_config.cache_clear() + + +def test_tokens(tmp_path: Path, mm2_session: Session): + with clear_config(): + config = get_config() + config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] + queue = DownloadQueueService(requests_session=mm2_session) + queue.start() + # this one has an access token assigned + job1 = queue.download( + source=AnyHttpUrl("http://www.civitai.com/models/12345"), + dest=tmp_path, + ) + # this one doesn't + job2 = queue.download( + source=AnyHttpUrl( + "http://www.huggingface.co/foo.txt", + ), + dest=tmp_path, + ) + queue.join() + # this token is defined in the temporary root invokeai.yaml + # see tests/backend/model_manager/data/invokeai_root/invokeai.yaml + assert job1.access_token == "cv_12345" + assert job2.access_token is None + queue.stop() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index b380414be8..9602a79a27 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -20,6 +20,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallStartedEvent, ) from invokeai.app.services.model_install import ( + HFModelSource, ModelInstallServiceBase, ) from invokeai.app.services.model_install.model_install_common import ( @@ -29,7 +30,14 @@ from invokeai.app.services.model_install.model_install_common import ( URLModelSource, ) from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException -from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType +from invokeai.backend.model_manager.config import ( + BaseModelType, + InvalidModelConfigException, + ModelFormat, + ModelRepoVariant, + ModelType, +) +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService OS = platform.uname().system @@ -222,7 +230,7 @@ def test_delete_register( store.get_model(key) -@pytest.mark.timeout(timeout=20, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) @@ -243,15 +251,16 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: model_record = store.get_model(key) assert (mm2_app_config.models_path / model_record.path).exists() - assert len(bus.events) == 4 - assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) - assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) - assert isinstance(bus.events[2], ModelInstallStartedEvent) - assert isinstance(bus.events[3], ModelInstallCompleteEvent) + assert len(bus.events) == 5 + assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts + assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses + assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed + assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started + assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed -@pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: +@pytest.mark.timeout(timeout=10, method="thread") +def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) bus: TestEventService = mm2_installer.event_bus @@ -277,6 +286,49 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co assert len(bus.events) >= 3 +@pytest.mark.timeout(timeout=10, method="thread") +def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: + source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) + + bus = mm2_installer.event_bus + store = mm2_installer.record_store + assert isinstance(bus, EventServiceBase) + assert store is not None + + job = mm2_installer.import_model(source) + job_list = mm2_installer.wait_for_installs(timeout=10) + assert len(job_list) == 1 + assert job.complete + assert job.config_out + + key = job.config_out.key + model_record = store.get_model(key) + assert (mm2_app_config.models_path / model_record.path).exists() + assert model_record.type == ModelType.Main + assert model_record.format == ModelFormat.Diffusers + + assert hasattr(bus, "events") # the dummyeventservice has this + assert len(bus.events) >= 3 + event_types = [type(x) for x in bus.events] + assert all( + x in event_types + for x in [ + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ] + ) + + completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)] + downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)] + assert completed_events[0].total_bytes == downloading_events[-1].bytes + assert job.total_bytes == completed_events[0].total_bytes + print(downloading_events[-1]) + print(job.download_parts) + assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts) + + def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors")) job = mm2_installer.import_model(source) @@ -308,7 +360,6 @@ def test_other_error_during_install( assert job.error == "Test error" -# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test @pytest.mark.parametrize( "model_params", [ @@ -326,7 +377,7 @@ def test_other_error_during_install( }, ], ) -@pytest.mark.timeout(timeout=40, method="thread") +@pytest.mark.timeout(timeout=10, method="thread") def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" assert "name" in model_params and "type" in model_params @@ -342,7 +393,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode } assert "repo_id" in model_params install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) - mm2_installer.wait_for_job(install_job1, timeout=20) + mm2_installer.wait_for_job(install_job1, timeout=10) if model_params["type"] != "embedding": assert install_job1.errored assert install_job1.error_type == "InvalidModelConfigException" @@ -351,6 +402,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) - mm2_installer.wait_for_job(install_job2, timeout=20) + mm2_installer.wait_for_job(install_job2, timeout=10) assert install_job2.complete assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py new file mode 100644 index 0000000000..6f2c7bd931 --- /dev/null +++ b/tests/app/services/model_load/test_load_api.py @@ -0,0 +1,88 @@ +from pathlib import Path + +import pytest +import torch +from diffusers import AutoencoderTiny + +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.model_manager import ModelManagerServiceBase +from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context +from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig +from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 + + +@pytest.fixture() +def mock_context( + mock_services: InvocationServices, + mm2_model_manager: ModelManagerServiceBase, +) -> InvocationContext: + mock_services.model_manager = mm2_model_manager + return build_invocation_context( + services=mock_services, + data=None, # type: ignore + is_canceled=None, # type: ignore + ) + + +def test_download_and_cache(mock_context: InvocationContext, mm2_root_dir: Path) -> None: + downloaded_path = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path.is_file() + assert downloaded_path.exists() + assert downloaded_path.name == "test_embedding.safetensors" + assert downloaded_path.parent.parent == mm2_root_dir / "models/.download_cache" + + downloaded_path_2 = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + assert downloaded_path == downloaded_path_2 + + +def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -> None: + downloaded_path = mock_context.models.download_and_cache_model( + "https://www.test.foo/download/test_embedding.safetensors" + ) + loaded_model_1 = mock_context.models.load_local_model(downloaded_path) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) + + loaded_model_2 = mock_context.models.load_local_model(downloaded_path) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) + assert loaded_model_1.model is loaded_model_2.model + + loaded_model_3 = mock_context.models.load_local_model(embedding_file) + assert isinstance(loaded_model_3, LoadedModelWithoutConfig) + assert loaded_model_1.model is not loaded_model_3.model + assert isinstance(loaded_model_1.model, dict) + assert isinstance(loaded_model_3.model, dict) + assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) + + +@pytest.mark.skip(reason="This requires a test model to load") +def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: + loaded_model = mock_context.models.load_local_model(vae_directory) + assert isinstance(loaded_model, LoadedModelWithoutConfig) + assert isinstance(loaded_model.model, AutoencoderTiny) + + +def test_download_and_load(mock_context: InvocationContext) -> None: + loaded_model_1 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) + + loaded_model_2 = mock_context.models.load_remote_model("https://www.test.foo/download/test_embedding.safetensors") + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) + assert loaded_model_1.model is loaded_model_2.model # should be cached copy + + +def test_download_diffusers(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo") + assert (model_path / "model_index.json").exists() + assert (model_path / "vae").is_dir() + + +def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") + assert model_path.is_dir() + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or ( + model_path / "diffusion_pytorch_model.safetensors" + ).exists() diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 5ddccd05bb..f82239298e 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -61,6 +61,13 @@ def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" +# Can be used to test diffusers model directory loading, but +# the test file adds ~10MB of space. +# @pytest.fixture +# def vae_directory(mm2_model_files: Path) -> Path: +# return mm2_model_files / "taesdxl" + + @pytest.fixture def diffusers_dir(mm2_model_files: Path) -> Path: return mm2_model_files / "test-diffusers-main" @@ -294,4 +301,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: }, ), ) + + for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + sess.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + + # here are some malformed URLs to test + # missing the content length + sess.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), + ) + # not found test + sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + return sess