Review changes

This commit is contained in:
StAlKeR7779 2023-05-05 21:09:29 +03:00
parent 5012f61599
commit 58d7833c5c
4 changed files with 38 additions and 76 deletions

View File

@ -30,35 +30,24 @@ class CompelOutput(BaseInvocationOutput):
#fmt: off #fmt: off
type: Literal["compel_output"] = "compel_output" type: Literal["compel_output"] = "compel_output"
# name + loras -> pipeline + loras
# model: ModelField = Field(default=None, description="Model") conditioning: ConditioningField = Field(default=None, description="Conditioning")
# src? + loras -> tokenizer + text_encoder + loras
# clip: ClipField = Field(default=None, description="Text encoder(clip)")
positive: ConditioningField = Field(default=None, description="Positive conditioning")
negative: ConditioningField = Field(default=None, description="Negative conditioning")
#fmt: on #fmt: on
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel" type: Literal["compel"] = "compel"
positive_prompt: str = Field(default="", description="Positive prompt") prompt: str = Field(default="", description="Prompt")
negative_prompt: str = Field(default="", description="Negative prompt")
model: str = Field(default="", description="Model to use") model: str = Field(default="", description="Model to use")
truncate_long_prompts: bool = Field(default=False, description="Whether or not to truncate long prompt to 77 tokens")
# name + loras -> pipeline + loras
# model: ModelField = Field(default=None, description="Model to use")
# src? + loras -> tokenizer + text_encoder + loras
# clip: ClipField = Field(default=None, description="Text encoder(clip) to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {
"tags": ["latents", "noise"], "tags": ["prompt", "compel"],
"type_hints": { "type_hints": {
"model": "model" "model": "model"
} }
@ -88,14 +77,8 @@ class CompelInvocation(BaseInvocation):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts) pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt # apply the concepts library to the prompt
positive_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers( prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.positive_prompt, self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
negative_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.negative_prompt,
lambda concepts: load_huggingface_concepts(concepts), lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(), pipeline.textual_inversion_manager.get_all_trigger_strings(),
) )
@ -103,7 +86,7 @@ class CompelInvocation(BaseInvocation):
# lazy-load any deferred textual inversions. # lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used. # this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
positive_prompt_str + "[" + negative_prompt_str + "]" prompt_str
) )
compel = Compel( compel = Compel(
@ -111,43 +94,35 @@ class CompelInvocation(BaseInvocation):
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager, textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=self.truncate_long_prompts, truncate_long_prompts=True, # TODO:
) )
# TODO: support legacy blend? # TODO: support legacy blend?
positive_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(positive_prompt_str) prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(negative_prompt_str)
if getattr(Globals, "log_tokenization", False): if getattr(Globals, "log_tokenization", False):
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer) log_tokenization_for_prompt_object(prompt, tokenizer)
# TODO: add lora(with model and clip field types) c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
if not self.truncate_long_prompts: # TODO: long prompt support
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) #if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, positive_prompt), tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
name_prefix = f'{context.graph_execution_state_id}__{self.id}' conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
name_positive = f"{name_prefix}_positive"
name_negative = f"{name_prefix}_negative"
# TODO: hacky but works ;D maybe rename latents somehow? # TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(name_positive, (c, ec)) context.services.latents.set(conditioning_name, (c, ec))
context.services.latents.set(name_negative, (uc, None))
return CompelOutput( return CompelOutput(
positive=ConditioningField( conditioning=ConditioningField(
conditioning_name=name_positive, conditioning_name=conditioning_name,
),
negative=ConditioningField(
conditioning_name=name_negative,
), ),
) )
@ -195,20 +170,6 @@ def get_tokens_for_prompt_object(
return tokens return tokens
def log_tokenization(
positive_prompt: Union[Blend, FlattenedPrompt],
negative_prompt: Union[Blend, FlattenedPrompt],
tokenizer,
):
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
log_tokenization_for_prompt_object(
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
)
def log_tokenization_for_prompt_object( def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
): ):

View File

@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt.""" """Generates latents from conditionings."""
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
# Inputs # Inputs
# fmt: off # fmt: off
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation") positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
@ -204,8 +204,8 @@ class TextToLatentsInvocation(BaseInvocation):
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,

View File

@ -20,28 +20,29 @@ def create_text_to_image() -> LibraryGraph:
'seed': ParamIntInvocation(id='seed', a=-1), 'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'), '3': NoiseInvocation(id='3'),
'4': CompelInvocation(id='4'), '4': CompelInvocation(id='4'),
'5': TextToLatentsInvocation(id='5'), '5': CompelInvocation(id='5'),
'6': LatentsToImageInvocation(id='6'), '6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')), Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')), Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')), Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')), Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')), Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
] ]
), ),
exposed_inputs=[ exposed_inputs=[
ExposedNodeInput(node_path='4', field='positive_prompt', alias='positive_prompt'), ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='4', field='negative_prompt', alias='negative_prompt'), ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'), ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height'), ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'), ExposedNodeInput(node_path='seed', field='a', alias='seed'),
], ],
exposed_outputs=[ exposed_outputs=[
ExposedNodeOutput(node_path='6', field='image', alias='image') ExposedNodeOutput(node_path='7', field='image', alias='image')
]) ])

View File

@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
n4 = ShowImageInvocation(id = "4") n4 = ShowImageInvocation(id = "4")
g.add_node(n4) g.add_node(n4)
g.add_edge(create_edge("1.6","image","4","image")) g.add_edge(create_edge("1.7","image","4","image"))
# Validate # Validate
dg = g.nx_graph_flat() dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '2', '3', '4']) assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([ expected_edges.extend([
('2','1.width'), ('2','1.width'),
('3','1.height'), ('3','1.height'),
('1.6','4') ('1.7','4')
]) ])
print(expected_edges) print(expected_edges)
print(list(dg.edges)) print(list(dg.edges))