Merge branch 'main' into lstein/feat/sd3-model-loading

This commit is contained in:
blessedcoolant 2024-06-20 08:43:34 +05:30
commit cd99ef2f46
30 changed files with 887 additions and 618 deletions

View File

@ -9,7 +9,7 @@ from copy import deepcopy
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse from fastapi.responses import FileResponse, HTMLResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
@ -502,6 +502,133 @@ async def install_model(
return result return result
@model_manager_router.get(
"/install/huggingface",
operation_id="install_hugging_face_model",
responses={
201: {"description": "The model is being installed"},
400: {"description": "Bad request"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_class=HTMLResponse,
)
async def install_hugging_face_model(
source: str = Query(description="HuggingFace repo_id to install"),
) -> HTMLResponse:
"""Install a Hugging Face model using a string identifier."""
def generate_html(title: str, heading: str, repo_id: str, is_error: bool, message: str | None = "") -> str:
if message:
message = f"<p>{message}</p>"
title_class = "error" if is_error else "success"
return f"""
<html>
<head>
<title>{title}</title>
<style>
body {{
text-align: center;
background-color: hsl(220 12% 10% / 1);
font-family: Helvetica, sans-serif;
color: hsl(220 12% 86% / 1);
}}
.repo-id {{
color: hsl(220 12% 68% / 1);
}}
.error {{
color: hsl(0 42% 68% / 1)
}}
.message-box {{
display: inline-block;
border-radius: 5px;
background-color: hsl(220 12% 20% / 1);
padding-inline-end: 30px;
padding: 20px;
padding-inline-start: 30px;
padding-inline-end: 30px;
}}
.container {{
display: flex;
width: 100%;
height: 100%;
align-items: center;
justify-content: center;
}}
a {{
color: inherit
}}
a:visited {{
color: inherit
}}
a:active {{
color: inherit
}}
</style>
</head>
<body style="background-color: hsl(220 12% 10% / 1);">
<div class="container">
<div class="message-box">
<h2 class="{title_class}">{heading}</h2>
{message}
<p class="repo-id">Repo ID: {repo_id}</p>
</div>
</div>
</body>
</html>
"""
try:
metadata = HuggingFaceMetadataFetch().from_id(source)
assert isinstance(metadata, ModelMetadataWithFiles)
except UnknownMetadataException:
title = "Unable to Install Model"
heading = "No HuggingFace repository found with that repo ID."
message = "Ensure the repo ID is correct and try again."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=400)
logger = ApiDependencies.invoker.services.logger
try:
installer = ApiDependencies.invoker.services.model_manager.install
if metadata.is_diffusers:
installer.heuristic_import(
source=source,
inplace=False,
)
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
installer.heuristic_import(
source=str(metadata.ckpt_urls[0]),
inplace=False,
)
else:
title = "Unable to Install Model"
heading = "This HuggingFace repo has multiple models."
message = "Please use the Model Manager to install this model."
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=200)
title = "Model Install Started"
heading = "Your HuggingFace model is installing now."
message = "You can close this tab and check the Model Manager for installation progress."
return HTMLResponse(content=generate_html(title, heading, source, False, message), status_code=201)
except Exception as e:
logger.error(str(e))
title = "Unable to Install Model"
heading = "There was an problem installing this model."
message = 'Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href="https://discord.gg/ZmtBAhwWhy">discord</a>.'
return HTMLResponse(content=generate_html(title, heading, source, True, message), status_code=500)
@model_manager_router.get( @model_manager_router.get(
"/install", "/install",
operation_id="list_model_installs", operation_id="list_model_installs",

View File

@ -1,6 +1,7 @@
from typing import Literal from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8 LATENT_SCALE_FACTOR = 8
""" """
@ -15,3 +16,5 @@ SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke""" """A literal type for PIL image modes supported by Invoke"""
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()

View File

@ -6,7 +6,7 @@ from PIL import Image
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.model import VAEField
@ -30,7 +30,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2) mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
fp32: bool = InputField( fp32: bool = InputField(
default=DEFAULT_PRECISION == "float32", default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32, description=FieldDescriptions.fp32,
ui_order=4, ui_order=4,
) )

View File

@ -7,7 +7,7 @@ from PIL import Image, ImageFilter
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
DenoiseMaskField, DenoiseMaskField,
FieldDescriptions, FieldDescriptions,
@ -74,7 +74,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
) )
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=8)
fp32: bool = InputField( fp32: bool = InputField(
default=DEFAULT_PRECISION == "float32", default=DEFAULT_PRECISION == torch.float32,
description=FieldDescriptions.fp32, description=FieldDescriptions.fp32,
ui_order=9, ui_order=9,
) )

View File

@ -16,7 +16,9 @@ from pydantic import field_validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
DenoiseMaskField, DenoiseMaskField,
@ -27,6 +29,7 @@ from invokeai.app.invocations.fields import (
UIType, UIType,
) )
from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.model import ModelIdentifierField, UNetField
from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -36,6 +39,11 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
@ -45,22 +53,11 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData, TextConditioningData,
TextConditioningRegions, TextConditioningRegions,
) )
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
@ -660,8 +657,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
return 1 - mask, masked_latents, self.denoise_mask.gradient return 1 - mask, masked_latents, self.denoise_mask.gradient
@torch.no_grad() @torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
seed = None seed = None
noise = None noise = None
if self.noise is not None: if self.noise is not None:

View File

@ -12,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@ -44,7 +44,7 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
) )
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod @staticmethod
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor: def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:

View File

@ -11,7 +11,7 @@ from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.denoise_latents import DEFAULT_PRECISION from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
@ -39,7 +39,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection, input=Input.Connection,
) )
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@ -115,6 +115,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue. max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all. allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none. deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory. node_cache_size: How many cached nodes to keep in memory.
@ -189,6 +190,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES # NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")

View File

@ -22,6 +22,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent, ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent, ModelInstallErrorEvent,
ModelInstallStartedEvent, ModelInstallStartedEvent,
ModelLoadCompleteEvent, ModelLoadCompleteEvent,
@ -144,6 +145,10 @@ class EventServiceBase:
# region Model install # region Model install
def emit_model_install_download_started(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is started (remote models only)."""
self.dispatch(ModelInstallDownloadStartedEvent.build(job))
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None: def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only).""" """Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job)) self.dispatch(ModelInstallDownloadProgressEvent.build(job))

