diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 5d621ab9b4..71a4fc5e4e 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -63,20 +63,26 @@ class ONNXPromptInvocation(BaseInvocation): text_encoder_info as text_encoder,\ ExitStack() as stack: - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] + #loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] + loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: ti_list.append( - stack.enter_context( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ) - ) + #stack.enter_context( + # context.services.model_manager.get_model( + # model_name=name, + # base_model=self.clip.text_encoder.base_model, + # model_type=ModelType.TextualInversion, + # ) + #) + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) except Exception: #print(e) @@ -218,7 +224,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation): with unet_info as unet,\ ExitStack() as stack: - loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + #loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] with ONNXModelPatcher.apply_lora_unet(unet, loras): # TODO: diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index c092ecb384..b6187469c5 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -12,6 +12,7 @@ from torch.utils.hooks import RemovableHandle from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel from onnx import numpy_helper +from onnxruntime import OrtValue import numpy as np from compel.embeddings_provider import BaseTextualInversionManager @@ -718,8 +719,7 @@ class ONNXModelPatcher: if not isinstance(model, IAIOnnxRuntimeModel): raise Exception("Only IAIOnnxRuntimeModel models supported") - base_model = model.proto - orig_nodes = dict() + orig_weights = dict() try: blended_loras = dict() @@ -736,68 +736,49 @@ class ONNXModelPatcher: else: blended_loras[layer_key] = layer_weight - initializer_idx = dict() - for idx, init in enumerate(base_model.graph.initializer): - initializer_idx[init.name.replace(".", "_")] = idx + node_names = dict() + for node in model.nodes.values(): + node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name - node_idx = dict() - for idx, node in enumerate(base_model.graph.node): - node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx - - for layer_key, weights in blended_loras.items(): + for layer_key, lora_weight in blended_loras.items(): conv_key = layer_key + "_Conv" gemm_key = layer_key + "_Gemm" matmul_key = layer_key + "_MatMul" - if conv_key in node_idx or gemm_key in node_idx: - if conv_key in node_idx: - conv_node = base_model.graph.node[node_idx[conv_key]] + if conv_key in node_names or gemm_key in node_names: + if conv_key in node_names: + conv_node = model.nodes[node_names[conv_key]] else: - conv_node = base_model.graph.node[node_idx[gemm_key]] + conv_node = model.nodes[node_names[gemm_key]] weight_name = [n for n in conv_node.input if ".weight" in n][0] - weight_name = weight_name.replace(".", "_") + orig_weight = model.tensors[weight_name] - weight_idx = initializer_idx[weight_name] - weight_node = base_model.graph.initializer[weight_idx] - - orig_weights = numpy_helper.to_array(weight_node) - - if orig_weights.shape[-2:] == (1, 1): - if weights.shape[-2:] == (1, 1): - new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) + if orig_weight.shape[-2:] == (1, 1): + if lora_weight.shape[-2:] == (1, 1): + new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) else: - new_weights = orig_weights.squeeze((3, 2)) + weights + new_weight = orig_weight.squeeze((3, 2)) + lora_weight - new_weights = np.expand_dims(new_weights, (2, 3)) + new_weight = np.expand_dims(new_weight, (2, 3)) else: - if orig_weights.shape != weights.shape: - new_weights = orig_weights + weights.reshape(orig_weights.shape) + if orig_weight.shape != lora_weight.shape: + new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) else: - new_weights = orig_weights + weights + new_weight = orig_weight + lora_weight - new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name) - orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx] - del base_model.graph.initializer[weight_idx] - base_model.graph.initializer.insert(weight_idx, new_node) - - elif matmul_key in node_idx: - weight_node = base_model.graph.node[node_idx[matmul_key]] + orig_weights[weight_name] = orig_weight + model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) + elif matmul_key in node_names: + weight_node = model.nodes[node_names[matmul_key]] matmul_name = [n for n in weight_node.input if "MatMul" in n][0] - matmul_idx = initializer_idx[matmul_name] - matmul_node = base_model.graph.initializer[matmul_idx] + orig_weight = model.tensors[matmul_name] + new_weight = orig_weight + lora_weight.transpose() - orig_weights = numpy_helper.to_array(matmul_node) - - new_weights = orig_weights + weights.transpose() - - # replace the original initializer - new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name) - orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx] - del base_model.graph.initializer[matmul_idx] - base_model.graph.initializer.insert(matmul_idx, new_node) + orig_weights[matmul_name] = orig_weight + model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) else: # warn? err? @@ -807,9 +788,8 @@ class ONNXModelPatcher: finally: # restore original weights - for idx, orig_node in orig_nodes.items(): - del base_model.graph.initializer[idx] - base_model.graph.initializer.insert(idx, orig_node) + for name, orig_weight in orig_weights.items(): + model.tensors[name] = orig_weight @@ -825,8 +805,7 @@ class ONNXModelPatcher: if not isinstance(text_encoder, IAIOnnxRuntimeModel): raise Exception("Only IAIOnnxRuntimeModel models supported") - init_tokens_count = None - new_tokens_added = None + orig_embeddings = None try: ti_tokenizer = copy.deepcopy(tokenizer) @@ -845,17 +824,15 @@ class ONNXModelPatcher: new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) # modify text_encoder - for i in range(len(text_encoder.proto.graph.initializer)): - if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight": - embeddings_node_idx = i - break - else: - raise Exception("text_model.embeddings.token_embedding.weight node not found") + orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] - embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx] - base_weights = numpy_helper.to_array(embeddings_node_orig) - - embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0) + embeddings = np.concatenate( + ( + np.copy(orig_embeddings), + np.zeros((new_tokens_added, orig_embeddings.shape[1])) + ), + axis=0, + ) for ti in ti_list: ti_tokens = [] @@ -867,26 +844,22 @@ class ONNXModelPatcher: if token_id == ti_tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{trigger}'") - if embedding_weights[token_id].shape != embedding.shape: + if embeddings[token_id].shape != embedding.shape: raise ValueError( - f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}." + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}." ) - embedding_weights[token_id] = embedding + embeddings[token_id] = embedding ti_tokens.append(token_id) if len(ti_tokens) > 1: ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - - new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name) - del text_encoder.proto.graph.initializer[embeddings_node_idx] - text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node) + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype) yield ti_tokenizer, ti_manager finally: # restore - if embeddings_node_orig is not None: - del text_encoder.proto.graph.initializer[embeddings_node_idx] - text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig) + if orig_embeddings is not None: + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index dd8fe2ee19..a3c9b4bc87 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -426,27 +426,127 @@ class SilenceWarnings(object): diffusers_logging.set_verbosity(self.diffusers_verbosity) warnings.simplefilter('default') -def buffer_external_data_tensors(model): - external_data = dict() - for tensor in model.graph.initializer: - name = tensor.name - - if tensor.HasField("raw_data"): - npt = numpy_helper.to_array(tensor) - orv = OrtValue.ortvalue_from_numpy(npt) - external_data[name] = orv - set_external_data(tensor, location="tmp.bin") - tensor.name = name - tensor.ClearField("raw_data") - - return (model, external_data) - ONNX_WEIGHTS_NAME = "model.onnx" -class IAIOnnxRuntimeModel(OnnxRuntimeModel): - def __init__(self, model: tuple, **kwargs): - self.proto, self.provider, self.sess_options = model +class IAIOnnxRuntimeModel: + class _tensor_access: + + def __init__(self, model): + self.model = model + self.indexes = dict() + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): + return self.model.data[key].numpy() + + def __setitem__(self, key: str, value: np.ndarray): + new_node = numpy_helper.from_array(value) + set_external_data(new_node, location="in-memory-location") + new_node.name = key + new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str): + return key in self.model.data + + def items(self): + raise NotImplementedError("tensor.items") + #return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self): + return self.model.data.keys() + + def values(self): + raise NotImplementedError("tensor.values") + #return [obj for obj in self.raw_proto] + + + + class _access_helper: + def __init__(self, raw_proto): + self.indexes = dict() + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str): + return key in self.indexes + + def items(self): + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self): + return self.indexes.keys() + + def values(self): + return [obj for obj in self.raw_proto] + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path self.session = None - self._external_data = dict() + self.provider = provider or "CPUExecutionProvider" + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + self.data = dict() + for tensor in self.proto.graph.initializer: + name = tensor.name + + if tensor.HasField("raw_data"): + npt = numpy_helper.to_array(tensor) + orv = OrtValue.ortvalue_from_numpy(npt) + self.data[name] = orv + set_external_data(tensor, location="in-memory-location") + tensor.name = name + tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) + self.initializers = self._access_helper(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) + + # TODO: integrate with model manager/cache + def create_session(self): + if self.session is None: + #onnx.save(self.proto, "tmp.onnx") + #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + #(trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + #self._external_data.update(**external_data) + sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + self.session = InferenceSession(self.proto.SerializeToString(), providers=[self.provider], sess_options=sess) + #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + + def release_session(self): + self.session = None + import gc + gc.collect() + def __call__(self, **kwargs): if self.session is None: @@ -455,63 +555,32 @@ class IAIOnnxRuntimeModel(OnnxRuntimeModel): inputs = {k: np.array(v) for k, v in kwargs.items()} return self.session.run(None, inputs) - def create_session(self): - if self.session is None: - #onnx.save(self.proto, "tmp.onnx") - #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) - # TODO: something to be able to get weight when they already moved outside of model proto - (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) - sess = SessionOptions() - self._external_data.update(**external_data) - sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values())) - self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess) - #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) - - def release_session(self): - self.session = None - import gc - gc.collect() - - @staticmethod - def load_model(path: Union[str, Path], provider=None, sess_options=None): - """ - Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` - - Arguments: - path (`str` or `Path`): - Directory from which to load - provider(`str`, *optional*): - Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` - """ - if provider is None: - #logger.info("No onnxruntime provider specified, using CPUExecutionProvider") - print("No onnxruntime provider specified, using CPUExecutionProvider") - provider = "CPUExecutionProvider" - - # TODO: check that provider available? - return (onnx.load(path), provider, sess_options) - + # compatability with diffusers load code @classmethod - def _from_pretrained( + def from_pretrained( cls, model_id: Union[str, Path], - use_auth_token: Optional[Union[bool, str, None]] = None, - revision: Optional[Union[str, None]] = None, - force_download: bool = False, - cache_dir: Optional[str] = None, + subfolder: Union[str, Path] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["SessionOptions"] = None, **kwargs, ): - model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + # load model from local directory - if not os.path.isdir(model_id): - raise Exception(f"Model not found: {model_id}") - model = IAIOnnxRuntimeModel.load_model( - os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options - ) - - return cls(model=model, **kwargs) + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + # TODO: session options + return cls(model_path, provider=provider)