Remove TensorRT support at the current time until we validate it works, remove time step recorder

This commit is contained in:
Brandon Rising 2023-07-27 11:18:50 -04:00
parent 918a0dedc0
commit 59716938bf
2 changed files with 2 additions and 5 deletions

View File

@ -264,8 +264,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
) )
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
import time
times = []
for i in tqdm(range(len(scheduler.timesteps))): for i in tqdm(range(len(scheduler.timesteps))):
t = scheduler.timesteps[i] t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
@ -275,9 +273,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# predict the noise residual # predict the noise residual
timestep = np.array([t], dtype=timestep_dtype) timestep = np.array([t], dtype=timestep_dtype)
start_time = time.time()
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
times.append(time.time() - start_time)
noise_pred = noise_pred[0] noise_pred = noise_pred[0]
# perform guidance # perform guidance
@ -307,7 +303,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# call the callback, if provided # call the callback, if provided
#if callback is not None and i % callback_steps == 0: #if callback is not None and i % callback_steps == 0:
# callback(i, t, latents) # callback(i, t, latents)
print(times)
unet.release_session() unet.release_session()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -581,6 +581,8 @@ class IAIOnnxRuntimeModel:
providers.append(self.provider) providers.append(self.provider)
else: else:
providers = get_available_providers() providers = get_available_providers()
if "TensorrtExecutionProvider" in providers:
providers.remove("TensorrtExecutionProvider")
try: try:
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
except Exception as e: except Exception as e: