mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing caching onnx sessions
This commit is contained in:
parent
59716938bf
commit
bfdc8c80f3
@ -353,7 +353,7 @@ class ModelCache(object):
|
|||||||
# 2 refs:
|
# 2 refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
if not cache_entry.locked and refs <= 2:
|
if not cache_entry.locked and refs <= 3 if 'onnx' in model_key else 2:
|
||||||
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
del self._cache_stack[pos]
|
del self._cache_stack[pos]
|
||||||
|
@ -239,6 +239,7 @@ class DiffusersModel(ModelBase):
|
|||||||
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_size(self, child_type: Optional[SubModelType] = None):
|
def get_size(self, child_type: Optional[SubModelType] = None):
|
||||||
if child_type is None:
|
if child_type is None:
|
||||||
return sum(self.child_sizes.values())
|
return sum(self.child_sizes.values())
|
||||||
@ -363,6 +364,8 @@ def calc_model_size_by_data(model) -> int:
|
|||||||
return _calc_pipeline_by_data(model)
|
return _calc_pipeline_by_data(model)
|
||||||
elif isinstance(model, torch.nn.Module):
|
elif isinstance(model, torch.nn.Module):
|
||||||
return _calc_model_by_data(model)
|
return _calc_model_by_data(model)
|
||||||
|
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||||
|
return _calc_onnx_model_by_data(model)
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
@ -383,6 +386,12 @@ def _calc_model_by_data(model) -> int:
|
|||||||
return mem
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_onnx_model_by_data(model) -> int:
|
||||||
|
tensor_size = model.tensors.size()
|
||||||
|
mem = tensor_size # in bytes
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
def _fast_safetensors_reader(path: str):
|
def _fast_safetensors_reader(path: str):
|
||||||
checkpoint = dict()
|
checkpoint = dict()
|
||||||
device = torch.device("meta")
|
device = torch.device("meta")
|
||||||
@ -455,7 +464,8 @@ class IAIOnnxRuntimeModel:
|
|||||||
self.indexes[obj.name] = idx
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
def __getitem__(self, key: str):
|
def __getitem__(self, key: str):
|
||||||
return self.model.data[key].numpy()
|
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
return numpy_helper.to_array(value)
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: np.ndarray):
|
def __setitem__(self, key: str, value: np.ndarray):
|
||||||
new_node = numpy_helper.from_array(value)
|
new_node = numpy_helper.from_array(value)
|
||||||
@ -464,24 +474,29 @@ class IAIOnnxRuntimeModel:
|
|||||||
# new_node.ClearField("raw_data")
|
# new_node.ClearField("raw_data")
|
||||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||||
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||||
|
|
||||||
# __delitem__
|
# __delitem__
|
||||||
|
|
||||||
def __contains__(self, key: str):
|
def __contains__(self, key: str):
|
||||||
return key in self.model.data
|
return self.indexes[key] in self.model.proto.graph.initializer
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
raise NotImplementedError("tensor.items")
|
raise NotImplementedError("tensor.items")
|
||||||
#return [(obj.name, obj) for obj in self.raw_proto]
|
#return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return self.model.data.keys()
|
return self.indexes.keys()
|
||||||
|
|
||||||
def values(self):
|
def values(self):
|
||||||
raise NotImplementedError("tensor.values")
|
raise NotImplementedError("tensor.values")
|
||||||
#return [obj for obj in self.raw_proto]
|
#return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
bytesSum = 0
|
||||||
|
for node in self.model.proto.graph.initializer:
|
||||||
|
bytesSum += sys.getsizeof(node.raw_data)
|
||||||
|
return bytesSum
|
||||||
|
|
||||||
|
|
||||||
class _access_helper:
|
class _access_helper:
|
||||||
@ -530,20 +545,20 @@ class IAIOnnxRuntimeModel:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
self.proto = onnx.load(model_path, load_external_data=True)
|
self.proto = onnx.load(model_path, load_external_data=True)
|
||||||
self.data = dict()
|
# self.data = dict()
|
||||||
for tensor in self.proto.graph.initializer:
|
# for tensor in self.proto.graph.initializer:
|
||||||
name = tensor.name
|
# name = tensor.name
|
||||||
|
|
||||||
if tensor.HasField("raw_data"):
|
# if tensor.HasField("raw_data"):
|
||||||
npt = numpy_helper.to_array(tensor)
|
# npt = numpy_helper.to_array(tensor)
|
||||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
self.data[name] = orv
|
# # self.data[name] = orv
|
||||||
# set_external_data(tensor, location="in-memory-location")
|
# # set_external_data(tensor, location="in-memory-location")
|
||||||
tensor.name = name
|
# tensor.name = name
|
||||||
# tensor.ClearField("raw_data")
|
# # tensor.ClearField("raw_data")
|
||||||
|
|
||||||
self.nodes = self._access_helper(self.proto.graph.node)
|
self.nodes = self._access_helper(self.proto.graph.node)
|
||||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||||
# print(self.proto.graph.input)
|
# print(self.proto.graph.input)
|
||||||
# print(self.proto.graph.initializer)
|
# print(self.proto.graph.initializer)
|
||||||
|
|
||||||
@ -551,7 +566,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
# TODO: integrate with model manager/cache
|
# TODO: integrate with model manager/cache
|
||||||
def create_session(self, height=None, width=None):
|
def create_session(self, height=None, width=None):
|
||||||
if self.session is None:
|
if self.session is None or self.session_width != width or self.session_height != height:
|
||||||
#onnx.save(self.proto, "tmp.onnx")
|
#onnx.save(self.proto, "tmp.onnx")
|
||||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||||
@ -568,13 +583,15 @@ class IAIOnnxRuntimeModel:
|
|||||||
# sess.enable_cpu_mem_arena = True
|
# sess.enable_cpu_mem_arena = True
|
||||||
# sess.enable_mem_pattern = True
|
# sess.enable_mem_pattern = True
|
||||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||||
|
self.session_height = height
|
||||||
|
self.session_width = width
|
||||||
if height and width:
|
if height and width:
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_height", height)
|
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_width", width)
|
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||||
providers = []
|
providers = []
|
||||||
if self.provider:
|
if self.provider:
|
||||||
@ -591,9 +608,10 @@ class IAIOnnxRuntimeModel:
|
|||||||
# self.io_binding = self.session.io_binding()
|
# self.io_binding = self.session.io_binding()
|
||||||
|
|
||||||
def release_session(self):
|
def release_session(self):
|
||||||
self.session = None
|
# self.session = None
|
||||||
import gc
|
# import gc
|
||||||
gc.collect()
|
# gc.collect()
|
||||||
|
return
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
if self.session is None:
|
if self.session is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user