Allow passing in of precision, use available providers if none provided

This commit is contained in:
Brandon Rising 2023-07-20 13:15:45 -04:00
parent 23f4a4ea1a
commit 4e90376d11
2 changed files with 20 additions and 13 deletions

View File

@ -49,6 +49,9 @@ ORT_TO_NP_TYPE = {
"tensor(double)": np.float64, "tensor(double)": np.float64,
} }
PRECISION_VALUES = Literal[
tuple(list(ORT_TO_NP_TYPE.keys()))
]
class ONNXPromptInvocation(BaseInvocation): class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx" type: Literal["prompt_onnx"] = "prompt_onnx"
@ -151,6 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
@ -202,7 +206,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
latents = latents.cpu().numpy() latents = latents.cpu().numpy()
# TODO: better execution device handling # TODO: better execution device handling
latents = latents.astype(np.float16) latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
# get the initial random noise unless the user supplied it # get the initial random noise unless the user supplied it
do_classifier_free_guidance = True do_classifier_free_guidance = True
@ -486,7 +490,6 @@ class OnnxModelLoaderInvocation(BaseInvocation):
type: Literal["onnx_model_loader"] = "onnx_model_loader" type: Literal["onnx_model_loader"] = "onnx_model_loader"
model: OnnxModelField = Field(description="The model to load") model: OnnxModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):

View File

@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
import onnx import onnx
from onnx import numpy_helper from onnx import numpy_helper
from onnx.external_data_helper import set_external_data from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel, get_available_providers
class InvalidModelException(Exception): class InvalidModelException(Exception):
pass pass
@ -514,7 +514,7 @@ class IAIOnnxRuntimeModel:
def __init__(self, model_path: str, provider: Optional[str]): def __init__(self, model_path: str, provider: Optional[str]):
self.path = model_path self.path = model_path
self.session = None self.session = None
self.provider = provider or "CPUExecutionProvider" self.provider = provider
""" """
self.data_path = self.path + "_data" self.data_path = self.path + "_data"
if not os.path.exists(self.data_path): if not os.path.exists(self.data_path):
@ -567,15 +567,19 @@ class IAIOnnxRuntimeModel:
# sess.enable_mem_pattern = True # sess.enable_mem_pattern = True
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
# sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
sess.add_free_dimension_override_by_name("unet_sample_batch", 2) # sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
sess.add_free_dimension_override_by_name("unet_sample_channels", 4) # sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) # sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) # sess.add_free_dimension_override_by_name("unet_sample_height", 64)
sess.add_free_dimension_override_by_name("unet_sample_height", 64) # sess.add_free_dimension_override_by_name("unet_sample_width", 64)
sess.add_free_dimension_override_by_name("unet_sample_width", 64) # sess.add_free_dimension_override_by_name("unet_time_batch", 1)
sess.add_free_dimension_override_by_name("unet_time_batch", 1) providers = []
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess) if self.provider:
providers.append(self.provider)
else:
providers = get_available_providers()
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
# self.io_binding = self.session.io_binding() # self.io_binding = self.session.io_binding()