ONNX Model/runtime first implementation

This commit is contained in:
Sergey Borisov
2023-06-21 02:12:21 +03:00
parent 92c86fd0b8
commit 4d337f6abc
7 changed files with 935 additions and 16 deletions

View File

@ -11,6 +11,8 @@ from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from onnx import numpy_helper
import numpy as np
from compel.embeddings_provider import BaseTextualInversionManager
@ -70,7 +72,7 @@ class LoRALayerBase:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight(module)
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
@ -81,7 +83,7 @@ class LoRALayerBase:
**extra_args,
) * multiplier * scale
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
@ -122,7 +124,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
@ -185,7 +187,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -271,7 +273,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module):
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
@ -286,7 +288,7 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
weight = torch.kron(w1, w2)#.reshape(module.weight.shape) # TODO: can we remove reshape?
return weight
@ -676,3 +678,212 @@ class TextualInversionManager(BaseTextualInversionManager):
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
base_model = model.proto
orig_nodes = dict()
try:
blended_loras = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
initializer_idx = dict()
for idx, init in enumerate(base_model.graph.initializer):
initializer_idx[init.name.replace(".", "_")] = idx
node_idx = dict()
for idx, node in enumerate(base_model.graph.node):
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
for layer_key, weights in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_idx or gemm_key in node_idx:
if conv_key in node_idx:
conv_node = base_model.graph.node[node_idx[conv_key]]
else:
conv_node = base_model.graph.node[node_idx[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = weight_name.replace(".", "_")
weight_idx = initializer_idx[weight_name]
weight_node = base_model.graph.initializer[weight_idx]
orig_weights = numpy_helper.to_array(weight_node)
if orig_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
new_weights = orig_weights.squeeze((3, 2)) + weights
new_weights = np.expand_dims(new_weights, (2, 3))
else:
if orig_weights.shape != weights.shape:
new_weights = orig_weights + weights.reshape(orig_weights.shape)
else:
new_weights = orig_weights + weights
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, new_node)
elif matmul_key in node_idx:
weight_node = base_model.graph.node[node_idx[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
matmul_idx = initializer_idx[matmul_name]
matmul_node = base_model.graph.initializer[matmul_idx]
orig_weights = numpy_helper.to_array(matmul_node)
new_weights = orig_weights + weights.transpose()
# replace the original initializer
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, new_node)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for idx, orig_node in orig_nodes.items():
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, orig_node)
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
init_tokens_count = None
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
for i in range(len(text_encoder.proto.graph.initializer)):
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
embeddings_node_idx = i
break
else:
raise Exception("text_model.embeddings.token_embedding.weight node not found")
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
base_weights = numpy_helper.to_array(embeddings_node_orig)
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embedding_weights[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
)
embedding_weights[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
yield ti_tokenizer, ti_manager
finally:
# restore
if embeddings_node_orig is not None:
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)

View File

@ -9,9 +9,12 @@ from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model,
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
@ -19,6 +22,7 @@ MODEL_CLASSES = {
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,

View File

@ -5,19 +5,27 @@ import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import numpy as np
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
@ -29,6 +37,8 @@ class SubModelType(str, Enum):
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
@ -240,16 +250,18 @@ class DiffusersModel(ModelBase):
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
os.path.join(self.model_path, child_type.value),
#subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
print("====ERR LOAD====")
print(f"{variant}: {e}")
import traceback
traceback.print_exc()
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@ -413,3 +425,92 @@ class SilenceWarnings(object):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
def buffer_external_data_tensors(model):
external_data = dict()
for tensor in model.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data[name] = orv
set_external_data(tensor, location="tmp.bin")
tensor.name = name
tensor.ClearField("raw_data")
return (model, external_data)
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
def __init__(self, model: tuple, **kwargs):
self.proto, self.provider, self.sess_options = model
self.session = None
self._external_data = dict()
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
return self.session.run(None, inputs)
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
self._external_data.update(**external_data)
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
def release_session(self):
self.session = None
import gc
gc.collect()
@staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
Arguments:
path (`str` or `Path`):
Directory from which to load
provider(`str`, *optional*):
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
"""
if provider is None:
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
print("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
# TODO: check that provider available?
return (onnx.load(path), provider, sess_options)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory
if not os.path.isdir(model_id):
raise Exception(f"Model not found: {model_id}")
model = IAIOnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
return cls(model=model, **kwargs)

View File

@ -0,0 +1,156 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path