# Copyright (c) 2024 The InvokeAI Development Team import os import sys from pathlib import Path from typing import Any, List, Optional, Tuple, Union import numpy as np import onnx import torch from onnx import numpy_helper from onnxruntime import InferenceSession, SessionOptions, get_available_providers from invokeai.backend.raw_model import RawModel ONNX_WEIGHTS_NAME = "model.onnx" # NOTE FROM LS: This was copied from Stalker's original implementation. # I have not yet gone through and fixed all the type hints class IAIOnnxRuntimeModel(RawModel): class _tensor_access: def __init__(self, model): # type: ignore self.model = model self.indexes = {} for idx, obj in enumerate(self.model.proto.graph.initializer): self.indexes[obj.name] = idx def __getitem__(self, key: str): # type: ignore value = self.model.proto.graph.initializer[self.indexes[key]] return numpy_helper.to_array(value) def __setitem__(self, key: str, value: np.ndarray): # type: ignore new_node = numpy_helper.from_array(value) # set_external_data(new_node, location="in-memory-location") new_node.name = key # new_node.ClearField("raw_data") del self.model.proto.graph.initializer[self.indexes[key]] self.model.proto.graph.initializer.insert(self.indexes[key], new_node) # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) # __delitem__ def __contains__(self, key: str) -> bool: return self.indexes[key] in self.model.proto.graph.initializer def items(self) -> List[Tuple[str, Any]]: # fixme raise NotImplementedError("tensor.items") # return [(obj.name, obj) for obj in self.raw_proto] def keys(self) -> List[str]: return list(self.indexes.keys()) def values(self) -> List[Any]: # fixme raise NotImplementedError("tensor.values") # return [obj for obj in self.raw_proto] def size(self) -> int: bytesSum = 0 for node in self.model.proto.graph.initializer: bytesSum += sys.getsizeof(node.raw_data) return bytesSum class _access_helper: def __init__(self, raw_proto): # type: ignore self.indexes = {} self.raw_proto = raw_proto for idx, obj in enumerate(raw_proto): self.indexes[obj.name] = idx def __getitem__(self, key: str): # type: ignore return self.raw_proto[self.indexes[key]] def __setitem__(self, key: str, value): # type: ignore index = self.indexes[key] del self.raw_proto[index] self.raw_proto.insert(index, value) # __delitem__ def __contains__(self, key: str) -> bool: return key in self.indexes def items(self) -> List[Tuple[str, Any]]: return [(obj.name, obj) for obj in self.raw_proto] def keys(self) -> List[str]: return list(self.indexes.keys()) def values(self) -> List[Any]: # fixme return list(self.raw_proto) def __init__(self, model_path: str, provider: Optional[str]): self.path = model_path self.session = None self.provider = provider """ self.data_path = self.path + "_data" if not os.path.exists(self.data_path): print(f"Moving model tensors to separate file: {self.data_path}") tmp_proto = onnx.load(model_path, load_external_data=True) onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) del tmp_proto gc.collect() self.proto = onnx.load(model_path, load_external_data=False) """ self.proto = onnx.load(model_path, load_external_data=True) # self.data = dict() # for tensor in self.proto.graph.initializer: # name = tensor.name # if tensor.HasField("raw_data"): # npt = numpy_helper.to_array(tensor) # orv = OrtValue.ortvalue_from_numpy(npt) # # self.data[name] = orv # # set_external_data(tensor, location="in-memory-location") # tensor.name = name # # tensor.ClearField("raw_data") self.nodes = self._access_helper(self.proto.graph.node) # type: ignore # self.initializers = self._access_helper(self.proto.graph.initializer) # print(self.proto.graph.input) # print(self.proto.graph.initializer) self.tensors = self._tensor_access(self) # type: ignore # TODO: integrate with model manager/cache def create_session(self, height=None, width=None): if self.session is None or self.session_width != width or self.session_height != height: # onnx.save(self.proto, "tmp.onnx") # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) # TODO: something to be able to get weight when they already moved outside of model proto # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) sess = SessionOptions() # self._external_data.update(**external_data) # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) # sess.enable_profiling = True # sess.intra_op_num_threads = 1 # sess.inter_op_num_threads = 1 # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL # sess.enable_cpu_mem_arena = True # sess.enable_mem_pattern = True # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code self.session_height = height self.session_width = width if height and width: sess.add_free_dimension_override_by_name("unet_sample_batch", 2) sess.add_free_dimension_override_by_name("unet_sample_channels", 4) sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) sess.add_free_dimension_override_by_name("unet_time_batch", 1) providers = [] if self.provider: providers.append(self.provider) else: providers = get_available_providers() if "TensorrtExecutionProvider" in providers: providers.remove("TensorrtExecutionProvider") try: self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) except Exception as e: raise e # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) # self.io_binding = self.session.io_binding() def release_session(self): self.session = None import gc gc.collect() return def __call__(self, **kwargs): if self.session is None: raise Exception("You should call create_session before running model") inputs = {k: np.array(v) for k, v in kwargs.items()} # output_names = self.session.get_outputs() # for k in inputs: # self.io_binding.bind_cpu_input(k, inputs[k]) # for name in output_names: # self.io_binding.bind_output(name.name) # self.session.run_with_iobinding(self.io_binding, None) # return self.io_binding.copy_outputs_to_cpu() return self.session.run(None, inputs) # compatability with RawModel ABC def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False, ) -> None: pass # compatability with diffusers load code @classmethod def from_pretrained( cls, model_id: Union[str, Path], subfolder: Optional[Union[str, Path]] = None, file_name: Optional[str] = None, provider: Optional[str] = None, sess_options: Optional["SessionOptions"] = None, **kwargs: Any, ) -> Any: # fixme file_name = file_name or ONNX_WEIGHTS_NAME if os.path.isdir(model_id): model_path = model_id if subfolder is not None: model_path = os.path.join(model_path, subfolder) model_path = os.path.join(model_path, file_name) else: model_path = model_id # load model from local directory if not os.path.isfile(model_path): raise Exception(f"Model not found: {model_path}") # TODO: session options return cls(str(model_path), provider=provider)