From 78750042f505181bd6f57570a56a19b838b4ff8d Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 21 Jul 2023 12:16:24 -0400 Subject: [PATCH] Pass in dim overrides --- invokeai/app/invocations/onnx.py | 3 ++- .../backend/model_management/models/base.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index e89ed32eb5..d69eaeb02a 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -257,7 +257,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation): with ONNXModelPatcher.apply_lora_unet(unet, loras): # TODO: - unet.create_session() + _, _, h, w = latents.shape + unet.create_session(h, w) timestep_dtype = next( (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 026b5c5326..23c6906b53 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -548,7 +548,7 @@ class IAIOnnxRuntimeModel: self.tensors = self._tensor_access(self) # TODO: integrate with model manager/cache - def create_session(self): + def create_session(self, height=None, width=None): if self.session is None: #onnx.save(self.proto, "tmp.onnx") #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) @@ -566,14 +566,14 @@ class IAIOnnxRuntimeModel: # sess.enable_cpu_mem_arena = True # sess.enable_mem_pattern = True # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code - - # sess.add_free_dimension_override_by_name("unet_sample_batch", 2) - # sess.add_free_dimension_override_by_name("unet_sample_channels", 4) - # sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) - # sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - # sess.add_free_dimension_override_by_name("unet_sample_height", 64) - # sess.add_free_dimension_override_by_name("unet_sample_width", 64) - # sess.add_free_dimension_override_by_name("unet_time_batch", 1) + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", height) + sess.add_free_dimension_override_by_name("unet_sample_width", width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) providers = [] if self.provider: providers.append(self.provider)