View File

@ -417,6 +417,42 @@ class ModelLoadCompleteEvent(ModelEventBase):
return cls(config=config, submodel_type=submodel_type) return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadStartedEvent(ModelEventBase):
"""Event model for model_install_download_started"""
__event_name__ = "model_install_download_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadStartedEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register @payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase): class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress""" """Event model for model_install_download_progress"""

View File

@ -822,7 +822,7 @@ class ModelInstallService(ModelInstallServiceBase):
install_job.download_parts = download_job.download_parts install_job.download_parts = download_job.download_parts
install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.bytes = sum(x.bytes for x in download_job.download_parts)
install_job.total_bytes = download_job.total_bytes install_job.total_bytes = download_job.total_bytes
self._signal_job_downloading(install_job) self._signal_job_download_started(install_job)
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
@ -874,6 +874,13 @@ class ModelInstallService(ModelInstallServiceBase):
if self._event_bus: if self._event_bus:
self._event_bus.emit_model_install_started(job) self._event_bus.emit_model_install_started(job)
def _signal_job_download_started(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._multifile_job is not None
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_download_started(job)
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
assert job._multifile_job is not None assert job._multifile_job is not None

View File

@ -37,8 +37,12 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self.__invoker = invoker
self._set_in_progress_to_canceled() self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID) prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0: if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")

View File

@ -125,13 +125,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype self.device, dtype=self.dtype
) )
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device self.device = device
if dtype is not None: if dtype is not None:
self.dtype = dtype self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype) self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype) self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self): def calc_size(self):
# workaround for circular import # workaround for circular import

View File

@ -61,9 +61,10 @@ class LoRALayerBase:
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias # TODO: find and debug lora/locon with bias
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype) self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None: if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype) self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase): class LoHALayer(LoRALayerBase):
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None: if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype) self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase): class LoKRLayer(LoRALayerBase):
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else: else:
assert self.w1_a is not None assert self.w1_a is not None
assert self.w1_b is not None assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None: if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype) self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else: else:
assert self.w2_a is not None assert self.w2_a is not None
assert self.w2_b is not None assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase): class FullLayer(LoRALayerBase):
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase): class IA3Layer(LoRALayerBase):
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
): ):
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
# TODO: try revert if exception? # TODO: try revert if exception?
for _key, layer in self.layers.items(): for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = 0 model_size = 0
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values # lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear() state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer model.layers[layer_key] = layer
return model return model

View File

