Update onnx model structure, change code according

This commit is contained in:
Sergey Borisov 2023-06-22 20:03:17 +03:00
parent 7759b3f75a
commit 6c7668aaca
3 changed files with 197 additions and 148 deletions

View File

@ -63,20 +63,26 @@ class ONNXPromptInvocation(BaseInvocation):
text_encoder_info as text_encoder,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
#stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
#)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
@ -218,7 +224,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
with unet_info as unet,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:

View File

@ -12,6 +12,7 @@ from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel
from onnx import numpy_helper
from onnxruntime import OrtValue
import numpy as np
from compel.embeddings_provider import BaseTextualInversionManager
@ -718,8 +719,7 @@ class ONNXModelPatcher:
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
base_model = model.proto
orig_nodes = dict()
orig_weights = dict()
try:
blended_loras = dict()
@ -736,68 +736,49 @@ class ONNXModelPatcher:
else:
blended_loras[layer_key] = layer_weight
initializer_idx = dict()
for idx, init in enumerate(base_model.graph.initializer):
initializer_idx[init.name.replace(".", "_")] = idx
node_names = dict()
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
node_idx = dict()
for idx, node in enumerate(base_model.graph.node):
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
for layer_key, weights in blended_loras.items():
for layer_key, lora_weight in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_idx or gemm_key in node_idx:
if conv_key in node_idx:
conv_node = base_model.graph.node[node_idx[conv_key]]
if conv_key in node_names or gemm_key in node_names:
if conv_key in node_names:
conv_node = model.nodes[node_names[conv_key]]
else:
conv_node = base_model.graph.node[node_idx[gemm_key]]
conv_node = model.nodes[node_names[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = weight_name.replace(".", "_")
orig_weight = model.tensors[weight_name]
weight_idx = initializer_idx[weight_name]
weight_node = base_model.graph.initializer[weight_idx]
orig_weights = numpy_helper.to_array(weight_node)
if orig_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
if orig_weight.shape[-2:] == (1, 1):
if lora_weight.shape[-2:] == (1, 1):
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
else:
new_weights = orig_weights.squeeze((3, 2)) + weights
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
new_weights = np.expand_dims(new_weights, (2, 3))
new_weight = np.expand_dims(new_weight, (2, 3))
else:
if orig_weights.shape != weights.shape:
new_weights = orig_weights + weights.reshape(orig_weights.shape)
if orig_weight.shape != lora_weight.shape:
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
else:
new_weights = orig_weights + weights
new_weight = orig_weight + lora_weight
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, new_node)
elif matmul_key in node_idx:
weight_node = base_model.graph.node[node_idx[matmul_key]]
orig_weights[weight_name] = orig_weight
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
elif matmul_key in node_names:
weight_node = model.nodes[node_names[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
matmul_idx = initializer_idx[matmul_name]
matmul_node = base_model.graph.initializer[matmul_idx]
orig_weight = model.tensors[matmul_name]
new_weight = orig_weight + lora_weight.transpose()
orig_weights = numpy_helper.to_array(matmul_node)
new_weights = orig_weights + weights.transpose()
# replace the original initializer
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, new_node)
orig_weights[matmul_name] = orig_weight
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
else:
# warn? err?
@ -807,9 +788,8 @@ class ONNXModelPatcher:
finally:
# restore original weights
for idx, orig_node in orig_nodes.items():
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, orig_node)
for name, orig_weight in orig_weights.items():
model.tensors[name] = orig_weight
@ -825,8 +805,7 @@ class ONNXModelPatcher:
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
init_tokens_count = None
new_tokens_added = None
orig_embeddings = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
@ -845,17 +824,15 @@ class ONNXModelPatcher:
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
for i in range(len(text_encoder.proto.graph.initializer)):
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
embeddings_node_idx = i
break
else:
raise Exception("text_model.embeddings.token_embedding.weight node not found")
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
base_weights = numpy_helper.to_array(embeddings_node_orig)
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
embeddings = np.concatenate(
(
np.copy(orig_embeddings),
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
),
axis=0,
)
for ti in ti_list:
ti_tokens = []
@ -867,26 +844,22 @@ class ONNXModelPatcher:
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embedding_weights[token_id].shape != embedding.shape:
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
)
embedding_weights[token_id] = embedding
embeddings[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
yield ti_tokenizer, ti_manager
finally:
# restore
if embeddings_node_orig is not None:
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)
if orig_embeddings is not None:
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings

