mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove TensorRT support at the current time until we validate it works, remove time step recorder
This commit is contained in:
parent
918a0dedc0
commit
59716938bf
@ -264,8 +264,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||||
)
|
)
|
||||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
import time
|
|
||||||
times = []
|
|
||||||
for i in tqdm(range(len(scheduler.timesteps))):
|
for i in tqdm(range(len(scheduler.timesteps))):
|
||||||
t = scheduler.timesteps[i]
|
t = scheduler.timesteps[i]
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
@ -275,9 +273,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
timestep = np.array([t], dtype=timestep_dtype)
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
start_time = time.time()
|
|
||||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||||
times.append(time.time() - start_time)
|
|
||||||
noise_pred = noise_pred[0]
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
@ -307,7 +303,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
# call the callback, if provided
|
# call the callback, if provided
|
||||||
#if callback is not None and i % callback_steps == 0:
|
#if callback is not None and i % callback_steps == 0:
|
||||||
# callback(i, t, latents)
|
# callback(i, t, latents)
|
||||||
print(times)
|
|
||||||
unet.release_session()
|
unet.release_session()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -581,6 +581,8 @@ class IAIOnnxRuntimeModel:
|
|||||||
providers.append(self.provider)
|
providers.append(self.provider)
|
||||||
else:
|
else:
|
||||||
providers = get_available_providers()
|
providers = get_available_providers()
|
||||||
|
if "TensorrtExecutionProvider" in providers:
|
||||||
|
providers.remove("TensorrtExecutionProvider")
|
||||||
try:
|
try:
|
||||||
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Loading…
Reference in New Issue
Block a user