mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
38343917f8
In #6490 we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use. This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe. - Intermittent black images on MPS devices - reported on discord and #6545, fixed with special handling in #6549. - Intermittent OOMs and black images on a P4000 GPU on Windows - reported in #6613, fixed in this commit. On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU. One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots. Much safer is to fully revert non-locking - which is what this change does.
224 lines
9.1 KiB
Python
224 lines
9.1 KiB
Python
# Copyright (c) 2024 The InvokeAI Development Team
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import torch
|
|
from onnx import numpy_helper
|
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
|
|
|
from invokeai.backend.raw_model import RawModel
|
|
|
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
|
|
|
|
|
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
|
# I have not yet gone through and fixed all the type hints
|
|
class IAIOnnxRuntimeModel(RawModel):
|
|
class _tensor_access:
|
|
def __init__(self, model): # type: ignore
|
|
self.model = model
|
|
self.indexes = {}
|
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
|
self.indexes[obj.name] = idx
|
|
|
|
def __getitem__(self, key: str): # type: ignore
|
|
value = self.model.proto.graph.initializer[self.indexes[key]]
|
|
return numpy_helper.to_array(value)
|
|
|
|
def __setitem__(self, key: str, value: np.ndarray): # type: ignore
|
|
new_node = numpy_helper.from_array(value)
|
|
# set_external_data(new_node, location="in-memory-location")
|
|
new_node.name = key
|
|
# new_node.ClearField("raw_data")
|
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
|
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
|
|
|
# __delitem__
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
return self.indexes[key] in self.model.proto.graph.initializer
|
|
|
|
def items(self) -> List[Tuple[str, Any]]: # fixme
|
|
raise NotImplementedError("tensor.items")
|
|
# return [(obj.name, obj) for obj in self.raw_proto]
|
|
|
|
def keys(self) -> List[str]:
|
|
return list(self.indexes.keys())
|
|
|
|
def values(self) -> List[Any]: # fixme
|
|
raise NotImplementedError("tensor.values")
|
|
# return [obj for obj in self.raw_proto]
|
|
|
|
def size(self) -> int:
|
|
bytesSum = 0
|
|
for node in self.model.proto.graph.initializer:
|
|
bytesSum += sys.getsizeof(node.raw_data)
|
|
return bytesSum
|
|
|
|
class _access_helper:
|
|
def __init__(self, raw_proto): # type: ignore
|
|
self.indexes = {}
|
|
self.raw_proto = raw_proto
|
|
for idx, obj in enumerate(raw_proto):
|
|
self.indexes[obj.name] = idx
|
|
|
|
def __getitem__(self, key: str): # type: ignore
|
|
return self.raw_proto[self.indexes[key]]
|
|
|
|
def __setitem__(self, key: str, value): # type: ignore
|
|
index = self.indexes[key]
|
|
del self.raw_proto[index]
|
|
self.raw_proto.insert(index, value)
|
|
|
|
# __delitem__
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
return key in self.indexes
|
|
|
|
def items(self) -> List[Tuple[str, Any]]:
|
|
return [(obj.name, obj) for obj in self.raw_proto]
|
|
|
|
def keys(self) -> List[str]:
|
|
return list(self.indexes.keys())
|
|
|
|
def values(self) -> List[Any]: # fixme
|
|
return list(self.raw_proto)
|
|
|
|
def __init__(self, model_path: str, provider: Optional[str]):
|
|
self.path = model_path
|
|
self.session = None
|
|
self.provider = provider
|
|
"""
|
|
self.data_path = self.path + "_data"
|
|
if not os.path.exists(self.data_path):
|
|
print(f"Moving model tensors to separate file: {self.data_path}")
|
|
tmp_proto = onnx.load(model_path, load_external_data=True)
|
|
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
|
del tmp_proto
|
|
gc.collect()
|
|
|
|
self.proto = onnx.load(model_path, load_external_data=False)
|
|
"""
|
|
|
|
self.proto = onnx.load(model_path, load_external_data=True)
|
|
# self.data = dict()
|
|
# for tensor in self.proto.graph.initializer:
|
|
# name = tensor.name
|
|
|
|
# if tensor.HasField("raw_data"):
|
|
# npt = numpy_helper.to_array(tensor)
|
|
# orv = OrtValue.ortvalue_from_numpy(npt)
|
|
# # self.data[name] = orv
|
|
# # set_external_data(tensor, location="in-memory-location")
|
|
# tensor.name = name
|
|
# # tensor.ClearField("raw_data")
|
|
|
|
self.nodes = self._access_helper(self.proto.graph.node) # type: ignore
|
|
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
|
# print(self.proto.graph.input)
|
|
# print(self.proto.graph.initializer)
|
|
|
|
self.tensors = self._tensor_access(self) # type: ignore
|
|
|
|
# TODO: integrate with model manager/cache
|
|
def create_session(self, height=None, width=None):
|
|
if self.session is None or self.session_width != width or self.session_height != height:
|
|
# onnx.save(self.proto, "tmp.onnx")
|
|
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
|
# TODO: something to be able to get weight when they already moved outside of model proto
|
|
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
|
sess = SessionOptions()
|
|
# self._external_data.update(**external_data)
|
|
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
|
# sess.enable_profiling = True
|
|
|
|
# sess.intra_op_num_threads = 1
|
|
# sess.inter_op_num_threads = 1
|
|
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
|
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
# sess.enable_cpu_mem_arena = True
|
|
# sess.enable_mem_pattern = True
|
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
|
self.session_height = height
|
|
self.session_width = width
|
|
if height and width:
|
|
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
|
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
|
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
|
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
|
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
|
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
|
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
|
providers = []
|
|
if self.provider:
|
|
providers.append(self.provider)
|
|
else:
|
|
providers = get_available_providers()
|
|
if "TensorrtExecutionProvider" in providers:
|
|
providers.remove("TensorrtExecutionProvider")
|
|
try:
|
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
|
except Exception as e:
|
|
raise e
|
|
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
|
# self.io_binding = self.session.io_binding()
|
|
|
|
def release_session(self):
|
|
self.session = None
|
|
import gc
|
|
|
|
gc.collect()
|
|
return
|
|
|
|
def __call__(self, **kwargs):
|
|
if self.session is None:
|
|
raise Exception("You should call create_session before running model")
|
|
|
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
|
# output_names = self.session.get_outputs()
|
|
# for k in inputs:
|
|
# self.io_binding.bind_cpu_input(k, inputs[k])
|
|
# for name in output_names:
|
|
# self.io_binding.bind_output(name.name)
|
|
# self.session.run_with_iobinding(self.io_binding, None)
|
|
# return self.io_binding.copy_outputs_to_cpu()
|
|
return self.session.run(None, inputs)
|
|
|
|
# compatability with RawModel ABC
|
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
|
pass
|
|
|
|
# compatability with diffusers load code
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
model_id: Union[str, Path],
|
|
subfolder: Optional[Union[str, Path]] = None,
|
|
file_name: Optional[str] = None,
|
|
provider: Optional[str] = None,
|
|
sess_options: Optional["SessionOptions"] = None,
|
|
**kwargs: Any,
|
|
) -> Any: # fixme
|
|
file_name = file_name or ONNX_WEIGHTS_NAME
|
|
|
|
if os.path.isdir(model_id):
|
|
model_path = model_id
|
|
if subfolder is not None:
|
|
model_path = os.path.join(model_path, subfolder)
|
|
model_path = os.path.join(model_path, file_name)
|
|
|
|
else:
|
|
model_path = model_id
|
|
|
|
# load model from local directory
|
|
if not os.path.isfile(model_path):
|
|
raise Exception(f"Model not found: {model_path}")
|
|
|
|
# TODO: session options
|
|
return cls(str(model_path), provider=provider)
|