Pass in dim overrides

This commit is contained in:
Brandon Rising 2023-07-21 12:16:24 -04:00
parent ce08aa350c
commit 78750042f5
2 changed files with 11 additions and 10 deletions

View File

@ -257,7 +257,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:
unet.create_session()
_, _, h, w = latents.shape
unet.create_session(h, w)
timestep_dtype = next(
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"

View File

@ -548,7 +548,7 @@ class IAIOnnxRuntimeModel:
self.tensors = self._tensor_access(self)
# TODO: integrate with model manager/cache
def create_session(self):
def create_session(self, height=None, width=None):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
@ -566,14 +566,14 @@ class IAIOnnxRuntimeModel:
# sess.enable_cpu_mem_arena = True
# sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
# sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
# sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
# sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
# sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
# sess.add_free_dimension_override_by_name("unet_time_batch", 1)
if height and width:
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_sample_height", height)
sess.add_free_dimension_override_by_name("unet_sample_width", width)
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
providers = []
if self.provider:
providers.append(self.provider)