mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Pass in dim overrides
This commit is contained in:
parent
ce08aa350c
commit
78750042f5
@ -257,7 +257,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||||
# TODO:
|
# TODO:
|
||||||
unet.create_session()
|
_, _, h, w = latents.shape
|
||||||
|
unet.create_session(h, w)
|
||||||
|
|
||||||
timestep_dtype = next(
|
timestep_dtype = next(
|
||||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||||
|
@ -548,7 +548,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
self.tensors = self._tensor_access(self)
|
self.tensors = self._tensor_access(self)
|
||||||
|
|
||||||
# TODO: integrate with model manager/cache
|
# TODO: integrate with model manager/cache
|
||||||
def create_session(self):
|
def create_session(self, height=None, width=None):
|
||||||
if self.session is None:
|
if self.session is None:
|
||||||
#onnx.save(self.proto, "tmp.onnx")
|
#onnx.save(self.proto, "tmp.onnx")
|
||||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
@ -566,14 +566,14 @@ class IAIOnnxRuntimeModel:
|
|||||||
# sess.enable_cpu_mem_arena = True
|
# sess.enable_cpu_mem_arena = True
|
||||||
# sess.enable_mem_pattern = True
|
# sess.enable_mem_pattern = True
|
||||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||||
|
if height and width:
|
||||||
# sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||||
# sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||||
# sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||||
# sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||||
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
sess.add_free_dimension_override_by_name("unet_sample_height", height)
|
||||||
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
sess.add_free_dimension_override_by_name("unet_sample_width", width)
|
||||||
# sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||||
providers = []
|
providers = []
|
||||||
if self.provider:
|
if self.provider:
|
||||||
providers.append(self.provider)
|
providers.append(self.provider)
|
||||||
|
Loading…
Reference in New Issue
Block a user