Testing caching onnx sessions

This commit is contained in:
Brandon Rising 2023-07-27 14:13:29 -04:00
parent 59716938bf
commit bfdc8c80f3
2 changed files with 40 additions and 22 deletions

View File

@ -353,7 +353,7 @@ class ModelCache(object):
# 2 refs: # 2 refs:
# 1 from cache_entry # 1 from cache_entry
# 1 from getrefcount function # 1 from getrefcount function
if not cache_entry.locked and refs <= 2: if not cache_entry.locked and refs <= 3 if 'onnx' in model_key else 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)') self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size current_size -= cache_entry.size
del self._cache_stack[pos] del self._cache_stack[pos]

View File

@ -239,6 +239,7 @@ class DiffusersModel(ModelBase):
self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name)
def get_size(self, child_type: Optional[SubModelType] = None): def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is None: if child_type is None:
return sum(self.child_sizes.values()) return sum(self.child_sizes.values())
@ -363,6 +364,8 @@ def calc_model_size_by_data(model) -> int:
return _calc_pipeline_by_data(model) return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module): elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model) return _calc_model_by_data(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else: else:
return 0 return 0
@ -383,6 +386,12 @@ def _calc_model_by_data(model) -> int:
return mem return mem
def _calc_onnx_model_by_data(model) -> int:
tensor_size = model.tensors.size()
mem = tensor_size # in bytes
return mem
def _fast_safetensors_reader(path: str): def _fast_safetensors_reader(path: str):
checkpoint = dict() checkpoint = dict()
device = torch.device("meta") device = torch.device("meta")
@ -455,7 +464,8 @@ class IAIOnnxRuntimeModel:
self.indexes[obj.name] = idx self.indexes[obj.name] = idx
def __getitem__(self, key: str): def __getitem__(self, key: str):
return self.model.data[key].numpy() value = self.model.proto.graph.initializer[self.indexes[key]]
return numpy_helper.to_array(value)
def __setitem__(self, key: str, value: np.ndarray): def __setitem__(self, key: str, value: np.ndarray):
new_node = numpy_helper.from_array(value) new_node = numpy_helper.from_array(value)
@ -464,24 +474,29 @@ class IAIOnnxRuntimeModel:
# new_node.ClearField("raw_data") # new_node.ClearField("raw_data")
del self.model.proto.graph.initializer[self.indexes[key]] del self.model.proto.graph.initializer[self.indexes[key]]
self.model.proto.graph.initializer.insert(self.indexes[key], new_node) self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
self.model.data[key] = OrtValue.ortvalue_from_numpy(value) # self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
# __delitem__ # __delitem__
def __contains__(self, key: str): def __contains__(self, key: str):
return key in self.model.data return self.indexes[key] in self.model.proto.graph.initializer
def items(self): def items(self):
raise NotImplementedError("tensor.items") raise NotImplementedError("tensor.items")
#return [(obj.name, obj) for obj in self.raw_proto] #return [(obj.name, obj) for obj in self.raw_proto]
def keys(self): def keys(self):
return self.model.data.keys() return self.indexes.keys()
def values(self): def values(self):
raise NotImplementedError("tensor.values") raise NotImplementedError("tensor.values")
#return [obj for obj in self.raw_proto] #return [obj for obj in self.raw_proto]
def size(self):
bytesSum = 0
for node in self.model.proto.graph.initializer:
bytesSum += sys.getsizeof(node.raw_data)
return bytesSum
class _access_helper: class _access_helper:
@ -530,20 +545,20 @@ class IAIOnnxRuntimeModel:
""" """
self.proto = onnx.load(model_path, load_external_data=True) self.proto = onnx.load(model_path, load_external_data=True)
self.data = dict() # self.data = dict()
for tensor in self.proto.graph.initializer: # for tensor in self.proto.graph.initializer:
name = tensor.name # name = tensor.name
if tensor.HasField("raw_data"): # if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor) # npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt) # orv = OrtValue.ortvalue_from_numpy(npt)
self.data[name] = orv # # self.data[name] = orv
# set_external_data(tensor, location="in-memory-location") # # set_external_data(tensor, location="in-memory-location")
tensor.name = name # tensor.name = name
# tensor.ClearField("raw_data") # # tensor.ClearField("raw_data")
self.nodes = self._access_helper(self.proto.graph.node) self.nodes = self._access_helper(self.proto.graph.node)
self.initializers = self._access_helper(self.proto.graph.initializer) # self.initializers = self._access_helper(self.proto.graph.initializer)
# print(self.proto.graph.input) # print(self.proto.graph.input)
# print(self.proto.graph.initializer) # print(self.proto.graph.initializer)
@ -551,7 +566,7 @@ class IAIOnnxRuntimeModel:
# TODO: integrate with model manager/cache # TODO: integrate with model manager/cache
def create_session(self, height=None, width=None): def create_session(self, height=None, width=None):
if self.session is None: if self.session is None or self.session_width != width or self.session_height != height:
#onnx.save(self.proto, "tmp.onnx") #onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
# TODO: something to be able to get weight when they already moved outside of model proto # TODO: something to be able to get weight when they already moved outside of model proto
@ -568,13 +583,15 @@ class IAIOnnxRuntimeModel:
# sess.enable_cpu_mem_arena = True # sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True # sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
self.session_height = height
self.session_width = width
if height and width: if height and width:
sess.add_free_dimension_override_by_name("unet_sample_batch", 2) sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4) sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", height) sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
sess.add_free_dimension_override_by_name("unet_sample_width", width) sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
sess.add_free_dimension_override_by_name("unet_time_batch", 1) sess.add_free_dimension_override_by_name("unet_time_batch", 1)
providers = [] providers = []
if self.provider: if self.provider:
@ -591,9 +608,10 @@ class IAIOnnxRuntimeModel:
# self.io_binding = self.session.io_binding() # self.io_binding = self.session.io_binding()
def release_session(self): def release_session(self):
self.session = None # self.session = None
import gc # import gc
gc.collect() # gc.collect()
return
def __call__(self, **kwargs): def __call__(self, **kwargs):
if self.session is None: if self.session is None: