Combine conditioning to one field(better fits for multiple type conditioning like perp-neg)

This commit is contained in:
StAlKeR7779 2023-05-04 20:14:22 +03:00
parent 56d3cbead0
commit 7d221e2518
2 changed files with 7 additions and 16 deletions

View File

@ -34,8 +34,7 @@ class CompelOutput(BaseInvocationOutput):
# model: ModelField = Field(default=None, description="Model")
# src? + loras -> tokenizer + text_encoder + loras
# clip: ClipField = Field(default=None, description="Text encoder(clip)")
positive: ConditioningField = Field(default=None, description="Positive conditioning")
negative: ConditioningField = Field(default=None, description="Negative conditioning")
conditioning: ConditioningField = Field(default=None, description="Conditioning")
#fmt: on
@ -134,20 +133,14 @@ class CompelInvocation(BaseInvocation):
cross_attention_control_args=options.get("cross_attention_control", None),
)
name_prefix = f'{context.graph_execution_state_id}__{self.id}'
name_positive = f"{name_prefix}_positive"
name_negative = f"{name_prefix}_negative"
name_cond = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(name_positive, (c, ec))
context.services.latents.set(name_negative, (uc, None))
context.services.latents.set(name_cond, (c, uc, ec))
return CompelOutput(
positive=ConditioningField(
conditioning_name=name_positive,
),
negative=ConditioningField(
conditioning_name=name_negative,
conditioning=ConditioningField(
conditioning_name=name_cond,
),
)

View File

@ -144,8 +144,7 @@ class TextToLatentsInvocation(BaseInvocation):
# Inputs
# fmt: off
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
conditioning: Optional[ConditioningField] = Field(description="Conditioning for generation")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
@ -205,8 +204,7 @@ class TextToLatentsInvocation(BaseInvocation):
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
uc, _ = context.services.latents.get(self.negative.conditioning_name)
c, uc, extra_conditioning_info = context.services.latents.get(self.conditioning.conditioning_name)
conditioning_data = ConditioningData(
uc,