mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added textual inversion and lora loaders
This commit is contained in:
committed by
psychedelicious
parent
67eb715093
commit
0d3addc69b
@ -28,9 +28,11 @@ from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from .onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
|
||||
|
@ -10,11 +10,17 @@ from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
@ -160,4 +166,3 @@ class ModelLoader(ModelLoaderBase):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||
|
||||
return "\n"+msg if len(msg)>0 else msg
|
||||
return "\n" + msg if len(msg) > 0 else msg
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Init file for RamCache."""
|
||||
|
||||
from .model_cache_base import ModelCacheBase
|
||||
from .model_cache_default import ModelCache
|
||||
_all__ = ["ModelCacheBase", "ModelCache"]
|
||||
|
@ -14,8 +14,10 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
@ -37,7 +39,7 @@ class ControlnetLoader(GenericDiffusersLoader):
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert hasattr(config, 'config')
|
||||
assert hasattr(config, "config")
|
||||
config_file = config.config
|
||||
|
||||
if weights_path.suffix == ".safetensors":
|
||||
|
@ -15,6 +15,7 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
|
@ -1,11 +1,11 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for IP Adapter model loading in InvokeAI."""
|
||||
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
@ -18,6 +18,7 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
class IPAdapterInvokeAILoader(ModelLoader):
|
||||
"""Class to load IP Adapter diffusers models."""
|
||||
@ -36,4 +37,3 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
|
@ -2,13 +2,12 @@
|
||||
"""Class for LoRA model loading in InvokeAI."""
|
||||
|
||||
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.embeddings.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -18,9 +17,11 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
@ -47,6 +48,7 @@ class LoraLoader(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a LoRA model.")
|
||||
assert self._model_base is not None
|
||||
model = LoRAModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
@ -56,9 +58,11 @@ class LoraLoader(ModelLoader):
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
|
||||
self._model_base = (
|
||||
config.base
|
||||
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
|
||||
|
||||
model_base_path = self._app_config.models_path
|
||||
model_path = model_base_path / config.path
|
||||
@ -72,5 +76,3 @@ class LoraLoader(ModelLoader):
|
||||
|
||||
result = model_path.resolve(), config, submodel_type
|
||||
return result
|
||||
|
||||
|
||||
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for TI model loading in InvokeAI."""
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder)
|
||||
class TextualInversionLoader(ModelLoader):
|
||||
"""Class to load TI models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("There are no submodels in a TI model.")
|
||||
model = TextualInversionModelRaw.from_checkpoint(
|
||||
file_path=model_path,
|
||||
dtype=self._torch_dtype,
|
||||
)
|
||||
return model
|
||||
|
||||
# override
|
||||
def _get_model_path(
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
|
||||
model_path = self._app_config.models_path / config.path
|
||||
|
||||
if config.format == ModelFormat.EmbeddingFolder:
|
||||
path = model_path / "learned_embeds.bin"
|
||||
else:
|
||||
path = model_path
|
||||
|
||||
if not path.exists():
|
||||
raise OSError(f"The embedding file at {path} was not found")
|
||||
|
||||
return path, config, submodel_type
|
@ -15,6 +15,7 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
|
@ -3,13 +3,13 @@
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
|
||||
def calc_model_size_by_data(model: AnyModel) -> int:
|
||||
|
@ -1,620 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development team
|
||||
"""LoRA model support."""
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union, List, Tuple
|
||||
from typing_extensions import Self
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
# bias: Optional[torch.Tensor]
|
||||
# layer_key: str
|
||||
|
||||
# @property
|
||||
# def scale(self):
|
||||
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
if "alpha" in values:
|
||||
self.alpha = values["alpha"].item()
|
||||
else:
|
||||
self.alpha = None
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
)
|
||||
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
# mid: Optional[torch.Tensor]
|
||||
# down: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
||||
else:
|
||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
# w2_a: torch.Tensor
|
||||
# w2_b: torch.Tensor
|
||||
# t1: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor]
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.w1_a = values["hada_w1_a"]
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
# w1_b: Optional[torch.Tensor] = None
|
||||
# w2: Optional[torch.Tensor] = None
|
||||
# w2_a: Optional[torch.Tensor] = None
|
||||
# w2_b: Optional[torch.Tensor] = None
|
||||
# t2: Optional[torch.Tensor] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
assert w1 is not None
|
||||
assert w2 is not None
|
||||
weight = torch.kron(w1, w2)
|
||||
|
||||
return weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
if val is not None:
|
||||
model_size += val.nelement() * val.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
if self.w1 is not None:
|
||||
self.w1 = self.w1.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: Dict[str, torch.Tensor],
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
layers: Dict[str, AnyLoRALayer],
|
||||
):
|
||||
self._name = name
|
||||
self.layers = layers
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for _, layer in self.layers.items():
|
||||
model_size += layer.calc_size()
|
||||
return model_size
|
||||
|
||||
@classmethod
|
||||
def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||
|
||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||
diffusers format, then this function will have no effect.
|
||||
|
||||
This function is adapted from:
|
||||
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
|
||||
|
||||
Args:
|
||||
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
|
||||
|
||||
Raises:
|
||||
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: The diffusers-format state_dict.
|
||||
"""
|
||||
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
|
||||
not_converted_count = 0 # The number of keys that were not converted.
|
||||
|
||||
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
|
||||
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
|
||||
# `input_blocks_4_1_proj_in`.
|
||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||
stability_unet_keys.sort()
|
||||
|
||||
new_state_dict = {}
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_unet_"):
|
||||
search_key = full_key.replace("lora_unet_", "")
|
||||
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
|
||||
position = bisect.bisect_right(stability_unet_keys, search_key)
|
||||
map_key = stability_unet_keys[position - 1]
|
||||
# Now, check if the map_key *actually* matches the search_key.
|
||||
if search_key.startswith(map_key):
|
||||
new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
|
||||
new_state_dict[new_key] = value
|
||||
converted_count += 1
|
||||
else:
|
||||
new_state_dict[full_key] = value
|
||||
not_converted_count += 1
|
||||
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
|
||||
new_state_dict[full_key] = value
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
|
||||
|
||||
if converted_count > 0 and not_converted_count > 0:
|
||||
raise ValueError(
|
||||
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
|
||||
f" not_converted={not_converted_count}"
|
||||
)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
) -> Self:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||
|
||||
for key, value in state_dict.items():
|
||||
stem, leaf = key.split(".", 1)
|
||||
if stem not in state_dict_groupped:
|
||||
state_dict_groupped[stem] = {}
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]:
|
||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
for i in range(3): # num_blocks is 3 in sdxl
|
||||
# loop over downblocks/upblocks
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
# if i > 0: commentout for sdxl
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
unet_conversion_map_resnet = [
|
||||
# (stable-diffusion, HF Diffusers)
|
||||
("in_layers.0.", "norm1."),
|
||||
("in_layers.2.", "conv1."),
|
||||
("out_layers.0.", "norm2."),
|
||||
("out_layers.3.", "conv2."),
|
||||
("emb_layers.1.", "time_emb_proj."),
|
||||
("skip_connection.", "conv_shortcut."),
|
||||
]
|
||||
|
||||
unet_conversion_map = []
|
||||
for sd, hf in unet_conversion_map_layer:
|
||||
if "resnets" in hf:
|
||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
||||
else:
|
||||
unet_conversion_map.append((sd, hf))
|
||||
|
||||
for j in range(2):
|
||||
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
||||
sd_time_embed_prefix = f"time_embed.{j*2}."
|
||||
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
||||
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
||||
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
||||
|
||||
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
||||
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
||||
unet_conversion_map.append(("out.2.", "conv_out."))
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
||||
sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
@ -1,216 +0,0 @@
|
||||
# Copyright (c) 2024 The InvokeAI Development Team
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
|
||||
|
||||
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||
# I have not yet gone through and fixed all the type hints
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
def __init__(self, model): # type: ignore
|
||||
self.model = model
|
||||
self.indexes = {}
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str): # type: ignore
|
||||
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||
return numpy_helper.to_array(value)
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray): # type: ignore
|
||||
new_node = numpy_helper.from_array(value)
|
||||
# set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
# new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return self.indexes[key] in self.model.proto.graph.initializer
|
||||
|
||||
def items(self) -> List[Tuple[str, Any]]: # fixme
|
||||
raise NotImplementedError("tensor.items")
|
||||
# return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self.indexes.keys())
|
||||
|
||||
def values(self) -> List[Any]: # fixme
|
||||
raise NotImplementedError("tensor.values")
|
||||
# return [obj for obj in self.raw_proto]
|
||||
|
||||
def size(self) -> int:
|
||||
bytesSum = 0
|
||||
for node in self.model.proto.graph.initializer:
|
||||
bytesSum += sys.getsizeof(node.raw_data)
|
||||
return bytesSum
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto): # type: ignore
|
||||
self.indexes = {}
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str): # type: ignore
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value): # type: ignore
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self.indexes
|
||||
|
||||
def items(self) -> List[Tuple[str, Any]]:
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self.indexes.keys())
|
||||
|
||||
def values(self) -> List[Any]: # fixme
|
||||
return list(self.raw_proto)
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
# self.data = dict()
|
||||
# for tensor in self.proto.graph.initializer:
|
||||
# name = tensor.name
|
||||
|
||||
# if tensor.HasField("raw_data"):
|
||||
# npt = numpy_helper.to_array(tensor)
|
||||
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
# # self.data[name] = orv
|
||||
# # set_external_data(tensor, location="in-memory-location")
|
||||
# tensor.name = name
|
||||
# # tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node) # type: ignore
|
||||
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self) # type: ignore
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self, height=None, width=None):
|
||||
if self.session is None or self.session_width != width or self.session_height != height:
|
||||
# onnx.save(self.proto, "tmp.onnx")
|
||||
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
# self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
# sess.enable_profiling = True
|
||||
|
||||
# sess.intra_op_num_threads = 1
|
||||
# sess.inter_op_num_threads = 1
|
||||
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.enable_cpu_mem_arena = True
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
self.session_height = height
|
||||
self.session_width = width
|
||||
if height and width:
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
providers = []
|
||||
if self.provider:
|
||||
providers.append(self.provider)
|
||||
else:
|
||||
providers = get_available_providers()
|
||||
if "TensorrtExecutionProvider" in providers:
|
||||
providers.remove("TensorrtExecutionProvider")
|
||||
try:
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||
except Exception as e:
|
||||
raise e
|
||||
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
# self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
return
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
# output_names = self.session.get_outputs()
|
||||
# for k in inputs:
|
||||
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
# for name in output_names:
|
||||
# self.io_binding.bind_output(name.name)
|
||||
# self.session.run_with_iobinding(self.io_binding, None)
|
||||
# return self.io_binding.copy_outputs_to_cpu()
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
subfolder: Optional[Union[str, Path]] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any: # fixme
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(str(model_path), provider=provider)
|
Reference in New Issue
Block a user