Testing different session opts, added timings for testing

This commit is contained in:
Brandon Rising 2023-07-17 16:27:33 -04:00
parent 932112b640
commit bcce70fca6
2 changed files with 11 additions and 6 deletions

View File

@ -195,7 +195,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
latents = latents.cpu().numpy()
# TODO: better execution device handling
latents = latents.astype(np.float32)
latents = latents.astype(np.float16)
# get the initial random noise unless the user supplied it
do_classifier_free_guidance = True
@ -232,10 +232,11 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
unet.create_session()
timestep_dtype = next(
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float)"
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
import time
times = []
for i in tqdm(range(len(scheduler.timesteps))):
t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance
@ -245,7 +246,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
start_time = time.time()
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
times.append(time.time() - start_time)
noise_pred = noise_pred[0]
# perform guidance
@ -262,14 +265,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
print(times)
unet.release_session()
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
# Latent to image
class ONNXLatentsToImageInvocation(BaseInvocation):

View File

@ -20,7 +20,7 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel
class InvalidModelException(Exception):
pass
@ -552,6 +552,8 @@ class IAIOnnxRuntimeModel:
sess = SessionOptions()
#self._external_data.update(**external_data)
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
sess.execution_mode = ExecutionMode.ORT_PARALLEL
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)