From aae318425d52af7a573f3fb4c0a6a00576ce2db2 Mon Sep 17 00:00:00 2001 From: chainchompa Date: Fri, 14 Jun 2024 17:08:39 -0400 Subject: [PATCH 01/10] added route for installing huggingface model from model marketplace --- invokeai/app/api/routers/model_manager.py | 76 ++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index b1221f7a34..01590b21be 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -9,7 +9,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Type from fastapi import Body, Path, Query, Response, UploadFile -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, HTMLResponse from fastapi.routing import APIRouter from PIL import Image from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field @@ -501,6 +501,80 @@ async def install_model( raise HTTPException(status_code=409, detail=str(e)) return result +@model_manager_router.get( + "/install/huggingface", + operation_id="install_hugging_face_model", + responses={ + 201: {"description": "The model is being installed"}, + 400: {"description": "Bad request"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, + }, + status_code=201, + response_class=HTMLResponse +) +async def install_hugging_face_model( + source: str = Query(description="Hugging Face repo_id to install"), +) -> HTMLResponse: + """Install a Hugging Face model using a string identifier.""" + + def generate_html(message: str) -> str: + return f""" + + + + + +
+

{message}

+
+ + + """ + + try: + metadata = HuggingFaceMetadataFetch().from_id(source) + assert isinstance(metadata, ModelMetadataWithFiles) + message = "Your Hugging Face model is installing now. You can close this tab and check the Model Manager for installation progress." + except UnknownMetadataException: + message = "No HuggingFace repository found with that repo id." + return HTMLResponse(content=generate_html(message), status_code=400) + + logger = ApiDependencies.invoker.services.logger + + try: + installer = ApiDependencies.invoker.services.model_manager.install + + if metadata.is_diffusers: + installer.heuristic_import( + source=source, + inplace=False, + ) + elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1: + installer.heuristic_import( + source=str(metadata.ckpt_urls[0]), + inplace=False, + ) + else: + message = "This HuggingFace repo has multiple models. Please use the Model Manager to install this." + except Exception as e: + logger.error(str(e)) + message = "There was an error with installing this model. Please use the Model Manager to install this." + + return HTMLResponse(content=generate_html(message), status_code=201) + @model_manager_router.get( "/install", From 328f160e88b403b193f3a0708ffce4f878084ec5 Mon Sep 17 00:00:00 2001 From: chainchompa Date: Fri, 14 Jun 2024 17:09:07 -0400 Subject: [PATCH 02/10] refetch model installs when a new model install starts --- .../listeners/socketio/socketModelInstall.ts | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index 7fafb8302c..adfa7edd06 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -6,9 +6,17 @@ import { socketModelInstallComplete, socketModelInstallDownloadProgress, socketModelInstallError, + socketModelInstallStarted, } from 'services/events/actions'; export const addModelInstallEventListener = (startAppListening: AppStartListening) => { + startAppListening({ + actionCreator: socketModelInstallStarted, + effect: async (action, { dispatch }) => { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + }, + }); + startAppListening({ actionCreator: socketModelInstallDownloadProgress, effect: async (action, { dispatch }) => { From 40299725304e3fa3271b2d6d746512b7c448e1cd Mon Sep 17 00:00:00 2001 From: chainchompa Date: Fri, 14 Jun 2024 17:15:55 -0400 Subject: [PATCH 03/10] formatting --- invokeai/app/api/routers/model_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 01590b21be..f2fb0932e5 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -501,6 +501,7 @@ async def install_model( raise HTTPException(status_code=409, detail=str(e)) return result + @model_manager_router.get( "/install/huggingface", operation_id="install_hugging_face_model", @@ -510,7 +511,7 @@ async def install_model( 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_class=HTMLResponse + response_class=HTMLResponse, ) async def install_hugging_face_model( source: str = Query(description="Hugging Face repo_id to install"), From 1bc98abc7626690b4f850b80a72dd4d4573190dc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:33:46 +1000 Subject: [PATCH 04/10] docs(ui): explain model install events --- .../listeners/socketio/socketModelInstall.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index adfa7edd06..113d2cbd66 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -9,6 +9,21 @@ import { socketModelInstallStarted, } from 'services/events/actions'; +/** + * A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_` + * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully + * downloaded and is being "physically" installed. + * + * Here's the expected flow: + * - Model manager does some prep + * - `model_install_download_progress` fired when the download starts and continually until the download is complete + * - `model_install_download_complete` fired when the download is complete + * - `model_install_started` fired when the "physical" installation starts + * - `model_install_complete` fired when the installation is complete + * - `model_install_cancelled` fired if the installation is cancelled + * - `model_install_error` fired if the installation has an error + */ + export const addModelInstallEventListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: socketModelInstallStarted, From fb694b3e179733400676c8cc7bcbb1c7ae7ba60e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:50:25 +1000 Subject: [PATCH 05/10] feat(app): add `model_install_download_started` event Previously, we used `model_install_download_progress` for both download starting and progressing. When handling this event, we don't know which actual thing it represents. Add `model_install_download_started` event to explicitly represent a model download started event. --- invokeai/app/services/events/events_base.py | 5 +++ invokeai/app/services/events/events_common.py | 36 +++++++++++++++++++ .../model_install/model_install_default.py | 9 ++++- 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index cf49cc0626..bb578c23e8 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, @@ -144,6 +145,10 @@ class EventServiceBase: # region Model install + def emit_model_install_download_started(self, job: "ModelInstallJob") -> None: + """Emitted at intervals while the install job is started (remote models only).""" + self.dispatch(ModelInstallDownloadStartedEvent.build(job)) + def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: """Emitted at intervals while the install job is in progress (remote models only).""" self.dispatch(ModelInstallDownloadProgressEvent.build(job)) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 0adcaa2ab1..c6a867fb08 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase): return cls(config=config, submodel_type=submodel_type) +@payload_schema.register +class ModelInstallDownloadStartedEvent(ModelEventBase): + """Event model for model_install_download_started""" + + __event_name__ = "model_install_download_started" + + id: int = Field(description="The ID of the install job") + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent": + parts: list[dict[str, str | int]] = [ + { + "url": str(x.source), + "local_path": str(x.download_path), + "bytes": x.bytes, + "total_bytes": x.total_bytes, + } + for x in job.download_parts + ] + return cls( + id=job.id, + source=str(job.source), + local_path=job.local_path.as_posix(), + parts=parts, + bytes=job.bytes, + total_bytes=job.total_bytes, + ) + + @payload_schema.register class ModelInstallDownloadProgressEvent(ModelEventBase): """Event model for model_install_download_progress""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 0a2e2d798a..dd1b44d899 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -822,7 +822,7 @@ class ModelInstallService(ModelInstallServiceBase): install_job.download_parts = download_job.download_parts install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.total_bytes = download_job.total_bytes - self._signal_job_downloading(install_job) + self._signal_job_download_started(install_job) def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: @@ -874,6 +874,13 @@ class ModelInstallService(ModelInstallServiceBase): if self._event_bus: self._event_bus.emit_model_install_started(job) + def _signal_job_download_started(self, job: ModelInstallJob) -> None: + if self._event_bus: + assert job._multifile_job is not None + assert job.bytes is not None + assert job.total_bytes is not None + self._event_bus.emit_model_install_download_started(job) + def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: assert job._multifile_job is not None From c11478a94ac53e9f61ff57ff69e75fb45d6504f2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:51:18 +1000 Subject: [PATCH 06/10] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 435 +++++++++++------- 1 file changed, 260 insertions(+), 175 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 5482b57c0b..fe2732d06b 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -123,6 +123,13 @@ export type paths = { */ delete: operations["prune_model_install_jobs"]; }; + "/api/v2/models/install/huggingface": { + /** + * Install Hugging Face Model + * @description Install a Hugging Face model using a string identifier. + */ + get: operations["install_hugging_face_model"]; + }; "/api/v2/models/install/{id}": { /** * Get Model Install Job @@ -3788,23 +3795,6 @@ export type components = { * @description Class to monitor and control a model download request. */ DownloadJob: { - /** - * Source - * Format: uri - * @description Where to download from. Specific types specified in child classes. - */ - source: string; - /** - * Dest - * Format: path - * @description Destination of downloaded model on local disk; a directory or file path - */ - dest: string; - /** - * Access Token - * @description authorization token for protected resources - */ - access_token?: string | null; /** * Id * @description Numeric ID of this job @@ -3812,36 +3802,21 @@ export type components = { */ id?: number; /** - * Priority - * @description Queue priority; lower values are higher priority - * @default 10 + * Dest + * Format: path + * @description Initial destination of downloaded model on local disk; a directory or file path */ - priority?: number; + dest: string; + /** + * Download Path + * @description Final location of downloaded file or directory + */ + download_path?: string | null; /** * @description Status of the download * @default waiting */ status?: components["schemas"]["DownloadJobStatus"]; - /** - * Download Path - * @description Final location of downloaded file - */ - download_path?: string | null; - /** - * Job Started - * @description Timestamp for when the download job started - */ - job_started?: string | null; - /** - * Job Ended - * @description Timestamp for when the download job ende1d (completed or errored) - */ - job_ended?: string | null; - /** - * Content Type - * @description Content type of downloaded file - */ - content_type?: string | null; /** * Bytes * @description Bytes downloaded so far @@ -3864,6 +3839,38 @@ export type components = { * @description Traceback of the exception that caused an error */ error?: string | null; + /** + * Source + * Format: uri + * @description Where to download from. Specific types specified in child classes. + */ + source: string; + /** + * Access Token + * @description authorization token for protected resources + */ + access_token?: string | null; + /** + * Priority + * @description Queue priority; lower values are higher priority + * @default 10 + */ + priority?: number; + /** + * Job Started + * @description Timestamp for when the download job started + */ + job_started?: string | null; + /** + * Job Ended + * @description Timestamp for when the download job ende1d (completed or errored) + */ + job_ended?: string | null; + /** + * Content Type + * @description Content type of downloaded file + */ + content_type?: string | null; }; /** * DownloadJobStatus @@ -7276,144 +7283,144 @@ export type components = { project_id: string | null; }; InvocationOutputMap: { - pidi_image_processor: components["schemas"]["ImageOutput"]; - image_mask_to_tensor: components["schemas"]["MaskOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - img_pad_crop: components["schemas"]["ImageOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - div: components["schemas"]["IntegerOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; - float_math: components["schemas"]["FloatOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; - canny_image_processor: components["schemas"]["ImageOutput"]; - integer_collection: components["schemas"]["IntegerCollectionOutput"]; - freeu: components["schemas"]["UNetOutput"]; - lresize: components["schemas"]["LatentsOutput"]; - range_of_size: components["schemas"]["IntegerCollectionOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - float_to_int: components["schemas"]["IntegerOutput"]; - rand_int: components["schemas"]["IntegerOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; - string_split: components["schemas"]["String2Output"]; - img_nsfw: components["schemas"]["ImageOutput"]; - string: components["schemas"]["StringOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - compel: components["schemas"]["ConditioningOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - show_image: components["schemas"]["ImageOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - image: components["schemas"]["ImageOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - float: components["schemas"]["FloatOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - color: components["schemas"]["ColorOutput"]; midas_depth_image_processor: components["schemas"]["ImageOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + string_split: components["schemas"]["String2Output"]; + mask_edge: components["schemas"]["ImageOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + save_image: components["schemas"]["ImageOutput"]; + show_image: components["schemas"]["ImageOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + latents: components["schemas"]["LatentsOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + denoise_latents: components["schemas"]["LatentsOutput"]; + metadata: components["schemas"]["MetadataOutput"]; + compel: components["schemas"]["ConditioningOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + noise: components["schemas"]["NoiseOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; + esrgan: components["schemas"]["ImageOutput"]; merge_tiles_to_image: components["schemas"]["ImageOutput"]; prompt_from_file: components["schemas"]["StringCollectionOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - controlnet: components["schemas"]["ControlOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - noise: components["schemas"]["NoiseOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - tomask: components["schemas"]["ImageOutput"]; - integer: components["schemas"]["IntegerOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; - denoise_latents: components["schemas"]["LatentsOutput"]; - string_join: components["schemas"]["StringOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - model_identifier: components["schemas"]["ModelIdentifierOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - img_watermark: components["schemas"]["ImageOutput"]; - l2i: components["schemas"]["ImageOutput"]; - string_replace: components["schemas"]["StringOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - add: components["schemas"]["IntegerOutput"]; - sub: components["schemas"]["IntegerOutput"]; - img_scale: components["schemas"]["ImageOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - infill_tile: components["schemas"]["ImageOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - latents: components["schemas"]["LatentsOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - ideal_size: components["schemas"]["IdealSizeOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - metadata: components["schemas"]["MetadataOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - round_float: components["schemas"]["FloatOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - mul: components["schemas"]["IntegerOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - save_image: components["schemas"]["ImageOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; + tomask: components["schemas"]["ImageOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + img_watermark: components["schemas"]["ImageOutput"]; + img_pad_crop: components["schemas"]["ImageOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + string_join: components["schemas"]["StringOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + string_replace: components["schemas"]["StringOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + canny_image_processor: components["schemas"]["ImageOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; + infill_tile: components["schemas"]["ImageOutput"]; + integer_collection: components["schemas"]["IntegerCollectionOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + img_scale: components["schemas"]["ImageOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; + range_of_size: components["schemas"]["IntegerCollectionOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + controlnet: components["schemas"]["ControlOutput"]; + string: components["schemas"]["StringOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; + freeu: components["schemas"]["UNetOutput"]; + round_float: components["schemas"]["FloatOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + float: components["schemas"]["FloatOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + color: components["schemas"]["ColorOutput"]; + image: components["schemas"]["ImageOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + l2i: components["schemas"]["ImageOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + string_join_three: components["schemas"]["StringOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; + add: components["schemas"]["IntegerOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; + mul: components["schemas"]["IntegerOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + i2l: components["schemas"]["LatentsOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; + integer: components["schemas"]["IntegerOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + image_mask_to_tensor: components["schemas"]["MaskOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + sub: components["schemas"]["IntegerOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + div: components["schemas"]["IntegerOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + rand_int: components["schemas"]["IntegerOutput"]; + float_math: components["schemas"]["FloatOutput"]; }; /** * InvocationStartedEvent @@ -9443,6 +9450,49 @@ export type components = { [key: string]: number | string; })[]; }; + /** + * ModelInstallDownloadStartedEvent + * @description Event model for model_install_download_started + */ + ModelInstallDownloadStartedEvent: { + /** + * Timestamp + * @description The timestamp of the event + */ + timestamp: number; + /** + * Id + * @description The ID of the install job + */ + id: number; + /** + * Source + * @description Source of the model; local path, repo_id or url + */ + source: string; + /** + * Local Path + * @description Where model is downloading to + */ + local_path: string; + /** + * Bytes + * @description Number of bytes downloaded so far + */ + bytes: number; + /** + * Total Bytes + * @description Total size of download, including all files + */ + total_bytes: number; + /** + * Parts + * @description Progress of downloading URLs that comprise the model, if any + */ + parts: ({ + [key: string]: number | string; + })[]; + }; /** * ModelInstallDownloadsCompleteEvent * @description Emitted once when an install job becomes active. @@ -10671,8 +10721,9 @@ export type components = { /** * Size * @description The size of this file, in bytes + * @default 0 */ - size: number; + size?: number | null; /** * Sha256 * @description SHA256 hash of this model (not always available) @@ -14050,6 +14101,40 @@ export type operations = { }; }; }; + /** + * Install Hugging Face Model + * @description Install a Hugging Face model using a string identifier. + */ + install_hugging_face_model: { + parameters: { + query: { + /** @description Hugging Face repo_id to install */ + source: string; + }; + }; + responses: { + /** @description The model is being installed */ + 201: { + content: { + "text/html": string; + }; + }; + /** @description Bad request */ + 400: { + content: never; + }; + /** @description There is already a model corresponding to this path or repo_id */ + 409: { + content: never; + }; + /** @description Validation Error */ + 422: { + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; /** * Get Model Install Job * @description Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' From 56771de8565d2ff6d435fc6f234c0ee5254ca433 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:52:46 +1000 Subject: [PATCH 07/10] feat(ui): add redux actions for `model_install_download_started` event --- invokeai/frontend/web/src/services/events/actions.ts | 4 ++++ invokeai/frontend/web/src/services/events/types.ts | 2 ++ 2 files changed, 6 insertions(+) diff --git a/invokeai/frontend/web/src/services/events/actions.ts b/invokeai/frontend/web/src/services/events/actions.ts index 257819b4c8..a97bdcbf8b 100644 --- a/invokeai/frontend/web/src/services/events/actions.ts +++ b/invokeai/frontend/web/src/services/events/actions.ts @@ -16,6 +16,7 @@ import type { ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, @@ -45,6 +46,9 @@ export const socketModelInstallStarted = createSocketAction( 'ModelInstallDownloadProgressEvent' ); +export const socketModelInstallDownloadStarted = createSocketAction( + 'ModelInstallDownloadStartedEvent' +); export const socketModelInstallDownloadsComplete = createSocketAction( 'ModelInstallDownloadsCompleteEvent' ); diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index a84049cc28..2d3725394d 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -9,6 +9,7 @@ export type InvocationCompleteEvent = S['InvocationCompleteEvent']; export type InvocationErrorEvent = S['InvocationErrorEvent']; export type ProgressImage = InvocationDenoiseProgressEvent['progress_image']; +export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent']; export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent']; @@ -49,6 +50,7 @@ export type ServerToClientEvents = { download_error: (payload: DownloadErrorEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void; model_install_started: (payload: ModelInstallStartedEvent) => void; + model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void; model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; model_install_complete: (payload: ModelInstallCompleteEvent) => void; From f002bca2fa057aa9dcb997fc3520a3b31a12578a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:07:10 +1000 Subject: [PATCH 08/10] feat(ui): handle new `model_install_download_started` event When a model install is initiated from outside the client, we now trigger the model manager tab's model install list to update. - Handle new `model_install_download_started` event - Handle `model_install_download_complete` event (this event is not new but was never handled) - Update optimistic updates/cache invalidation logic to efficiently update the model install list --- .../listeners/socketio/socketModelInstall.ts | 184 +++++++++++++----- 1 file changed, 136 insertions(+), 48 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts index 113d2cbd66..22ad87fbe9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall.ts @@ -5,6 +5,8 @@ import { socketModelInstallCancelled, socketModelInstallComplete, socketModelInstallDownloadProgress, + socketModelInstallDownloadsComplete, + socketModelInstallDownloadStarted, socketModelInstallError, socketModelInstallStarted, } from 'services/events/actions'; @@ -14,9 +16,12 @@ import { * which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully * downloaded and is being "physically" installed. * + * Note: the download events are only fired for remote model installs, not local. + * * Here's the expected flow: - * - Model manager does some prep - * - `model_install_download_progress` fired when the download starts and continually until the download is complete + * - API receives install request, model manager preps the install + * - `model_install_download_started` fired when the download starts + * - `model_install_download_progress` fired continually until the download is complete * - `model_install_download_complete` fired when the download is complete * - `model_install_started` fired when the "physical" installation starts * - `model_install_complete` fired when the installation is complete @@ -24,47 +29,98 @@ import { * - `model_install_error` fired if the installation has an error */ +const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select(); + export const addModelInstallEventListener = (startAppListening: AppStartListening) => { + startAppListening({ + actionCreator: socketModelInstallDownloadStarted, + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } + }, + }); + startAppListening({ actionCreator: socketModelInstallStarted, - effect: async (action, { dispatch }) => { - dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + effect: async (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'running'; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallDownloadProgress, - effect: async (action, { dispatch }) => { + effect: async (action, { dispatch, getState }) => { const { bytes, total_bytes, id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.bytes = bytes; - modelImport.total_bytes = total_bytes; - modelImport.status = 'downloading'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.bytes = bytes; + modelImport.total_bytes = total_bytes; + modelImport.status = 'downloading'; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallComplete, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'completed'; - } - return draft; - }) - ); + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'completed'; + } + return draft; + }) + ); + } + dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); }, @@ -72,37 +128,69 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin startAppListening({ actionCreator: socketModelInstallError, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id, error, error_type } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'error'; - modelImport.error_reason = error_type; - modelImport.error = error; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'error'; + modelImport.error_reason = error_type; + modelImport.error = error; + } + return draft; + }) + ); + } }, }); startAppListening({ actionCreator: socketModelInstallCancelled, - effect: (action, { dispatch }) => { + effect: (action, { dispatch, getState }) => { const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); - dispatch( - modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { - const modelImport = draft.find((m) => m.id === id); - if (modelImport) { - modelImport.status = 'cancelled'; - } - return draft; - }) - ); + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'cancelled'; + } + return draft; + }) + ); + } + }, + }); + + startAppListening({ + actionCreator: socketModelInstallDownloadsComplete, + effect: (action, { dispatch, getState }) => { + const { id } = action.payload.data; + const { data } = selectModelInstalls(getState()); + + if (!data || !data.find((m) => m.id === id)) { + dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }])); + } else { + dispatch( + modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { + const modelImport = draft.find((m) => m.id === id); + if (modelImport) { + modelImport.status = 'downloads_done'; + } + return draft; + }) + ); + } }, }); }; From cd70937b7f3a44669f6c011376bbb44fb8201045 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:51:08 +1000 Subject: [PATCH 09/10] feat(api): improved model install confirmation page styling & messaging --- invokeai/app/api/routers/model_manager.py | 118 ++++++++++++++++------ 1 file changed, 85 insertions(+), 33 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index f2fb0932e5..99f00423c6 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -514,50 +514,93 @@ async def install_model( response_class=HTMLResponse, ) async def install_hugging_face_model( - source: str = Query(description="Hugging Face repo_id to install"), + source: str = Query(description="HuggingFace repo_id to install"), ) -> HTMLResponse: """Install a Hugging Face model using a string identifier.""" - def generate_html(message: str) -> str: + def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str: + if message: + message = f"