@ -0,0 +1,24 @@
import json
from base64 import b64decode
def validate_hash(hash: str):
if ":" not in hash:
return
for enc_hash in hashes:
alg, hash_ = hash.split(":")
if alg == "blake3":
alg = "blake3_single"
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception("Unrecoverable Model Error")
hashes: list[str] = [
"eyJibGFrZTNfbXVsdGkiOiI3Yjc5ODZmM2QyNTk3MDZiMjVhZDRhM2NmNGM2MTcyNGNhZmQ0Yjc4NjI4MjIwNjMyZGU4NjVlM2UxNDEyMTVlIiwiYmxha2UzX3NpbmdsZSI6IjdiNzk4NmYzZDI1OTcwNmIyNWFkNGEzY2Y0YzYxNzI0Y2FmZDRiNzg2MjgyMjA2MzJkZTg2NWUzZTE0MTIxNWUiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNzdlZmU5MzRhZGQ3YmU5Njc3NmJkODM3NWJhZDQxN2QiLCJzaGExIjoiYmM2YzYxYzgwNDgyMTE2ZTY2ZGQyNTYwNjRkYTgxYjFlY2U4NzMzOCIsInNoYTIyNCI6IjgzNzNlZGM4ZTg4Y2UxMTljODdlOTM2OTY4ZWViMWNmMzdjZGY4NTBmZjhjOTZkYjNmMDc4YmE0Iiwic2hhMjU2IjoiNzNjYWMxZWRlZmUyZjdlODFkNjRiMTI2YjIxMmY2Yzk2ZTAwNjgyNGJjZmJkZDI3Y2E5NmUyNTk5ZTQwNzUwZiIsInNoYTM4NCI6IjlmNmUwNzlmOTNiNDlkMTg1YzEyNzY0OGQwNzE3YTA0N2E3MzYyNDI4YzY4MzBhNDViNzExODAwZDE4NjIwZDZjMjcwZGE3ZmY0Y2FjOTRmNGVmZDdiZWQ5OTlkOWU0ZCIsInNoYTUxMiI6IjAwNzE5MGUyYjk5ZjVlN2Q1OGZiYWI2YTk1YmY0NjJiODhkOTg1N2NlNjY4MTMyMGJmM2M0Y2ZiZmY0MjkxZmEzNTMyMTk3YzdkODc2YWQ3NjZhOTQyOTQ2Zjc1OWY2YTViNDBlM2I2MzM3YzIwNWI0M2JkOWMyN2JiMTljNzk0IiwiYmxha2UyYiI6IjlhN2VhNTQzY2ZhMmMzMWYyZDIyNjg2MjUwNzUyNDE0Mjc1OWJiZTA0MWZlMWJkMzQzNDM1MWQwNWZlYjI2OGY2MjU0OTFlMzlmMzdkYWQ4MGM2Y2UzYTE4ZjAxNGEzZjJiMmQ2OGU2OTc0MjRmNTU2M2Y5ZjlhYzc1MzJiMjEwIiwiYmxha2UycyI6ImYxZmMwMjA0YjdjNzIwNGJlNWI1YzY3NDEyYjQ2MjY5NWE3YjFlYWQ2M2E5ZGVkMjEzYjZmYTU0NGZjNjJlYzUiLCJzaGEzXzIyNCI6IjljZDQ3YTBhMzA3NmNmYzI0NjJhNTAzMjVmMjg4ZjFiYzJjMmY2NmU2ODIxODc5NjJhNzU0NjFmIiwic2hhM18yNTYiOiI4NTFlNGI1ZDI1MWZlZTFiYzk0ODU1OWNjMDNiNjhlNTllYWU5YWI1ZTUyYjA0OTgxYTRhOTU4YWQyMDdkYjYwIiwic2hhM18zODQiOiJiZDA2ZTRhZGFlMWQ0MTJmZjFjOTcxMDJkZDFlN2JmY2UzMDViYTgxMTgyNzM3NWY5NTI4OWJkOGIyYTUxNjdiMmUyNzZjODNjNTU3ODFhMTEyMDRhNzc5MTUwMzM5ZTEiLCJzaGEzXzUxMiI6ImQ1ZGQ2OGZmZmY5NGRhZjJhMDkzZTliNmM1MTBlZmZkNThmZTA0ODMyZGQzMzEyOTZmN2NkZmYzNmRhZmQ3NGMxY2VmNjUxNTBkZjk5OGM1ODgyY2MzMzk2MTk1ZTViYjc5OTY1OGFkMTQ3MzFiMjJmZWZiMWQzNmY2MWJjYzJjIiwic2hha2VfMTI4IjoiOWJlNTgwNWMwNjg1MmZmNDUzNGQ4ZDZmODYyMmFkOTJkMGUwMWE2Y2JmYjIwN2QxOTRmM2JkYThiOGNmNWU4ZiIsInNoYWtlXzI1NiI6IjRhYjgwYjY2MzcxYzdhNjBhYWM4NDVkMTZlNWMzZDNhMmM4M2FjM2FjZDNiNTBiNzdjYWYyYTNmMWMyY2ZjZjc5OGNjYjkxN2FjZjQzNzBmZDdjN2ZmODQ5M2Q3NGY1MWM4NGU3M2ViZGQ4MTRmM2MwMzk3YzI4ODlmNTI0Mzg3In0K",
"eyJibGFrZTNfbXVsdGkiOiI4ODlmYzIwMDA4NWY1NWY4YTA4MjhiODg3MDM0OTRhMGFmNWZkZGI5N2E2YmYwMDRjM2VkYTdiYzBkNDU0MjQzIiwiYmxha2UzX3NpbmdsZSI6Ijg4OWZjMjAwMDg1ZjU1ZjhhMDgyOGI4ODcwMzQ5NGEwYWY1ZmRkYjk3YTZiZjAwNGMzZWRhN2JjMGQ0NTQyNDMiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiNTIzNTRhMzkzYTVmOGNjNmMyMzQ0OThiYjcxMDljYzEiLCJzaGExIjoiMTJmYmRhOGE3ZGUwOGMwNDc2NTA5OWY2NGNmMGIzYjcxMjc1MGM1NyIsInNoYTIyNCI6IjEyZWU3N2U0Y2NhODViMDk4YjdjNWJlMWFjNGMwNzljNGM3MmJmODA2YjdlZjU1NGI0NzgxZDkxIiwic2hhMjU2IjoiMjU1NTMwZDAyYTY4MjY4OWE5ZTZjMjRhOWZhMDM2OGNhODMxZTI1OTAyYjM2NzQyNzkwZTk3NzU1ZjEzMmNmNSIsInNoYTM4NCI6IjhkMGEyMTRlNDk0NGE2NGY3ZmZjNTg3MGY0ZWUyZTA0OGIzYjRjMmQ0MGRmMWFmYTVlOGE1ZWNkN2IwOTY3M2ZjNWI5YzM5Yzg4Yjc2YmIwY2I4ZjQ1ZjAxY2MwNjZkNCIsInNoYTUxMiI6Ijg3NTM3OWNiYzdlOGYyNzU4YjVjMDY5ZTU2ZWRjODY1ODE4MGFkNDEzNGMwMzY1NzM4ZjM1YjQwYzI2M2JkMTMwMzcwZTE0MzZkNDNmOGFhMTgyMTg5MzgzMTg1ODNhOWJhYTUyYTBjMTk1Mjg5OTQzYzZiYTY2NTg1Yjg5M2ZiIiwiYmxha2UyYiI6IjBhY2MwNWEwOGE5YjhhODNmZTVjYTk4ZmExMTg3NTYwNjk0MjY0YWUxNTI4NDliYzFkNzQzNTYzMzMyMTlhYTg3N2ZiNjc4MmRjZDZiOGIyYjM1MTkyNDQzNDE2ODJiMTQ3YmY2YTY3MDU2ZWIwOTQ4MzE1M2E4Y2ZiNTNmMTI0IiwiYmxha2UycyI6ImY5ZTRhZGRlNGEzZDRhOTZhOWUyNjVjMGVmMjdmZDNiNjA0NzI1NDllMTEyMWQzOGQwMTkxNTY5ZDY5YzdhYzAiLCJzaGEzXzIyNCI6ImM0NjQ3MGRjMjkyNGI0YjZkMTA2NDY5MDRiNWM2OGVjNTU2YmQ4MTA5NmVkMTA4YjZiMzQyZmU1Iiwic2hhM18yNTYiOiIwMDBlMThiZTI1MzYxYTk0NGExZTIwNjQ5ZmY0ZGM2OGRiZTk0OGNkNTYwY2I5MTFhODU1OTE3ODdkNWQ5YWYwIiwic2hhM18zODQiOiIzNDljZmVhMGUxZGE0NWZlMmYzNjJhMWFjZjI1ZTczOWNiNGQ0NDdiM2NiODUzZDVkYWNjMzU5ZmRhMWE1M2FhYWU5OTM2ZmFhZWM1NmFhZDkwMThhYjgxMTI4ZjI3N2YiLCJzaGEzXzUxMiI6ImMxNDgwNGY1YTNjNWE4ZGEyMTAyODk1YTFjZGU4MmIwNGYwZmY4OTczMTc0MmY2NDQyY2NmNzQ1OTQzYWQ5NGViOWZmMTNhZDg3YjRmODkxN2M5NmY5ZjMwZjkwYTFhYTI4OTI3OTkwMjg0ZDJhMzcyMjA0NjE4MTNiNDI0MzEyIiwic2hha2VfMTI4IjoiN2IxY2RkMWUyMzUzMzk0OTg5M2UyMmZkMTAwZmU0YjJhMTU1MDJmMTNjMTI0YzhiZDgxY2QwZDdlOWEzMGNmOCIsInNoYWtlXzI1NiI6ImI0NjMzZThhMjNkZDM0ODk0ZTIyNzc0ODYyNTE1MzVjYWFlNjkyMTdmOTQ0NTc3MzE1NTljODBjNWQ3M2ZkOTMxZTFjMDJlZDI0Yjc3MzE3OTJjMjVlNTZhYjg3NjI4YmJiMDgxNTU0MjU2MWY5ZGI2NWE0NDk4NDFmNGQzYTU4In0K",
"eyJibGFrZTNfbXVsdGkiOiI2Y2M0MmU4NGRiOGQyZTliYjA4YjUxNWUwYzlmYzg2NTViNDUwNGRlZDM1MzBlZjFjNTFjZWEwOWUxYThiNGYxIiwiYmxha2UzX3NpbmdsZSI6IjZjYzQyZTg0ZGI4ZDJlOWJiMDhiNTE1ZTBjOWZjODY1NWI0NTA0ZGVkMzUzMGVmMWM1MWNlYTA5ZTFhOGI0ZjEiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZDQwNjk3NTJhYjQ0NzFhZDliMDY3YmUxMmRjNTM2ZjYiLCJzaGExIjoiOGRjZmVlMjZjZjUyOTllMDBjN2QwZjJiZTc0NmVmMTlkZjliZGExNCIsInNoYTIyNCI6IjhjMzAzOTU3ZjI3NDNiMjUwNmQyYzIzY2VmNmU4MTQ5MTllZmE2MWM0MTFiMDk5ZmMzODc2MmRjIiwic2hhMjU2IjoiZDk3ZjQ2OWJjMWZkMjhjMjZkMjJhN2Y3ODczNzlhZmM4NjY3ZmZmM2FhYTQ5NTE4NmQyZTM4OTU2MTBjZDJmMyIsInNoYTM4NCI6IjY0NmY0YWM0ZDA2YWJkZmE2MDAwN2VjZWNiOWNjOTk4ZmJkOTBiYzYwMmY3NTk2M2RhZDUzMGMzNGE5ZGE1YzY4NjhlMGIwMDJkZDNlMTM4ZjhmMjA2ODcyNzFkMDVjMSIsInNoYTUxMiI6ImYzZTU4NTA0YzYyOGUwYjViNzBhOTYxYThmODA1MDA1NjQ1M2E5NDlmNTgzNDhiYTNhZTVlMjdkNDRhNGJkMjc5ZjA3MmU1OGQ5YjEyOGE1NDc1MTU2ZmM3YzcxMGJkYjI3OWQ5OGFmN2EwYTI4Y2Y1ZDY2MmQxODY4Zjg3ZjI3IiwiYmxha2UyYiI6ImFhNjgyYmJjM2U1ZGRjNDZkNWUxN2VjMzRlNmEzZGY5ZjhiNWQyNzk0YTZkNmY0M2VjODMxZjhjOTU2OGYyY2RiOGE4YjAyNTE4MDA4YmY0Y2FhYTlhY2FhYjNkNzRmZmRiNGZlNDgwOTcwODU3OGJiZjNlNzJjYTc5ZDQwYzZmIiwiYmxha2UycyI6ImQ0ZGJlZTJkMmZlNDMwOGViYTkwMTY1MDdmMzI1ZmJiODZlMWQzNDQ0MjgzNzRlMjAwNjNiNWQ1MzkzZTExNjMiLCJzaGEzXzIyNCI6ImE1ZTM5NWZlNGRlYjIyY2JhNjgwMWFiZTliZjljMjM2YmMzYjkwZDdiN2ZjMTRhZDhjZjQ0NzBlIiwic2hhM18yNTYiOiIwOWYwZGVjODk0OWEzYmQzYzU3N2RjYzUyMTMwMGRiY2UwMjVjM2VjOTJkNzQ0MDJkNTE1ZDA4NTQwODg2NGY1Iiwic2hhM18zODQiOiJmMjEyNmM5NTcxODQ3NDZmNjYyMjE4MTRkMDZkZWQ3NDBhYWU3MDA4MTc0YjI0OTEzY2YwOTQzY2IwMTA5Y2QxNWI4YmMwOGY1YjUwMWYwYzhhOTY4MzUwYzgzY2I1ZWUiLCJzaGEzXzUxMiI6ImU1ZmEwMzIwMzk2YTJjMThjN2UxZjVlZmJiODYwYTU1M2NlMTlkMDQ0MWMxNWEwZTI1M2RiNjJkM2JmNjg0ZDI1OWIxYmQ4OTJkYTcyMDVjYTYyODQ2YzU0YWI1ODYxOTBmNDUxZDlmZmNkNDA5YmU5MzlhNWM1YWIyZDdkM2ZkIiwic2hha2VfMTI4IjoiNGI2MTllM2I4N2U1YTY4OTgxMjk0YzgzMmU0NzljZGI4MWFmODdlZTE4YzM1Zjc5ZjExODY5ZWEzNWUxN2I3MiIsInNoYWtlXzI1NiI6ImYzOWVkNmMxZmQ2NzVmMDg3ODAyYTc4ZTUwYWFkN2ZiYTZiM2QxNzhlZWYzMjRkMTI3ZTZjYmEwMGRjNzkwNTkxNjQ1Y2U1Y2NmMjhjYzVkNWRkODU1OWIzMDMxYTM3ZjE5NjhmYmFhNDQzMmI2ZWU0Yzg3ZWE2YTdkMmE2NWM2In0K",
"eyJibGFrZTNfbXVsdGkiOiJhNDRiZjJkMzVkZDI3OTZlZTI1NmY0MzVkODFhNTdhOGM0MjZhMzM5ZDc3NTVkMmNiMjdmMzU4ZjM0NTM4OWM2IiwiYmxha2UzX3NpbmdsZSI6ImE0NGJmMmQzNWRkMjc5NmVlMjU2ZjQzNWQ4MWE1N2E4YzQyNmEzMzlkNzc1NWQyY2IyN2YzNThmMzQ1Mzg5YzYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiOGU5OTMzMzEyZjg4NDY4MDg0ZmRiZWNjNDYyMTMxZTgiLCJzaGExIjoiNmI0MmZjZDFmMmQyNzUwYWNkY2JkMTUzMmQ4NjQ5YTM1YWI2NDYzNCIsInNoYTIyNCI6ImQ2Y2E2OTUxNzIzZjdjZjg0NzBjZWRjMmVhNjA2ODNmMWU4NDMzM2Q2NDM2MGIzOWIyMjZlZmQzIiwic2hhMjU2IjoiMDAxNGY5Yzg0YjcwMTFhMGJkNzliNzU0NGVjNzg4NDQzNWQ4ZGY0NmRjMDBiNDk0ZmFkYzA4NWQzNDM1NjI4MyIsInNoYTM4NCI6IjMxODg2OTYxODc4NWY3MWJlM2RlZjkyZDgyNzY2NjBhZGE0MGViYTdkMDk1M2Y0YTc5ODdlMThhNzFlNjBlY2EwY2YyM2YwMjVhMmQ4ZjUyMmNkZGY3MTcxODFhMTQxNSIsInNoYTUxMiI6IjdmZGQxN2NmOWU3ZTBhZDcwMzJjMDg1MTkyYWMxZmQ0ZmFhZjZkNWNlYzAzOTE5ZDk0MmZiZTIyNWNhNmIwZTg0NmQ4ZGI0ZjllYTQ5MjJlMTdhNTg4MTY4YzExMTM1NWZiZDQ1NTlmMmU5NDcwNjAwZWE1MzBhMDdiMzY0YWQwIiwiYmxha2UyYiI6IjI0ZjExZWI5M2VlN2YxOTI5NWZiZGU5MTczMmE0NGJkZGYxOWE1ZTQ4MWNmOWFhMjQ2M2UzNDllYjg0Mzc4ZDBkODFjNzY0YWQ1NTk1YjkxZjQzYzgxODcxNTRlYWU5NTZkY2ZjZTlkMWU2MTZjNTFkZThhZDZjZTBhODcyY2Q0IiwiYmxha2UycyI6IjVkZTUwZDUwMGYwYTBmOGRlMTEwOGE2ZmFkZGM4ODNlMTA3NmQ3MThiNmQxN2E4ZDVkMjgzZDdiNGYzZDU2OGEiLCJzaGEzXzIyNCI6IjFhNTA0OGNlYWZiYjg2ZDc4ZmNiNTI0ZTViYTc4NWQ2ZmY5NzY1ZTNlMzdhZWRjZmYxZGVjNGJhIiwic2hhM18yNTYiOiI0YjA0YjE1NTRmMzRkYTlmMjBmZDczM2IzNDg4NjE0ZWNhM2IwOWU1OTJjOGJlMmM0NjA1NjYyMWU0MjJmZDllIiwic2hhM18zODQiOiI1NjMwYjM2OGQ4MGM1YmM5MTgzM2VmNWM2YWUzOTJhNDE4NTNjYmM2MWJiNTI4ZDE4YWM1OWFjZGZiZWU1YThkMWMyZDE4MTM1ZGI2ZWQ2OTJlODFkZThmYTM3MzkxN2MiLCJzaGEzXzUxMiI6IjA2ODg4MGE1MmNiNDkzODYwZDhjOTVhOTFhZGFmZTYwZGYxODc2ZDhjYjFhNmI3NTU2ZjJjM2Y1NjFmMGYwZjMyZjZhYTA1YmVmN2FhYjQ5OWEwNTM0Zjk0Njc4MDEzODlmNDc0ODFiNzcxMjdjMDFiOGFhOTY4NGJhZGUzYmY2Iiwic2hha2VfMTI4IjoiODlmYTdjNDcwNGI4NGZkMWQ1M2E0MTBlN2ZjMzU3NWRhNmUxMGU1YzkzMjM1NWYyZWEyMWM4NDVhZDBlM2UxOCIsInNoYWtlXzI1NiI6IjE4NGNlMWY2NjdmYmIyODA5NWJhZmVkZTQzNTUzZjhkYzBhNGY1MDQwYWJlMjcxMzkzMzcwNDEyZWFiZTg0ZGJhNjI0Y2ZiZWE4YzUxZDU2YzkwMTM2Mjg2ODgyZmQ0Y2E3MzA3NzZjNWUzODFlYzI5MWYxYTczOTE1MDkyMTFmIn0K",
"eyJibGFrZTNfbXVsdGkiOiJhYjA2YjNmMDliNTExOTAzMTMzMzY5NDE2MTc4ZDk2ZjlkYTc3ZGEwOTgyNDJmN2VlMTVjNTNhNTRkMDZhNWVmIiwiYmxha2UzX3NpbmdsZSI6ImFiMDZiM2YwOWI1MTE5MDMxMzMzNjk0MTYxNzhkOTZmOWRhNzdkYTA5ODI0MmY3ZWUxNWM1M2E1NGQwNmE1ZWYiLCJyYW5kb20iOiJhNDQxYjE1ZmU5YTNjZjU2NjYxMTkwYTBiOTNiOWRlYzdkMDQxMjcyODhjYzg3MjUwOTY3Y2YzYjUyODk0ZDExIiwibWQ1IjoiZWY0MjcxYjU3NTQwMjU4NGQ2OTI5ZWJkMGI3Nzk5NzYiLCJzaGExIjoiMzgzNzliYWQzZjZiZjc4MmM4OTgzOGY3YWVkMzRkNDNkMzNlYWM2MSIsInNoYTIyNCI6ImQ5ZDNiMjJkYmZlY2M1NTdlODAzNjg5M2M3ZWE0N2I0NTQzYzM2NzZhMDk4NzMxMzRhNjQ0OWEwIiwic2hhMjU2IjoiMjYxZGI3NmJlMGYxMzdlZWJkYmI5OGRlYWM0ZjcyMDdiOGUxMjdiY2MyZmMwODI5OGVjZDczYjQ3MjYxNjQ1NiIsInNoYTM4NCI6IjMzMjkwYWQxYjlhMmRkYmU0ODY3MWZiMTIxNDdiZWJhNjI4MjA1MDcwY2VkNjNiZTFmNGU5YWRhMjgwYWU2ZjZjNDkzYTY2MDllMGQ2YTIzMWU2ODU5ZmIyNGZhM2FjMCIsInNoYTUxMiI6IjAzMDZhMWI1NmNiYTdjNjJiNTNmNTk4MTAwMTQ3MDQ5ODBhNGRmZTdjZjQ5NTU4ZmMyMmQxZDczZDc5NzJmZTllODk2ZWRjMmEyYTQxYWVjNjRjZjkwZGUwYjI1NGM0MDBlZTU1YzcwZjk3OGVlMzk5NmM2YzhkNTBjYTI4YTdiIiwiYmxha2UyYiI6IjY1MDZhMDg1YWQ5MGZkZjk2NGJmMGE5NTFkZmVkMTllZTc0NGVjY2EyODQzZjQzYTI5NmFjZDM0M2RiODhhMDNlNTlkNmFmMGM1YWJkNTEzMzc4MTQ5Yjg3OTExMTVmODRmMDIyZWM1M2JmNGFjNDZhZDczNWIwMmJlYTM0MDk5IiwiYmxha2UycyI6IjdlZDQ3ZWQxOTg3MTk0YWFmNGIwMjQ3MWFkNTMyMmY3NTE3ZjI0OTcwMDc2Y2NmNDkzMWI0MzYxMDU1NzBlNDAiLCJzaGEzXzIyNCI6Ijk2MGM4MDExOTlhMGUzYWExNjdiNmU2MWVkMzE2ZDUzMDM2Yjk4M2UyOThkNWI5MjZmMDc3NDlhIiwic2hhM18yNTYiOiIzYzdmYWE1ZDE3Zjk2MGYxOTI2ZjNlNGIyZjc1ZjdiOWIyZDQ4NGFhNmEwM2ViOWNlMTI4NmM2OTE2YWEyM2RlIiwic2hhM18zODQiOiI5Y2Y0NDA1NWFjYzFlYjZmMDY1YjRjODcxYTYzNTM1MGE1ZjY0ODQwM2YwYTU0MWEzYzZhNjI3N2ViZjZmYTNjYmM1YmJiNjQwMDE4OGFlMWIxMTI2OGZmMDJiMzYzZDUiLCJzaGEzXzUxMiI6ImEyZDk3ZDRlYjYxM2UwZDViYTc2OTk2MzE2MzcxOGEwNDIxZDkxNTNiNjllYjM5MDRmZjI4ODRhZDdjNGJiYmIwNGY2Nzc1OTA1YmQxNGI2NTJmZTQ1Njg0YmI5MTQ3ZjBkYWViZjAxZjIzY2MzZDhkMjIzMTE0MGUzNjI4NTE5Iiwic2hha2VfMTI4IjoiNjkwMWMwYjg1MTg5ZTkyNTJiODI3MTc5NjE2MjRlMTM0MDQ1ZjlkMmI5MzM0MzVkM2Y0OThiZWIyN2Q3N2JiNSIsInNoYWtlXzI1NiI6ImIwMjA4ZTFkNDVjZWI0ODdiZDUwNzk3MWJiNWI3MjdjN2UyYmE3ZDliNWM2ZTEyYWE5YTNhOTY5YzcyNDRjODIwZDcyNDY1ODhlZWU3Yjk4ZWM1NzhjZWIxNjc3OTkxODljMWRkMmZkMmZmYWM4MWExZDAzZDFiNjMxOGRkMjBiIn0K",
]

