From fe78a08e3719b3e45866ffac743e58aa14f57946 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 16 Jul 2023 06:24:24 +0300 Subject: [PATCH] Fix sd1/2 models conditionings --- invokeai/app/invocations/compel.py | 12 +++++++++--- invokeai/app/invocations/latent.py | 12 ++++++------ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b8f2ec4250..1c8faee50b 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -147,11 +147,17 @@ class CompelInvocation(BaseInvocation): cross_attention_control_args=options.get( "cross_attention_control", None),) - raise NotImplementedError("TODO: redo to new conditionings") + conditioning_data = ConditioningFieldData( + conditionings=[ + BasicConditioningInfo( + embeds=c, + extra_conditioning=ec, + ) + ] + ) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) + context.services.latents.save(conditioning_name, conditioning_data) return CompelOutput( conditioning=ConditioningField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c8f9897aa6..19207d78d7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -168,12 +168,12 @@ class TextToLatentsInvocation(BaseInvocation): context: InvocationContext, scheduler, ) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get( - self.positive_conditioning.conditioning_name - ) - uc, _ = context.services.latents.get( - self.negative_conditioning.conditioning_name - ) + positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) + c = positive_cond_data.conditionings[0].embeds + extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning + + negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) + uc = negative_cond_data.conditionings[0].embeds conditioning_data = ConditioningData( unconditioned_embeddings=uc,