{message}

" + title_class = "error" if is_error else "success" return f""" - - - - - -
-

{message}

-
- - + + + + {title} + + + + +
+
+

{heading}

+ {message} +

Repo ID: {repo_id}

+
+
+ + + """ try: metadata = HuggingFaceMetadataFetch().from_id(source) assert isinstance(metadata, ModelMetadataWithFiles) - message = "Your Hugging Face model is installing now. You can close this tab and check the Model Manager for installation progress." except UnknownMetadataException: - message = "No HuggingFace repository found with that repo id." - return HTMLResponse(content=generate_html(message), status_code=400) + title = "Unable to Install Model" + heading = "No HuggingFace repository found with that repo ID." + message = "Ensure the repo ID is correct and try again." + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400) logger = ApiDependencies.invoker.services.logger try: installer = ApiDependencies.invoker.services.model_manager.install - if metadata.is_diffusers: installer.heuristic_import( source=source, @@ -569,12 +612,21 @@ async def install_hugging_face_model( inplace=False, ) else: - message = "This HuggingFace repo has multiple models. Please use the Model Manager to install this." + title = "Unable to Install Model" + heading = "This HuggingFace repo has multiple models." + message = "Please use the Model Manager to install this model." + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200) + + title = "Model Install Started" + heading = "Your HuggingFace model is installing now." + message = "You can close this tab and check the Model Manager for installation progress." + return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201) except Exception as e: logger.error(str(e)) - message = "There was an error with installing this model. Please use the Model Manager to install this." - - return HTMLResponse(content=generate_html(message), status_code=201) + title = "Unable to Install Model" + heading = "There was an problem installing this model." + message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on discord.' + return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500) @model_manager_router.get( From e26125b734be2bcfb3188c1a0e29ccd3b6f474ba Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:57:11 +1000 Subject: [PATCH 10/10] tests: fix test_model_install.py --- tests/app/services/model_install/test_model_install.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 9602a79a27..0c212cca76 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -17,6 +17,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallCompleteEvent, ModelInstallDownloadProgressEvent, ModelInstallDownloadsCompleteEvent, + ModelInstallDownloadStartedEvent, ModelInstallStartedEvent, ) from invokeai.app.services.model_install import ( @@ -252,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: assert (mm2_app_config.models_path / model_record.path).exists() assert len(bus.events) == 5 - assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts + assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started