View File

@ -426,27 +426,127 @@ class SilenceWarnings(object):
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
def buffer_external_data_tensors(model):
external_data = dict()
for tensor in model.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data[name] = orv
set_external_data(tensor, location="tmp.bin")
tensor.name = name
tensor.ClearField("raw_data")
return (model, external_data)
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
def __init__(self, model: tuple, **kwargs):
self.proto, self.provider, self.sess_options = model
class IAIOnnxRuntimeModel:
class _tensor_access:
def __init__(self, model):
self.model = model
self.indexes = dict()
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.model.data[key].numpy()
def __setitem__(self, key: str, value: np.ndarray):
new_node = numpy_helper.from_array(value)
set_external_data(new_node, location="in-memory-location")
new_node.name = key
new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__
def __contains__(self, key: str):
return key in self.model.data
def items(self):
raise NotImplementedError("tensor.items")
#return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.model.data.keys()
def values(self):
raise NotImplementedError("tensor.values")
#return [obj for obj in self.raw_proto]
class _access_helper:
def __init__(self, raw_proto):
self.indexes = dict()
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[obj.name] = idx
def __getitem__(self, key: str):
return self.raw_proto[self.indexes[key]]
def __setitem__(self, key: str, value):
index = self.indexes[key]
del self.raw_proto[index]
self.raw_proto.insert(index, value)
# __delitem__
def __contains__(self, key: str):
return key in self.indexes
def items(self):
return [(obj.name, obj) for obj in self.raw_proto]
def keys(self):
return self.indexes.keys()
def values(self):
return [obj for obj in self.raw_proto]
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path
self.session = None
self._external_data = dict()
self.provider = provider or "CPUExecutionProvider"
"""
self.data_path = self.path + "_data"
if not os.path.exists(self.data_path):
print(f"Moving model tensors to separate file: {self.data_path}")
tmp_proto = onnx.load(model_path, load_external_data=True)
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
del tmp_proto
gc.collect()
self.proto = onnx.load(model_path, load_external_data=False)
"""
self.proto = onnx.load(model_path, load_external_data=True)
self.data = dict()
for tensor in self.proto.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
self.data[name] = orv
set_external_data(tensor, location="in-memory-location")
tensor.name = name
tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node)
self.initializers = self._access_helper(self.proto.graph.initializer)
self.tensors = self._tensor_access(self)
# TODO: integrate with model manager/cache
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
#self._external_data.update(**external_data)
sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
self.session = InferenceSession(self.proto.SerializeToString(), providers=[self.provider], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
def release_session(self):
self.session = None
import gc
gc.collect()
def __call__(self, **kwargs):
if self.session is None:
@ -455,63 +555,32 @@ class IAIOnnxRuntimeModel(OnnxRuntimeModel):
inputs = {k: np.array(v) for k, v in kwargs.items()}
return self.session.run(None, inputs)
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
self._external_data.update(**external_data)
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
def release_session(self):
self.session = None
import gc
gc.collect()
@staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
Arguments:
path (`str` or `Path`):
Directory from which to load
provider(`str`, *optional*):
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
"""
if provider is None:
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
print("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
# TODO: check that provider available?
return (onnx.load(path), provider, sess_options)
# compatability with diffusers load code
@classmethod
def _from_pretrained(
def from_pretrained(
cls,
model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: Union[str, Path] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
file_name = file_name or ONNX_WEIGHTS_NAME
if os.path.isdir(model_id):
model_path = model_id
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
model_path = os.path.join(model_path, file_name)
else:
model_path = model_id
# load model from local directory
if not os.path.isdir(model_id):
raise Exception(f"Model not found: {model_id}")
model = IAIOnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
return cls(model=model, **kwargs)
if not os.path.isfile(model_path):
raise Exception(f"Model not found: {model_path}")
# TODO: session options
return cls(model_path, provider=provider)