mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into stalker7779/backend_base
This commit is contained in:
@ -98,7 +98,7 @@ class UnetSkipConnectionBlock(nn.Module):
|
||||
"""
|
||||
super(UnetSkipConnectionBlock, self).__init__()
|
||||
self.outermost = outermost
|
||||
if type(norm_layer) == functools.partial:
|
||||
if isinstance(norm_layer, functools.partial):
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
@ -124,16 +124,14 @@ class IPAdapter(RawModel):
|
||||
self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
def to(
|
||||
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
|
@ -11,7 +11,6 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
@ -57,14 +56,9 @@ class LoRALayerBase:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
@ -106,19 +100,14 @@ class LoRALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
@ -167,23 +156,18 @@ class LoHALayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
@ -264,12 +248,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
@ -277,19 +256,19 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
@ -319,15 +298,10 @@ class FullLayer(LoRALayerBase):
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@ -359,16 +333,11 @@ class IA3Layer(LoRALayerBase):
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
@ -390,15 +359,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@ -521,7 +485,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
|
@ -289,11 +289,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(
|
||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||
)
|
||||
new_dict[k] = v.to(target_device, copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class SpandrelImageToImageModelLoader(ModelLoader):
|
||||
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("Unexpected submodel requested for Spandrel model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
|
||||
torch_dtype = self._torch_dtype
|
||||
if not model.supports_dtype(torch_dtype):
|
||||
self._logger.warning(
|
||||
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||
"model. Falling back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(dtype=torch_dtype)
|
||||
|
||||
return model
|
@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@ -220,24 +222,46 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
if "No such file or directory" in str(e):
|
||||
# This error is expected if the model_path does not exist (which is the case in some unit tests).
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
@ -399,6 +399,43 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.T2IAdapter,
|
||||
),
|
||||
# endregion
|
||||
# region SpandrelImageToImage
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus_anime_6B",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x4plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||
description="A Real-ESRGAN 4x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="ESRGAN_SRx4_DF2KOST_official",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||
description="The official ESRGAN 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="RealESRGAN_x2plus",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||
description="A Real-ESRGAN 2x upscaling model (general-purpose).",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
StarterModel(
|
||||
name="SwinIR - realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN",
|
||||
base=BaseModelType.Any,
|
||||
source="https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN-with-dict-keys-params-and-params_ema.pth",
|
||||
description="A SwinIR 4x upscaling model.",
|
||||
type=ModelType.SpandrelImageToImage,
|
||||
),
|
||||
# endregion
|
||||
]
|
||||
|
||||
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"
|
||||
|
@ -158,15 +158,12 @@ class ModelPatcher:
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(
|
||||
device=TorchDevice.CPU_DEVICE,
|
||||
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
|
||||
)
|
||||
layer.to(device=TorchDevice.CPU_DEVICE)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
@ -175,7 +172,7 @@ class ModelPatcher:
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
@ -183,9 +180,7 @@ class ModelPatcher:
|
||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key, weight in original_weights.items():
|
||||
model.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
model.get_submodule(module_key).weight.copy_(weight)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
|
@ -190,12 +190,7 @@ class IAIOnnxRuntimeModel(RawModel):
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with RawModel ABC
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
||||
# compatability with diffusers load code
|
||||
|
@ -1,15 +1,3 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@ -17,13 +5,18 @@ import torch
|
||||
|
||||
|
||||
class RawModel(ABC):
|
||||
"""Abstract base class for 'Raw' model wrappers."""
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
pass
|
||||
|
139
invokeai/backend/spandrel_image_to_image_model.py
Normal file
139
invokeai/backend/spandrel_image_to_image_model.py
Normal file
@ -0,0 +1,139 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SpandrelImageToImageModel(RawModel):
|
||||
"""A wrapper for a Spandrel Image-to-Image model.
|
||||
|
||||
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
||||
"""
|
||||
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
@staticmethod
|
||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
||||
|
||||
Args:
|
||||
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
image_np = np.array(image)
|
||||
# (H, W, C) -> (C, H, W)
|
||||
image_np = np.transpose(image_np, (2, 0, 1))
|
||||
image_np = image_np / 255
|
||||
image_tensor = torch.from_numpy(image_np).float()
|
||||
# (C, H, W) -> (N, C, H, W)
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
return image_tensor
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
||||
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
|
||||
Returns:
|
||||
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
"""
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
tensor = tensor.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
tensor = tensor.clamp(0, 1)
|
||||
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
||||
image = Image.fromarray(tensor)
|
||||
return image
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model.
|
||||
|
||||
Args:
|
||||
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
@classmethod
|
||||
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
model = ModelLoader().load_from_state_dict(state_dict)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
||||
"""Check if the model supports the given dtype."""
|
||||
if dtype == torch.float16:
|
||||
return self._spandrel_model.supports_half
|
||||
elif dtype == torch.bfloat16:
|
||||
return self._spandrel_model.supports_bfloat16
|
||||
elif dtype == torch.float32:
|
||||
# All models support float32.
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||
|
||||
def get_model_type_name(self) -> str:
|
||||
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||
consistent over time.
|
||||
"""
|
||||
return str(type(self._spandrel_model.model))
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
||||
Note: The non_blocking parameter is currently ignored."""
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
|
||||
# model directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
||||
@property
|
||||
def scale(self) -> int:
|
||||
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
|
||||
return self._spandrel_model.scale
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._spandrel_model.model)
|
@ -65,17 +65,12 @@ class TextualInversionModelRaw(RawModel):
|
||||
|
||||
return result
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
for emb in [self.embedding, self.embedding_2]:
|
||||
if emb is not None:
|
||||
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
emb.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
|
@ -112,15 +112,3 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@staticmethod
|
||||
def get_non_blocking(to_device: torch.device) -> bool:
|
||||
"""Return the non_blocking flag to be used when moving a tensor to a given device.
|
||||
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
|
||||
When moving _from_ MPS, we can use non-blocking operations.
|
||||
|
||||
See:
|
||||
- https://github.com/pytorch/pytorch/issues/107455
|
||||
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
|
||||
"""
|
||||
return False if to_device.type == "mps" else True
|
||||
|
Reference in New Issue
Block a user