Emit step completions

This commit is contained in:
Brandon Rising 2023-07-18 12:35:07 -04:00
parent bcce70fca6
commit 35d5ef9118
4 changed files with 43 additions and 5 deletions

View File

@ -15,6 +15,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ONNXModelPatcher
from ...backend.util import choose_torch_device
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
@ -23,6 +24,8 @@ from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.backend import BaseModelType, ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.stable_diffusion import PipelineIntermediateState
from tqdm import tqdm
from .model import ClipField
@ -183,11 +186,14 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
if isinstance(c, torch.Tensor):
c = c.cpu().numpy()
if isinstance(uc, torch.Tensor):
uc = uc.cpu().numpy()
device = torch.device(choose_torch_device())
prompt_embeds = np.concatenate([uc, c])
latents = context.services.latents.get(self.noise.latents_name)
@ -210,6 +216,22 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler,
)
def torch2numpy(latent: torch.Tensor):
return latent.cpu().numpy()
def numpy2torch(latent, device):
return torch.from_numpy(latent).to(device)
def dispatch_progress(
self, context: InvocationContext, source_node_id: str,
intermediate_state: PipelineIntermediateState) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
@ -241,7 +263,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
@ -258,9 +280,22 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
)
latents = torch2numpy(scheduler_output.prev_sample)
state = PipelineIntermediateState(
run_id= "test",
step=i,
timestep=timestep,
latents=scheduler_output.prev_sample
)
dispatch_progress(
self,
context=context,
source_node_id=source_node_id,
intermediate_state=state
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:

View File

@ -466,7 +466,6 @@ class Generator:
dtype=samples.dtype,
device=samples.device,
)
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
latents_ubyte = (
((latent_image + 1) / 2)

View File

@ -554,6 +554,8 @@ class IAIOnnxRuntimeModel:
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
sess.execution_mode = ExecutionMode.ORT_PARALLEL
sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
# sess.add_free_dimension_override_by_name("unet_sample_height", 64)
# sess.add_free_dimension_override_by_name("unet_sample_width", 64)
self.session = InferenceSession(self.proto.SerializeToString(), providers=['CUDAExecutionProvider', 'CPUExecutionProvider'], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)

View File

@ -63,6 +63,8 @@ dependencies = [
"npyscreen",
"numpy<1.24",
"omegaconf",
"onnx",
"onnxruntime-gpu",
"opencv-python",
"picklescan",
"pillow",