diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 5ca17f00fc..79ddd624fc 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -353,7 +353,7 @@ class ModelCache(object): # 2 refs: # 1 from cache_entry # 1 from getrefcount function - if not cache_entry.locked and refs <= 2: + if not cache_entry.locked and refs <= 3 if 'onnx' in model_key else 2: self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)') current_size -= cache_entry.size del self._cache_stack[pos] diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 09e949baba..06255ac6f6 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -239,6 +239,7 @@ class DiffusersModel(ModelBase): self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) + def get_size(self, child_type: Optional[SubModelType] = None): if child_type is None: return sum(self.child_sizes.values()) @@ -363,6 +364,8 @@ def calc_model_size_by_data(model) -> int: return _calc_pipeline_by_data(model) elif isinstance(model, torch.nn.Module): return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) else: return 0 @@ -383,6 +386,12 @@ def _calc_model_by_data(model) -> int: return mem +def _calc_onnx_model_by_data(model) -> int: + tensor_size = model.tensors.size() + mem = tensor_size # in bytes + return mem + + def _fast_safetensors_reader(path: str): checkpoint = dict() device = torch.device("meta") @@ -455,7 +464,8 @@ class IAIOnnxRuntimeModel: self.indexes[obj.name] = idx def __getitem__(self, key: str): - return self.model.data[key].numpy() + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) def __setitem__(self, key: str, value: np.ndarray): new_node = numpy_helper.from_array(value) @@ -464,24 +474,29 @@ class IAIOnnxRuntimeModel: # new_node.ClearField("raw_data") del self.model.proto.graph.initializer[self.indexes[key]] self.model.proto.graph.initializer.insert(self.indexes[key], new_node) - self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) # __delitem__ def __contains__(self, key: str): - return key in self.model.data + return self.indexes[key] in self.model.proto.graph.initializer def items(self): raise NotImplementedError("tensor.items") #return [(obj.name, obj) for obj in self.raw_proto] def keys(self): - return self.model.data.keys() + return self.indexes.keys() def values(self): raise NotImplementedError("tensor.values") #return [obj for obj in self.raw_proto] + def size(self): + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum class _access_helper: @@ -530,20 +545,20 @@ class IAIOnnxRuntimeModel: """ self.proto = onnx.load(model_path, load_external_data=True) - self.data = dict() - for tensor in self.proto.graph.initializer: - name = tensor.name + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name - if tensor.HasField("raw_data"): - npt = numpy_helper.to_array(tensor) - orv = OrtValue.ortvalue_from_numpy(npt) - self.data[name] = orv - # set_external_data(tensor, location="in-memory-location") - tensor.name = name - # tensor.ClearField("raw_data") + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") self.nodes = self._access_helper(self.proto.graph.node) - self.initializers = self._access_helper(self.proto.graph.initializer) + # self.initializers = self._access_helper(self.proto.graph.initializer) # print(self.proto.graph.input) # print(self.proto.graph.initializer) @@ -551,7 +566,7 @@ class IAIOnnxRuntimeModel: # TODO: integrate with model manager/cache def create_session(self, height=None, width=None): - if self.session is None: + if self.session is None or self.session_width != width or self.session_height != height: #onnx.save(self.proto, "tmp.onnx") #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) # TODO: something to be able to get weight when they already moved outside of model proto @@ -568,13 +583,15 @@ class IAIOnnxRuntimeModel: # sess.enable_cpu_mem_arena = True # sess.enable_mem_pattern = True # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width if height and width: sess.add_free_dimension_override_by_name("unet_sample_batch", 2) sess.add_free_dimension_override_by_name("unet_sample_channels", 4) sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - sess.add_free_dimension_override_by_name("unet_sample_height", height) - sess.add_free_dimension_override_by_name("unet_sample_width", width) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) sess.add_free_dimension_override_by_name("unet_time_batch", 1) providers = [] if self.provider: @@ -591,9 +608,10 @@ class IAIOnnxRuntimeModel: # self.io_binding = self.session.io_binding() def release_session(self): - self.session = None - import gc - gc.collect() + # self.session = None + # import gc + # gc.collect() + return def __call__(self, **kwargs): if self.session is None: