From fbbc4b3f69141f09e1617cc94821736cb6759054 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 18 Jul 2023 16:51:16 +0300 Subject: [PATCH] Fixes --- invokeai/app/invocations/generate.py | 13 +++++++++---- invokeai/app/invocations/latent.py | 5 +++++ invokeai/app/invocations/sdxl.py | 4 ++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 6cdb83effc..c5967e4074 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -146,9 +146,13 @@ class InpaintInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning(self, context): - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning(self, context, unet): + positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) + extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning + + negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype) return (uc, c, extra_conditioning_info) @@ -209,7 +213,6 @@ class InpaintInvocation(BaseInvocation): ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] - conditioning = self.get_conditioning(context) scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -217,6 +220,8 @@ class InpaintInvocation(BaseInvocation): ) with self.load_model_old_way(context, scheduler) as model: + conditioning = self.get_conditioning(context, model.context.model.unet) + outputs = Inpaint(model).generate( conditioning=conditioning, scheduler=scheduler, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 1a34e58d28..8af5a943a8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -334,6 +334,8 @@ class TextToLatentsInvocation(BaseInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, @@ -425,6 +427,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: + noise = noise.to(device=unet.device, dtype=unet.dtype) + latent = latent.to(device=unet.device, dtype=unet.dtype) + scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index f73418e18e..2cc12f779e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -305,7 +305,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -351,7 +351,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype) - with tqdm(total=self.steps) as progress_bar: + with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents