Review changes

This commit is contained in:
StAlKeR7779 2023-05-05 21:09:29 +03:00
parent 5012f61599
commit 58d7833c5c
4 changed files with 38 additions and 76 deletions

View File

@ -30,35 +30,24 @@ class CompelOutput(BaseInvocationOutput):
#fmt: off
type: Literal["compel_output"] = "compel_output"
# name + loras -> pipeline + loras
# model: ModelField = Field(default=None, description="Model")
# src? + loras -> tokenizer + text_encoder + loras
# clip: ClipField = Field(default=None, description="Text encoder(clip)")
positive: ConditioningField = Field(default=None, description="Positive conditioning")
negative: ConditioningField = Field(default=None, description="Negative conditioning")
conditioning: ConditioningField = Field(default=None, description="Conditioning")
#fmt: on
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
positive_prompt: str = Field(default="", description="Positive prompt")
negative_prompt: str = Field(default="", description="Negative prompt")
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
truncate_long_prompts: bool = Field(default=False, description="Whether or not to truncate long prompt to 77 tokens")
# name + loras -> pipeline + loras
# model: ModelField = Field(default=None, description="Model to use")
# src? + loras -> tokenizer + text_encoder + loras
# clip: ClipField = Field(default=None, description="Text encoder(clip) to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "noise"],
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
@ -88,14 +77,8 @@ class CompelInvocation(BaseInvocation):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
positive_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.positive_prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
negative_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.negative_prompt,
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
@ -103,7 +86,7 @@ class CompelInvocation(BaseInvocation):
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
positive_prompt_str + "[" + negative_prompt_str + "]"
prompt_str
)
compel = Compel(
@ -111,43 +94,35 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=self.truncate_long_prompts,
truncate_long_prompts=True, # TODO:
)
# TODO: support legacy blend?
positive_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(positive_prompt_str)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(negative_prompt_str)
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
if getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
log_tokenization_for_prompt_object(prompt, tokenizer)
# TODO: add lora(with model and clip field types)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
if not self.truncate_long_prompts:
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, positive_prompt),
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
name_prefix = f'{context.graph_execution_state_id}__{self.id}'
name_positive = f"{name_prefix}_positive"
name_negative = f"{name_prefix}_negative"
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(name_positive, (c, ec))
context.services.latents.set(name_negative, (uc, None))
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
positive=ConditioningField(
conditioning_name=name_positive,
),
negative=ConditioningField(
conditioning_name=name_negative,
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@ -195,20 +170,6 @@ def get_tokens_for_prompt_object(
return tokens
def log_tokenization(
positive_prompt: Union[Blend, FlattenedPrompt],
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):

View File

@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
# Text to image
class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt."""
"""Generates latents from conditionings."""
type: Literal["t2l"] = "t2l"
# Inputs
# fmt: off
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
@ -204,8 +204,8 @@ class TextToLatentsInvocation(BaseInvocation):
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
uc, _ = context.services.latents.get(self.negative.conditioning_name)
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData(
uc,

View File

@ -20,28 +20,29 @@ def create_text_to_image() -> LibraryGraph:
'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'),
'4': CompelInvocation(id='4'),
'5': TextToLatentsInvocation(id='5'),
'6': LatentsToImageInvocation(id='6'),
'5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path='4', field='positive_prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='4', field='negative_prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
],
exposed_outputs=[
ExposedNodeOutput(node_path='6', field='image', alias='image')
ExposedNodeOutput(node_path='7', field='image', alias='image')
])

View File

@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
n4 = ShowImageInvocation(id = "4")
g.add_node(n4)
g.add_edge(create_edge("1.6","image","4","image"))
g.add_edge(create_edge("1.7","image","4","image"))
# Validate
dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '2', '3', '4'])
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([
('2','1.width'),
('3','1.height'),
('1.6','4')
('1.7','4')
])
print(expected_edges)
print(list(dg.edges))