mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow passing in of precision, use available providers if none provided
This commit is contained in:
parent
23f4a4ea1a
commit
4e90376d11
@ -49,6 +49,9 @@ ORT_TO_NP_TYPE = {
|
|||||||
"tensor(double)": np.float64,
|
"tensor(double)": np.float64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PRECISION_VALUES = Literal[
|
||||||
|
tuple(list(ORT_TO_NP_TYPE.keys()))
|
||||||
|
]
|
||||||
|
|
||||||
class ONNXPromptInvocation(BaseInvocation):
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||||
@ -151,6 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
|
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
@ -202,7 +206,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.cpu().numpy()
|
latents = latents.cpu().numpy()
|
||||||
|
|
||||||
# TODO: better execution device handling
|
# TODO: better execution device handling
|
||||||
latents = latents.astype(np.float16)
|
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
@ -486,7 +490,6 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
|||||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||||
|
|
||||||
model: OnnxModelField = Field(description="The model to load")
|
model: OnnxModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
|
@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
|
|||||||
import onnx
|
import onnx
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnx.external_data_helper import set_external_data
|
from onnx.external_data_helper import set_external_data
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel, get_available_providers
|
||||||
class InvalidModelException(Exception):
|
class InvalidModelException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -514,7 +514,7 @@ class IAIOnnxRuntimeModel:
|
|||||||
def __init__(self, model_path: str, provider: Optional[str]):
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
self.path = model_path
|
self.path = model_path
|
||||||
self.session = None
|
self.session = None
|
||||||
self.provider = provider or "CPUExecutionProvider"
|
self.provider = provider
|
||||||
"""
|
"""
|
||||||
self.data_path = self.path + "_data"
|
self.data_path = self.path + "_data"
|
||||||
if not os.path.exists(self.data_path):
|
if not os.path.exists(self.data_path):
|
||||||
@ -567,15 +567,19 @@ class IAIOnnxRuntimeModel:
|
|||||||
# sess.enable_mem_pattern = True
|
# sess.enable_mem_pattern = True
|
||||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||||
|
|
||||||
|
# sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
# sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
# sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
# sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||||
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
# sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
providers = []
|
||||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
if self.provider:
|
||||||
|
providers.append(self.provider)
|
||||||
|
else:
|
||||||
|
providers = get_available_providers()
|
||||||
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
# self.io_binding = self.session.io_binding()
|
# self.io_binding = self.session.io_binding()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user