mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
testing being super wasteful with data
This commit is contained in:
parent
91112167b1
commit
932112b640
@ -454,9 +454,9 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
def __setitem__(self, key: str, value: np.ndarray):
|
def __setitem__(self, key: str, value: np.ndarray):
|
||||||
new_node = numpy_helper.from_array(value)
|
new_node = numpy_helper.from_array(value)
|
||||||
set_external_data(new_node, location="in-memory-location")
|
# set_external_data(new_node, location="in-memory-location")
|
||||||
new_node.name = key
|
new_node.name = key
|
||||||
new_node.ClearField("raw_data")
|
# new_node.ClearField("raw_data")
|
||||||
del self.model.proto.graph.initializer[self.indexes[key]]
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||||
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||||
@ -491,7 +491,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
|
|
||||||
def __setitem__(self, key: str, value):
|
def __setitem__(self, key: str, value):
|
||||||
index = self.indexes[key]
|
index = self.indexes[key]
|
||||||
del self.raw_proto[index]
|
# del self.raw_proto[index]
|
||||||
self.raw_proto.insert(index, value)
|
self.raw_proto.insert(index, value)
|
||||||
|
|
||||||
# __delitem__
|
# __delitem__
|
||||||
@ -533,9 +533,9 @@ class IAIOnnxRuntimeModel:
|
|||||||
npt = numpy_helper.to_array(tensor)
|
npt = numpy_helper.to_array(tensor)
|
||||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
self.data[name] = orv
|
self.data[name] = orv
|
||||||
set_external_data(tensor, location="in-memory-location")
|
# set_external_data(tensor, location="in-memory-location")
|
||||||
tensor.name = name
|
tensor.name = name
|
||||||
tensor.ClearField("raw_data")
|
# tensor.ClearField("raw_data")
|
||||||
|
|
||||||
self.nodes = self._access_helper(self.proto.graph.node)
|
self.nodes = self._access_helper(self.proto.graph.node)
|
||||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||||
@ -551,8 +551,8 @@ class IAIOnnxRuntimeModel:
|
|||||||
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||||
sess = SessionOptions()
|
sess = SessionOptions()
|
||||||
#self._external_data.update(**external_data)
|
#self._external_data.update(**external_data)
|
||||||
sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=[self.provider], sess_options=sess)
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
|
||||||
def release_session(self):
|
def release_session(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user