mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Testing different session opts, added timings for testing
This commit is contained in:
parent
932112b640
commit
bcce70fca6
@ -195,7 +195,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
latents = latents.cpu().numpy()
|
latents = latents.cpu().numpy()
|
||||||
|
|
||||||
# TODO: better execution device handling
|
# TODO: better execution device handling
|
||||||
latents = latents.astype(np.float32)
|
latents = latents.astype(np.float16)
|
||||||
|
|
||||||
# get the initial random noise unless the user supplied it
|
# get the initial random noise unless the user supplied it
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
@ -232,10 +232,11 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
unet.create_session()
|
unet.create_session()
|
||||||
|
|
||||||
timestep_dtype = next(
|
timestep_dtype = next(
|
||||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float)"
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||||
)
|
)
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
import time
|
||||||
|
times = []
|
||||||
for i in tqdm(range(len(scheduler.timesteps))):
|
for i in tqdm(range(len(scheduler.timesteps))):
|
||||||
t = scheduler.timesteps[i]
|
t = scheduler.timesteps[i]
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
@ -245,7 +246,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
start_time = time.time()
|
||||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||||
|
times.append(time.time() - start_time)
|
||||||
noise_pred = noise_pred[0]
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
@ -262,14 +265,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
|
print(times)
|
||||||
unet.release_session()
|
unet.release_session()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
class ONNXLatentsToImageInvocation(BaseInvocation):
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
|
@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
|
|||||||
import onnx
|
import onnx
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnx.external_data_helper import set_external_data
|
from onnx.external_data_helper import set_external_data
|
||||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
|
||||||
class InvalidModelException(Exception):
|
class InvalidModelException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -552,6 +552,8 @@ class IAIOnnxRuntimeModel:
|
|||||||
sess = SessionOptions()
|
sess = SessionOptions()
|
||||||
#self._external_data.update(**external_data)
|
#self._external_data.update(**external_data)
|
||||||
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||||
|
sess.execution_mode = ExecutionMode.ORT_PARALLEL
|
||||||
|
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
|
||||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user