mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch to io_binding for run, testing different session options
This commit is contained in:
parent
869f418b03
commit
e201ad2f51
@ -539,6 +539,8 @@ class IAIOnnxRuntimeModel:
|
||||
|
||||
self.nodes = self._access_helper(self.proto.graph.node)
|
||||
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||
# print(self.proto.graph.input)
|
||||
# print(self.proto.graph.initializer)
|
||||
|
||||
self.tensors = self._tensor_access(self)
|
||||
|
||||
@ -552,25 +554,45 @@ class IAIOnnxRuntimeModel:
|
||||
sess = SessionOptions()
|
||||
#self._external_data.update(**external_data)
|
||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||
sess.execution_mode = ExecutionMode.ORT_PARALLEL
|
||||
# sess.enable_profiling = True
|
||||
|
||||
sess.intra_op_num_threads = 1
|
||||
sess.inter_op_num_threads = 1
|
||||
sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
sess.enable_cpu_mem_arena = True
|
||||
sess.enable_mem_pattern = True
|
||||
sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
|
||||
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
self.io_binding = self.session.io_binding()
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
return self.session.run(None, inputs)
|
||||
output_names = self.session.get_outputs()
|
||||
for k in inputs:
|
||||
self.io_binding.bind_cpu_input(k, inputs[k])
|
||||
for name in output_names:
|
||||
self.io_binding.bind_output(name.name)
|
||||
self.session.run_with_iobinding(self.io_binding, None)
|
||||
return self.io_binding.copy_outputs_to_cpu()
|
||||
|
||||
# compatability with diffusers load code
|
||||
@classmethod
|
||||
|
Loading…
Reference in New Issue
Block a user