psychedelicious 38343917f8 fix(backend): revert non-blocking device transfer
In  we enabled non-blocking torch device transfers throughout the model manager's memory management code. When using this torch feature, torch attempts to wait until the tensor transfer has completed before allowing any access to the tensor. Theoretically, that should make this a safe feature to use.

This provides a small performance improvement but causes race conditions in some situations. Specific platforms/systems are affected, and complicated data dependencies can make this unsafe.

- Intermittent black images on MPS devices - reported on discord and , fixed with special handling in .
- Intermittent OOMs and black images on a P4000 GPU on Windows - reported in , fixed in this commit.

On my system, I haven't experience any issues with generation, but targeted testing of non-blocking ops did expose a race condition when moving tensors from CUDA to CPU.

One workaround is to use torch streams with manual sync points. Our application logic is complicated enough that this would be a lot of work and feels ripe for edge cases and missed spots.

Much safer is to fully revert non-locking - which is what this change does.
2024-07-16 08:59:42 +10:00

224 lines
9.1 KiB

# Copyright (c) 2024 The InvokeAI Development Team
import os
import sys
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import onnx
import torch
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from invokeai.backend.raw_model import RawModel
ONNX_WEIGHTS_NAME = "model.onnx"
# NOTE FROM LS: This was copied from Stalker's original implementation.
# I have not yet gone through and fixed all the type hints
class IAIOnnxRuntimeModel(RawModel):
class _tensor_access:
def __init__(self, model): # type: ignore
self.model = model
self.indexes = {}
for idx, obj in enumerate(self.model.proto.graph.initializer):
self.indexes[] = idx
def __getitem__(self, key: str): # type: ignore
value = self.model.proto.graph.initializer[self.indexes[key]]
return numpy_helper.to_array(value)
def __setitem__(self, key: str, value: np.ndarray): # type: ignore
new_node = numpy_helper.from_array(value)
# set_external_data(new_node, location="in-memory-location") = key
# new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
#[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__
def __contains__(self, key: str) -> bool:
return self.indexes[key] in self.model.proto.graph.initializer
def items(self) -> List[Tuple[str, Any]]: # fixme
raise NotImplementedError("tensor.items")
# return [(, obj) for obj in self.raw_proto]
def keys(self) -> List[str]:
return list(self.indexes.keys())
def values(self) -> List[Any]: # fixme
raise NotImplementedError("tensor.values")
# return [obj for obj in self.raw_proto]
def size(self) -> int:
bytesSum = 0
for node in self.model.proto.graph.initializer:
bytesSum += sys.getsizeof(node.raw_data)
return bytesSum
class _access_helper:
def __init__(self, raw_proto): # type: ignore
self.indexes = {}
self.raw_proto = raw_proto
for idx, obj in enumerate(raw_proto):
self.indexes[] = idx
def __getitem__(self, key: str): # type: ignore
return self.raw_proto[self.indexes[key]]
def __setitem__(self, key: str, value): # type: ignore
index = self.indexes[key]
del self.raw_proto[index]
self.raw_proto.insert(index, value)
# __delitem__
def __contains__(self, key: str) -> bool:
return key in self.indexes
def items(self) -> List[Tuple[str, Any]]:
return [(, obj) for obj in self.raw_proto]
def keys(self) -> List[str]:
return list(self.indexes.keys())
def values(self) -> List[Any]: # fixme
return list(self.raw_proto)
def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path
self.session = None
self.provider = provider
self.data_path = self.path + "_data"
if not os.path.exists(self.data_path):
print(f"Moving model tensors to separate file: {self.data_path}")
tmp_proto = onnx.load(model_path, load_external_data=True)
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
del tmp_proto
self.proto = onnx.load(model_path, load_external_data=False)
self.proto = onnx.load(model_path, load_external_data=True)
# = dict()
# for tensor in self.proto.graph.initializer:
# name =
# if tensor.HasField("raw_data"):
# npt = numpy_helper.to_array(tensor)
# orv = OrtValue.ortvalue_from_numpy(npt)
# #[name] = orv
# # set_external_data(tensor, location="in-memory-location")
# = name
# # tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node) # type: ignore
# self.initializers = self._access_helper(self.proto.graph.initializer)
# print(self.proto.graph.input)
# print(self.proto.graph.initializer)
self.tensors = self._tensor_access(self) # type: ignore
# TODO: integrate with model manager/cache
def create_session(self, height=None, width=None):
if self.session is None or self.session_width != width or self.session_height != height:
#, "tmp.onnx")
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
# self._external_data.update(**external_data)
# sess.add_external_initializers(list(, list(
# sess.enable_profiling = True
# sess.intra_op_num_threads = 1
# sess.inter_op_num_threads = 1
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
self.session_height = height
self.session_width = width
if height and width:
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
providers = []
if self.provider:
providers = get_available_providers()
if "TensorrtExecutionProvider" in providers:
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
except Exception as e:
raise e
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
# self.io_binding = self.session.io_binding()
def release_session(self):
self.session = None
import gc
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
# output_names = self.session.get_outputs()
# for k in inputs:
# self.io_binding.bind_cpu_input(k, inputs[k])
# for name in output_names:
# self.io_binding.bind_output(
# self.session.run_with_iobinding(self.io_binding, None)
# return self.io_binding.copy_outputs_to_cpu()
return, inputs)
# compatability with RawModel ABC
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
# compatability with diffusers load code
def from_pretrained(
model_id: Union[str, Path],
subfolder: Optional[Union[str, Path]] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs: Any,
) -> Any: # fixme
file_name = file_name or ONNX_WEIGHTS_NAME
if os.path.isdir(model_id):
model_path = model_id
if subfolder is not None:
model_path = os.path.join(model_path, subfolder)
model_path = os.path.join(model_path, file_name)
model_path = model_id
# load model from local directory
if not os.path.isfile(model_path):
raise Exception(f"Model not found: {model_path}")
# TODO: session options
return cls(str(model_path), provider=provider)