mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow passing in of precision, use available providers if none provided
This commit is contained in:
parent
23f4a4ea1a
commit
4e90376d11
@ -49,6 +49,9 @@ ORT_TO_NP_TYPE = {
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[
|
||||
tuple(list(ORT_TO_NP_TYPE.keys()))
|
||||
]
|
||||
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||
@ -151,6 +154,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
|
||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
@ -202,7 +206,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(np.float16)
|
||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
@ -486,7 +490,6 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
||||
|
||||
model: OnnxModelField = Field(description="The model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
|
@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel, get_available_providers
|
||||
class InvalidModelException(Exception):
|
||||
pass
|
||||
|
||||
@ -514,7 +514,7 @@ class IAIOnnxRuntimeModel:
|
||||
def __init__(self, model_path: str, provider: Optional[str]):
|
||||
self.path = model_path
|
||||
self.session = None
|
||||
self.provider = provider or "CPUExecutionProvider"
|
||||
self.provider = provider
|
||||
"""
|
||||
self.data_path = self.path + "_data"
|
||||
if not os.path.exists(self.data_path):
|
||||
@ -567,15 +567,19 @@ class IAIOnnxRuntimeModel:
|
||||
# sess.enable_mem_pattern = True
|
||||
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||
|
||||
|
||||
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||
# sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||
# sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
|
||||
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
|
||||
# sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||
providers = []
|
||||
if self.provider:
|
||||
providers.append(self.provider)
|
||||
else:
|
||||
providers = get_available_providers()
|
||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
# self.io_binding = self.session.io_binding()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user