View File

@ -31,6 +31,7 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from ..raw_model import RawModel from ..raw_model import RawModel
@ -452,4 +453,6 @@ class ModelConfigFactory(object):
model.key = key model.key = key
if isinstance(model, CheckpointConfigBase) and timestamp is not None: if isinstance(model, CheckpointConfigBase) and timestamp is not None:
model.converted_at = timestamp model.converted_at = timestamp
if model:
validate_hash(model.hash)
return model # type: ignore return model # type: ignore

View File

@ -301,9 +301,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
new_dict: Dict[str, torch.Tensor] = {} new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items(): for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True) new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device) cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)

View File

@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader): class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models.""" """Class to load VAE models."""
@ -40,10 +39,6 @@ class VAELoader(GenericDiffusersLoader):
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase) assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path config_file = self._app_config.legacy_conf_path / config.config_path

View File

@ -10,7 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.util.util import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import ( from .config import (
AnyModelConfig, AnyModelConfig,
@ -461,8 +461,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase): class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
# I can't find any standalone 2.X VAEs to test with! # VAEs of all base types have the same structure, so we wimp out and
return BaseModelType.StableDiffusion1 # guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
class LoRACheckpointProbe(CheckpointProbeBase): class LoRACheckpointProbe(CheckpointProbeBase):

View File

@ -67,7 +67,7 @@ class ModelPatcher:
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> Generator[None, None, None]:
with cls.apply_lora( with cls.apply_lora(
unet, unet,
loras=loras, loras=loras,
@ -83,7 +83,7 @@ class ModelPatcher:
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield yield
@ -95,7 +95,7 @@ class ModelPatcher:
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str, prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]: ) -> Generator[None, None, None]:
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
@ -139,12 +139,12 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to # We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'. # same thing in a single call to '.to(...)'.
layer.to(device=device) layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu")) layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
@ -153,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape) layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype) module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit yield # wait for context manager exit
@ -161,7 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight) model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod @classmethod
@contextmanager @contextmanager

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import onnx import onnx
import torch
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime import InferenceSession, SessionOptions, get_available_providers
@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
# return self.io_binding.copy_outputs_to_cpu() # return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs) return self.session.run(None, inputs)
# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
# compatability with diffusers load code # compatability with diffusers load code
@classmethod @classmethod
def from_pretrained( def from_pretrained(

View File

@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes. that adds additional methods and attributes.
""" """
from abc import ABC, abstractmethod
from typing import Optional
class RawModel: import torch
"""Base class for 'Raw' model wrappers."""
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

View File

@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
return result return result
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
class TextualInversionManager(BaseTextualInversionManager): class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""

View File

@ -1,29 +1,36 @@
"""Context class to silence transformers and diffusers warnings."""
import warnings import warnings
from typing import Any from contextlib import ContextDecorator
from diffusers import logging as diffusers_logging from diffusers.utils import logging as diffusers_logging
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
class SilenceWarnings(object): # Inherit from ContextDecorator to allow using SilenceWarnings as both a context manager and a decorator.
"""Use in context to temporarily turn off warnings from transformers & diffusers modules. class SilenceWarnings(ContextDecorator):
"""A context manager that disables warnings from transformers & diffusers modules while active.
As context manager:
```
with SilenceWarnings(): with SilenceWarnings():
# do something # do something
```
As decorator:
```
@SilenceWarnings()
def some_function():
# do something
```
""" """
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None: def __enter__(self) -> None:
self._transformers_verbosity = transformers_logging.get_verbosity()
self._diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error() transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
def __exit__(self, *args: Any) -> None: def __exit__(self, *args) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity) transformers_logging.set_verbosity(self._transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity) diffusers_logging.set_verbosity(self._diffusers_verbosity)
warnings.simplefilter("default") warnings.simplefilter("default")

View File

@ -3,12 +3,9 @@ import io
import os import os
import re import re
import unicodedata import unicodedata
import warnings
from pathlib import Path from pathlib import Path
from diffusers import logging as diffusers_logging
from PIL import Image from PIL import Image
from transformers import logging as transformers_logging
# actual size of a gig # actual size of a gig
GIG = 1073741824 GIG = 1073741824
@ -80,21 +77,3 @@ class Chdir(object):
def __exit__(self, *args): def __exit__(self, *args):
os.chdir(self.original) os.chdir(self.original)
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, type, value, traceback):
"""Restore logger verbosity to state before context was entered."""
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")

View File

@ -5,15 +5,86 @@ import {
socketModelInstallCancelled, socketModelInstallCancelled,
socketModelInstallComplete, socketModelInstallComplete,
socketModelInstallDownloadProgress, socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallDownloadStarted,
socketModelInstallError, socketModelInstallError,
socketModelInstallStarted,
} from 'services/events/actions'; } from 'services/events/actions';
/**
* A model install has two main stages - downloading and installing. All these events are namespaced under `model_install_`
* which is a bit misleading. For example, a `model_install_started` event is actually fired _after_ the model has fully
* downloaded and is being "physically" installed.
*
* Note: the download events are only fired for remote model installs, not local.
*
* Here's the expected flow:
* - API receives install request, model manager preps the install
* - `model_install_download_started` fired when the download starts
* - `model_install_download_progress` fired continually until the download is complete
* - `model_install_download_complete` fired when the download is complete
* - `model_install_started` fired when the "physical" installation starts
* - `model_install_complete` fired when the installation is complete
* - `model_install_cancelled` fired if the installation is cancelled
* - `model_install_error` fired if the installation has an error
*/
const selectModelInstalls = modelsApi.endpoints.listModelInstalls.select();
export const addModelInstallEventListener = (startAppListening: AppStartListening) => { export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({ startAppListening({
actionCreator: socketModelInstallDownloadProgress, actionCreator: socketModelInstallDownloadStarted,
effect: async (action, { dispatch }) => { effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloading';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallStarted,
effect: async (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'running';
}
return draft;
})
);
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
effect: async (action, { dispatch, getState }) => {
const { bytes, total_bytes, id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -25,14 +96,20 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallComplete, actionCreator: socketModelInstallComplete,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -42,6 +119,8 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelConfig', id: LIST_TAG }]));
dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }])); dispatch(api.util.invalidateTags([{ type: 'ModelScanFolderResults', id: LIST_TAG }]));
}, },
@ -49,9 +128,13 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
startAppListening({ startAppListening({
actionCreator: socketModelInstallError, actionCreator: socketModelInstallError,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id, error, error_type } = action.payload.data; const { id, error, error_type } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -63,14 +146,19 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
}, },
}); });
startAppListening({ startAppListening({
actionCreator: socketModelInstallCancelled, actionCreator: socketModelInstallCancelled,
effect: (action, { dispatch }) => { effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data; const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch( dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => { modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id); const modelImport = draft.find((m) => m.id === id);
@ -80,6 +168,29 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
return draft; return draft;
}) })
); );
}
},
});
startAppListening({
actionCreator: socketModelInstallDownloadsComplete,
effect: (action, { dispatch, getState }) => {
const { id } = action.payload.data;
const { data } = selectModelInstalls(getState());
if (!data || !data.find((m) => m.id === id)) {
dispatch(api.util.invalidateTags([{ type: 'ModelInstalls' }]));
} else {
dispatch(
modelsApi.util.updateQueryData('listModelInstalls', undefined, (draft) => {
const modelImport = draft.find((m) => m.id === id);
if (modelImport) {
modelImport.status = 'downloads_done';
}
return draft;
})
);
}
}, },
}); });
}; };

File diff suppressed because one or more lines are too long

View File

@ -16,6 +16,7 @@ import type {
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent, ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallErrorEvent, ModelInstallErrorEvent,
ModelInstallStartedEvent, ModelInstallStartedEvent,
ModelLoadCompleteEvent, ModelLoadCompleteEvent,
@ -45,6 +46,9 @@ export const socketModelInstallStarted = createSocketAction<ModelInstallStartedE
export const socketModelInstallDownloadProgress = createSocketAction<ModelInstallDownloadProgressEvent>( export const socketModelInstallDownloadProgress = createSocketAction<ModelInstallDownloadProgressEvent>(
'ModelInstallDownloadProgressEvent' 'ModelInstallDownloadProgressEvent'
); );
export const socketModelInstallDownloadStarted = createSocketAction<ModelInstallDownloadStartedEvent>(
'ModelInstallDownloadStartedEvent'
);
export const socketModelInstallDownloadsComplete = createSocketAction<ModelInstallDownloadsCompleteEvent>( export const socketModelInstallDownloadsComplete = createSocketAction<ModelInstallDownloadsCompleteEvent>(
'ModelInstallDownloadsCompleteEvent' 'ModelInstallDownloadsCompleteEvent'
); );

View File

@ -9,6 +9,7 @@ export type InvocationCompleteEvent = S['InvocationCompleteEvent'];
export type InvocationErrorEvent = S['InvocationErrorEvent']; export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image']; export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelInstallDownloadStartedEvent = S['ModelInstallDownloadStartedEvent'];
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent']; export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent']; export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent']; export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
@ -49,6 +50,7 @@ export type ServerToClientEvents = {
download_error: (payload: DownloadErrorEvent) => void; download_error: (payload: DownloadErrorEvent) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void;
model_install_started: (payload: ModelInstallStartedEvent) => void; model_install_started: (payload: ModelInstallStartedEvent) => void;
model_install_download_started: (payload: ModelInstallDownloadStartedEvent) => void;
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void; model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void; model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
model_install_complete: (payload: ModelInstallCompleteEvent) => void; model_install_complete: (payload: ModelInstallCompleteEvent) => void;

View File

@ -17,6 +17,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent, ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent, ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent, ModelInstallDownloadsCompleteEvent,
ModelInstallDownloadStartedEvent,
ModelInstallStartedEvent, ModelInstallStartedEvent,
) )
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
@ -252,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists() assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 5 assert len(bus.events) == 5
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts assert isinstance(bus.events[0], ModelInstallDownloadStartedEvent) # download starts
assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses
assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed
assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started