mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Review changes
This commit is contained in:
parent
5012f61599
commit
58d7833c5c
@ -30,35 +30,24 @@ class CompelOutput(BaseInvocationOutput):
|
||||
|
||||
#fmt: off
|
||||
type: Literal["compel_output"] = "compel_output"
|
||||
# name + loras -> pipeline + loras
|
||||
# model: ModelField = Field(default=None, description="Model")
|
||||
# src? + loras -> tokenizer + text_encoder + loras
|
||||
# clip: ClipField = Field(default=None, description="Text encoder(clip)")
|
||||
positive: ConditioningField = Field(default=None, description="Positive conditioning")
|
||||
negative: ConditioningField = Field(default=None, description="Negative conditioning")
|
||||
|
||||
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||
#fmt: on
|
||||
|
||||
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
|
||||
type: Literal["compel"] = "compel"
|
||||
|
||||
positive_prompt: str = Field(default="", description="Positive prompt")
|
||||
negative_prompt: str = Field(default="", description="Negative prompt")
|
||||
|
||||
prompt: str = Field(default="", description="Prompt")
|
||||
model: str = Field(default="", description="Model to use")
|
||||
truncate_long_prompts: bool = Field(default=False, description="Whether or not to truncate long prompt to 77 tokens")
|
||||
|
||||
# name + loras -> pipeline + loras
|
||||
# model: ModelField = Field(default=None, description="Model to use")
|
||||
# src? + loras -> tokenizer + text_encoder + loras
|
||||
# clip: ClipField = Field(default=None, description="Text encoder(clip) to use")
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["latents", "noise"],
|
||||
"tags": ["prompt", "compel"],
|
||||
"type_hints": {
|
||||
"model": "model"
|
||||
}
|
||||
@ -88,14 +77,8 @@ class CompelInvocation(BaseInvocation):
|
||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
positive_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.positive_prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
|
||||
negative_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.negative_prompt,
|
||||
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||
self.prompt,
|
||||
lambda concepts: load_huggingface_concepts(concepts),
|
||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||
)
|
||||
@ -103,7 +86,7 @@ class CompelInvocation(BaseInvocation):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||
positive_prompt_str + "[" + negative_prompt_str + "]"
|
||||
prompt_str
|
||||
)
|
||||
|
||||
compel = Compel(
|
||||
@ -111,43 +94,35 @@ class CompelInvocation(BaseInvocation):
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype,
|
||||
truncate_long_prompts=self.truncate_long_prompts,
|
||||
truncate_long_prompts=True, # TODO:
|
||||
)
|
||||
|
||||
# TODO: support legacy blend?
|
||||
|
||||
positive_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(positive_prompt_str)
|
||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(negative_prompt_str)
|
||||
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||
|
||||
if getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||
|
||||
# TODO: add lora(with model and clip field types)
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||
|
||||
if not self.truncate_long_prompts:
|
||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
# TODO: long prompt support
|
||||
#if not self.truncate_long_prompts:
|
||||
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, positive_prompt),
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
|
||||
name_prefix = f'{context.graph_execution_state_id}__{self.id}'
|
||||
name_positive = f"{name_prefix}_positive"
|
||||
name_negative = f"{name_prefix}_negative"
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.set(name_positive, (c, ec))
|
||||
context.services.latents.set(name_negative, (uc, None))
|
||||
context.services.latents.set(conditioning_name, (c, ec))
|
||||
|
||||
return CompelOutput(
|
||||
positive=ConditioningField(
|
||||
conditioning_name=name_positive,
|
||||
),
|
||||
negative=ConditioningField(
|
||||
conditioning_name=name_negative,
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
@ -195,20 +170,6 @@ def get_tokens_for_prompt_object(
|
||||
return tokens
|
||||
|
||||
|
||||
def log_tokenization(
|
||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
||||
tokenizer,
|
||||
):
|
||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||
|
||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||
log_tokenization_for_prompt_object(
|
||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
||||
)
|
||||
|
||||
|
||||
def log_tokenization_for_prompt_object(
|
||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||
):
|
||||
|
@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
# Text to image
|
||||
class TextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from a prompt."""
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
type: Literal["t2l"] = "t2l"
|
||||
|
||||
# Inputs
|
||||
# fmt: off
|
||||
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||
@ -204,8 +204,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
|
||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative.conditioning_name)
|
||||
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
|
@ -20,28 +20,29 @@ def create_text_to_image() -> LibraryGraph:
|
||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||
'3': NoiseInvocation(id='3'),
|
||||
'4': CompelInvocation(id='4'),
|
||||
'5': TextToLatentsInvocation(id='5'),
|
||||
'6': LatentsToImageInvocation(id='6'),
|
||||
'5': CompelInvocation(id='5'),
|
||||
'6': TextToLatentsInvocation(id='6'),
|
||||
'7': LatentsToImageInvocation(id='7'),
|
||||
},
|
||||
edges=[
|
||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')),
|
||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||
]
|
||||
),
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path='4', field='positive_prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='4', field='negative_prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||
],
|
||||
exposed_outputs=[
|
||||
ExposedNodeOutput(node_path='6', field='image', alias='image')
|
||||
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||
])
|
||||
|
||||
|
||||
|
@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
|
||||
|
||||
n4 = ShowImageInvocation(id = "4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.6","image","4","image"))
|
||||
g.add_edge(create_edge("1.7","image","4","image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '2', '3', '4'])
|
||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
|
||||
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||
expected_edges.extend([
|
||||
('2','1.width'),
|
||||
('3','1.height'),
|
||||
('1.6','4')
|
||||
('1.7','4')
|
||||
])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
|
Loading…
Reference in New Issue
Block a user