mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update onnx model structure, change code according
This commit is contained in:
parent
7759b3f75a
commit
6c7668aaca
@ -63,20 +63,26 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
text_encoder_info as text_encoder,\
|
||||
ExitStack() as stack:
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
)
|
||||
)
|
||||
#stack.enter_context(
|
||||
# context.services.model_manager.get_model(
|
||||
# model_name=name,
|
||||
# base_model=self.clip.text_encoder.base_model,
|
||||
# model_type=ModelType.TextualInversion,
|
||||
# )
|
||||
#)
|
||||
context.services.model_manager.get_model(
|
||||
model_name=name,
|
||||
base_model=self.clip.text_encoder.base_model,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).context.model
|
||||
)
|
||||
except Exception:
|
||||
#print(e)
|
||||
@ -218,7 +224,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
with unet_info as unet,\
|
||||
ExitStack() as stack:
|
||||
|
||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
|
@ -12,6 +12,7 @@ from torch.utils.hooks import RemovableHandle
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import OrtValue
|
||||
import numpy as np
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
@ -718,8 +719,7 @@ class ONNXModelPatcher:
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
base_model = model.proto
|
||||
orig_nodes = dict()
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
@ -736,68 +736,49 @@ class ONNXModelPatcher:
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
initializer_idx = dict()
|
||||
for idx, init in enumerate(base_model.graph.initializer):
|
||||
initializer_idx[init.name.replace(".", "_")] = idx
|
||||
node_names = dict()
|
||||
for node in model.nodes.values():
|
||||
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||
|
||||
node_idx = dict()
|
||||
for idx, node in enumerate(base_model.graph.node):
|
||||
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
|
||||
|
||||
for layer_key, weights in blended_loras.items():
|
||||
for layer_key, lora_weight in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_idx or gemm_key in node_idx:
|
||||
if conv_key in node_idx:
|
||||
conv_node = base_model.graph.node[node_idx[conv_key]]
|
||||
if conv_key in node_names or gemm_key in node_names:
|
||||
if conv_key in node_names:
|
||||
conv_node = model.nodes[node_names[conv_key]]
|
||||
else:
|
||||
conv_node = base_model.graph.node[node_idx[gemm_key]]
|
||||
conv_node = model.nodes[node_names[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
weight_name = weight_name.replace(".", "_")
|
||||
orig_weight = model.tensors[weight_name]
|
||||
|
||||
weight_idx = initializer_idx[weight_name]
|
||||
weight_node = base_model.graph.initializer[weight_idx]
|
||||
|
||||
orig_weights = numpy_helper.to_array(weight_node)
|
||||
|
||||
if orig_weights.shape[-2:] == (1, 1):
|
||||
if weights.shape[-2:] == (1, 1):
|
||||
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
if orig_weight.shape[-2:] == (1, 1):
|
||||
if lora_weight.shape[-2:] == (1, 1):
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||
else:
|
||||
new_weights = orig_weights.squeeze((3, 2)) + weights
|
||||
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||
|
||||
new_weights = np.expand_dims(new_weights, (2, 3))
|
||||
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||
else:
|
||||
if orig_weights.shape != weights.shape:
|
||||
new_weights = orig_weights + weights.reshape(orig_weights.shape)
|
||||
if orig_weight.shape != lora_weight.shape:
|
||||
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||
else:
|
||||
new_weights = orig_weights + weights
|
||||
new_weight = orig_weight + lora_weight
|
||||
|
||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
|
||||
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, new_node)
|
||||
|
||||
elif matmul_key in node_idx:
|
||||
weight_node = base_model.graph.node[node_idx[matmul_key]]
|
||||
orig_weights[weight_name] = orig_weight
|
||||
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
elif matmul_key in node_names:
|
||||
weight_node = model.nodes[node_names[matmul_key]]
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
matmul_idx = initializer_idx[matmul_name]
|
||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||
orig_weight = model.tensors[matmul_name]
|
||||
new_weight = orig_weight + lora_weight.transpose()
|
||||
|
||||
orig_weights = numpy_helper.to_array(matmul_node)
|
||||
|
||||
new_weights = orig_weights + weights.transpose()
|
||||
|
||||
# replace the original initializer
|
||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
|
||||
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, new_node)
|
||||
orig_weights[matmul_name] = orig_weight
|
||||
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
@ -807,9 +788,8 @@ class ONNXModelPatcher:
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for idx, orig_node in orig_nodes.items():
|
||||
del base_model.graph.initializer[idx]
|
||||
base_model.graph.initializer.insert(idx, orig_node)
|
||||
for name, orig_weight in orig_weights.items():
|
||||
model.tensors[name] = orig_weight
|
||||
|
||||
|
||||
|
||||
@ -825,8 +805,7 @@ class ONNXModelPatcher:
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
orig_embeddings = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
@ -845,17 +824,15 @@ class ONNXModelPatcher:
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
for i in range(len(text_encoder.proto.graph.initializer)):
|
||||
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
||||
embeddings_node_idx = i
|
||||
break
|
||||
else:
|
||||
raise Exception("text_model.embeddings.token_embedding.weight node not found")
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
base_weights = numpy_helper.to_array(embeddings_node_orig)
|
||||
|
||||
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
|
||||
embeddings = np.concatenate(
|
||||
(
|
||||
np.copy(orig_embeddings),
|
||||
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
|
||||
),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
@ -867,26 +844,22 @@ class ONNXModelPatcher:
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embedding_weights[token_id].shape != embedding.shape:
|
||||
if embeddings[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embedding_weights[token_id] = embedding
|
||||
embeddings[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
|
||||
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
|
||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if embeddings_node_orig is not None:
|
||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)
|
||||
if orig_embeddings is not None:
|
||||
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||
|
@ -426,27 +426,127 @@ class SilenceWarnings(object):
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
def buffer_external_data_tensors(model):
|
||||
external_data = dict()
|
||||
for tensor in model.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
external_data[name] = orv
|
||||
set_external_data(tensor, location="tmp.bin")
|
||||
tensor.name = name
|
||||
tensor.ClearField("raw_data")
|
||||
|
||||
return (model, external_data)
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
||||
def __init__(self, model: tuple, **kwargs):
|
||||
self.proto, self.provider, self.sess_options = model
|
||||
class IAIOnnxRuntimeModel:
|
||||
class _tensor_access:
|
||||
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.indexes = dict()
|
||||
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.model.data[key].numpy()
|
||||
|
||||
def __setitem__(self, key: str, value: np.ndarray):
|
||||
new_node = numpy_helper.from_array(value)
|
||||
set_external_data(new_node, location="in-memory-location")
|
||||
new_node.name = key
|
||||
new_node.ClearField("raw_data")
|
||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.model.data
|
||||
|
||||
def items(self):
|
||||
raise NotImplementedError("tensor.items")
|
||||
#return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.model.data.keys()
|
||||
|
||||
def values(self):
|
||||
raise NotImplementedError("tensor.values")
|
||||
#return [obj for obj in self.raw_proto]
|
||||
|
||||
|
||||
|
||||
class _access_helper:
|
||||
def __init__(self, raw_proto):
|
||||
self.indexes = dict()
|
||||
self.raw_proto = raw_proto
|
||||
for idx, obj in enumerate(raw_proto):
|
||||
self.indexes[obj.name] = idx
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
return self.raw_proto[self.indexes[key]]
|
||||
|
||||
def __setitem__(self, key: str, value):
|
||||
index = self.indexes[key]
|
||||
del self.raw_proto[index]
|
||||
self.raw_proto.insert(index, value)
|
||||
|
||||
# __delitem__
|
||||
|
||||
def __contains__(self, key: str):
|
||||
return key in self.indexes
|
||||
|
||||
def items(self):
|
||||
return [(obj.name, obj) for obj in self.raw_proto]
|
||||
|
||||
def keys(self):
|
||||
return self.indexes.keys()
|
||||
|
||||
def values(self):
|
||||
return [obj for obj in self.raw_proto]
|
||||
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self._external_data = dict()
|
||||
self.provider = provider or "CPUExecutionProvider"
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||
del tmp_proto
|
||||
gc.collect()
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=False)
|
||||
"""
|
||||
|
||||
self.proto = onnx.load(model_path, load_external_data=True)
|
||||
self.data = dict()
|
||||
for tensor in self.proto.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
self.data[name] = orv
|
||||
set_external_data(tensor, location="in-memory-location")
|
||||
tensor.name = name
|
||||
tensor.ClearField("raw_data")
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
# TODO: integrate with model manager/cache
|
||||
def create_session(self):
|
||||
if self.session is None:
|
||||
#onnx.save(self.proto, "tmp.onnx")
|
||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
#self._external_data.update(**external_data)
|
||||
sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=[self.provider], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
@ -455,63 +555,32 @@ class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
def create_session(self):
|
||||
if self.session is None:
|
||||
#onnx.save(self.proto, "tmp.onnx")
|
||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
self._external_data.update(**external_data)
|
||||
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
|
||||
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
@staticmethod
|
||||
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
||||
"""
|
||||
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
||||
|
||||
Arguments:
|
||||
path (`str` or `Path`):
|
||||
Directory from which to load
|
||||
provider(`str`, *optional*):
|
||||
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
|
||||
"""
|
||||
if provider is None:
|
||||
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
print("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
# TODO: check that provider available?
|
||||
return (onnx.load(path), provider, sess_options)
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
use_auth_token: Optional[Union[bool, str, None]] = None,
|
||||
revision: Optional[Union[str, None]] = None,
|
||||
force_download: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
subfolder: Union[str, Path] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||
|
||||
if os.path.isdir(model_id):
|
||||
model_path = model_id
|
||||
if subfolder is not None:
|
||||
model_path = os.path.join(model_path, subfolder)
|
||||
model_path = os.path.join(model_path, file_name)
|
||||
|
||||
else:
|
||||
model_path = model_id
|
||||
|
||||
# load model from local directory
|
||||
if not os.path.isdir(model_id):
|
||||
raise Exception(f"Model not found: {model_id}")
|
||||
model = IAIOnnxRuntimeModel.load_model(
|
||||
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
||||
)
|
||||
|
||||
return cls(model=model, **kwargs)
|
||||
if not os.path.isfile(model_path):
|
||||
raise Exception(f"Model not found: {model_path}")
|
||||
|
||||
# TODO: session options
|
||||
return cls(model_path, provider=provider)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user