From 4e90376d119b6300f1a998d56e5df565c2992b1a Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 20 Jul 2023 13:15:45 -0400 Subject: [PATCH] Allow passing in of precision, use available providers if none provided --- invokeai/app/invocations/onnx.py | 7 +++-- .../backend/model_management/models/base.py | 26 +++++++++++-------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 6d91819329..5d0512c7f0 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -49,6 +49,9 @@ ORT_TO_NP_TYPE = { "tensor(double)": np.float64, } +PRECISION_VALUES = Literal[ + tuple(list(ORT_TO_NP_TYPE.keys())) +] class ONNXPromptInvocation(BaseInvocation): type: Literal["prompt_onnx"] = "prompt_onnx" @@ -151,6 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) + precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents") unet: UNetField = Field(default=None, description="UNet submodel") #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) @@ -202,7 +206,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): latents = latents.cpu().numpy() # TODO: better execution device handling - latents = latents.astype(np.float16) + latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) # get the initial random noise unless the user supplied it do_classifier_free_guidance = True @@ -486,7 +490,6 @@ class OnnxModelLoaderInvocation(BaseInvocation): type: Literal["onnx_model_loader"] = "onnx_model_loader" model: OnnxModelField = Field(description="The model to load") - # TODO: precision? # Schema customisation class Config(InvocationConfig): diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index 2cfb254b39..026b5c5326 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab import onnx from onnx import numpy_helper from onnx.external_data_helper import set_external_data -from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel +from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel, get_available_providers class InvalidModelException(Exception): pass @@ -514,7 +514,7 @@ class IAIOnnxRuntimeModel: def __init__(self, model_path: str, provider: Optional[str]): self.path = model_path self.session = None - self.provider = provider or "CPUExecutionProvider" + self.provider = provider """ self.data_path = self.path + "_data" if not os.path.exists(self.data_path): @@ -567,15 +567,19 @@ class IAIOnnxRuntimeModel: # sess.enable_mem_pattern = True # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code - - sess.add_free_dimension_override_by_name("unet_sample_batch", 2) - sess.add_free_dimension_override_by_name("unet_sample_channels", 4) - sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) - sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) - sess.add_free_dimension_override_by_name("unet_sample_height", 64) - sess.add_free_dimension_override_by_name("unet_sample_width", 64) - sess.add_free_dimension_override_by_name("unet_time_batch", 1) - self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess) + # sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + # sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + # sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + # sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + # sess.add_free_dimension_override_by_name("unet_sample_height", 64) + # sess.add_free_dimension_override_by_name("unet_sample_width", 64) + # sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) # self.io_binding = self.session.io_binding()