mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Review changes
This commit is contained in:
parent
5012f61599
commit
58d7833c5c
@ -30,35 +30,24 @@ class CompelOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
#fmt: off
|
#fmt: off
|
||||||
type: Literal["compel_output"] = "compel_output"
|
type: Literal["compel_output"] = "compel_output"
|
||||||
# name + loras -> pipeline + loras
|
|
||||||
# model: ModelField = Field(default=None, description="Model")
|
conditioning: ConditioningField = Field(default=None, description="Conditioning")
|
||||||
# src? + loras -> tokenizer + text_encoder + loras
|
|
||||||
# clip: ClipField = Field(default=None, description="Text encoder(clip)")
|
|
||||||
positive: ConditioningField = Field(default=None, description="Positive conditioning")
|
|
||||||
negative: ConditioningField = Field(default=None, description="Negative conditioning")
|
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
|
"""Parse prompt using compel package to conditioning."""
|
||||||
|
|
||||||
type: Literal["compel"] = "compel"
|
type: Literal["compel"] = "compel"
|
||||||
|
|
||||||
positive_prompt: str = Field(default="", description="Positive prompt")
|
prompt: str = Field(default="", description="Prompt")
|
||||||
negative_prompt: str = Field(default="", description="Negative prompt")
|
|
||||||
|
|
||||||
model: str = Field(default="", description="Model to use")
|
model: str = Field(default="", description="Model to use")
|
||||||
truncate_long_prompts: bool = Field(default=False, description="Whether or not to truncate long prompt to 77 tokens")
|
|
||||||
|
|
||||||
# name + loras -> pipeline + loras
|
|
||||||
# model: ModelField = Field(default=None, description="Model to use")
|
|
||||||
# src? + loras -> tokenizer + text_encoder + loras
|
|
||||||
# clip: ClipField = Field(default=None, description="Text encoder(clip) to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {
|
||||||
"tags": ["latents", "noise"],
|
"tags": ["prompt", "compel"],
|
||||||
"type_hints": {
|
"type_hints": {
|
||||||
"model": "model"
|
"model": "model"
|
||||||
}
|
}
|
||||||
@ -88,14 +77,8 @@ class CompelInvocation(BaseInvocation):
|
|||||||
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
positive_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
||||||
self.positive_prompt,
|
self.prompt,
|
||||||
lambda concepts: load_huggingface_concepts(concepts),
|
|
||||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
|
||||||
)
|
|
||||||
|
|
||||||
negative_prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
|
|
||||||
self.negative_prompt,
|
|
||||||
lambda concepts: load_huggingface_concepts(concepts),
|
lambda concepts: load_huggingface_concepts(concepts),
|
||||||
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
pipeline.textual_inversion_manager.get_all_trigger_strings(),
|
||||||
)
|
)
|
||||||
@ -103,7 +86,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
|
||||||
positive_prompt_str + "[" + negative_prompt_str + "]"
|
prompt_str
|
||||||
)
|
)
|
||||||
|
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
@ -111,43 +94,35 @@ class CompelInvocation(BaseInvocation):
|
|||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=pipeline.textual_inversion_manager,
|
textual_inversion_manager=pipeline.textual_inversion_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=torch_dtype,
|
||||||
truncate_long_prompts=self.truncate_long_prompts,
|
truncate_long_prompts=True, # TODO:
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: support legacy blend?
|
# TODO: support legacy blend?
|
||||||
|
|
||||||
positive_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(positive_prompt_str)
|
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
|
||||||
negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(negative_prompt_str)
|
|
||||||
|
|
||||||
if getattr(Globals, "log_tokenization", False):
|
if getattr(Globals, "log_tokenization", False):
|
||||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
log_tokenization_for_prompt_object(prompt, tokenizer)
|
||||||
|
|
||||||
# TODO: add lora(with model and clip field types)
|
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
|
||||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
|
||||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
|
||||||
|
|
||||||
if not self.truncate_long_prompts:
|
# TODO: long prompt support
|
||||||
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
#if not self.truncate_long_prompts:
|
||||||
|
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, positive_prompt),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
name_prefix = f'{context.graph_execution_state_id}__{self.id}'
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
name_positive = f"{name_prefix}_positive"
|
|
||||||
name_negative = f"{name_prefix}_negative"
|
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
context.services.latents.set(name_positive, (c, ec))
|
context.services.latents.set(conditioning_name, (c, ec))
|
||||||
context.services.latents.set(name_negative, (uc, None))
|
|
||||||
|
|
||||||
return CompelOutput(
|
return CompelOutput(
|
||||||
positive=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
conditioning_name=name_positive,
|
conditioning_name=conditioning_name,
|
||||||
),
|
|
||||||
negative=ConditioningField(
|
|
||||||
conditioning_name=name_negative,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -195,20 +170,6 @@ def get_tokens_for_prompt_object(
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization(
|
|
||||||
positive_prompt: Union[Blend, FlattenedPrompt],
|
|
||||||
negative_prompt: Union[Blend, FlattenedPrompt],
|
|
||||||
tokenizer,
|
|
||||||
):
|
|
||||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
|
||||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
|
||||||
|
|
||||||
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
|
||||||
log_tokenization_for_prompt_object(
|
|
||||||
negative_prompt, tokenizer, display_label_prefix="(negative prompt)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization_for_prompt_object(
|
def log_tokenization_for_prompt_object(
|
||||||
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
|
||||||
):
|
):
|
||||||
|
@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class TextToLatentsInvocation(BaseInvocation):
|
class TextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from a prompt."""
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
type: Literal["t2l"] = "t2l"
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# fmt: off
|
# fmt: off
|
||||||
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
@ -204,8 +204,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
|
c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
uc, _ = context.services.latents.get(self.negative.conditioning_name)
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
|
@ -20,28 +20,29 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
'seed': ParamIntInvocation(id='seed', a=-1),
|
'seed': ParamIntInvocation(id='seed', a=-1),
|
||||||
'3': NoiseInvocation(id='3'),
|
'3': NoiseInvocation(id='3'),
|
||||||
'4': CompelInvocation(id='4'),
|
'4': CompelInvocation(id='4'),
|
||||||
'5': TextToLatentsInvocation(id='5'),
|
'5': CompelInvocation(id='5'),
|
||||||
'6': LatentsToImageInvocation(id='6'),
|
'6': TextToLatentsInvocation(id='6'),
|
||||||
|
'7': LatentsToImageInvocation(id='7'),
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||||
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
|
||||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
|
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
|
||||||
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
|
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
|
||||||
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),
|
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
|
||||||
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')),
|
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path='4', field='positive_prompt', alias='positive_prompt'),
|
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
|
||||||
ExposedNodeInput(node_path='4', field='negative_prompt', alias='negative_prompt'),
|
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
|
||||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||||
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
ExposedNodeInput(node_path='height', field='a', alias='height'),
|
||||||
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
|
||||||
],
|
],
|
||||||
exposed_outputs=[
|
exposed_outputs=[
|
||||||
ExposedNodeOutput(node_path='6', field='image', alias='image')
|
ExposedNodeOutput(node_path='7', field='image', alias='image')
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
|
|||||||
|
|
||||||
n4 = ShowImageInvocation(id = "4")
|
n4 = ShowImageInvocation(id = "4")
|
||||||
g.add_node(n4)
|
g.add_node(n4)
|
||||||
g.add_edge(create_edge("1.6","image","4","image"))
|
g.add_edge(create_edge("1.7","image","4","image"))
|
||||||
|
|
||||||
# Validate
|
# Validate
|
||||||
dg = g.nx_graph_flat()
|
dg = g.nx_graph_flat()
|
||||||
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '2', '3', '4'])
|
assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
|
||||||
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
|
||||||
expected_edges.extend([
|
expected_edges.extend([
|
||||||
('2','1.width'),
|
('2','1.width'),
|
||||||
('3','1.height'),
|
('3','1.height'),
|
||||||
('1.6','4')
|
('1.7','4')
|
||||||
])
|
])
|
||||||
print(expected_edges)
|
print(expected_edges)
|
||||||
print(list(dg.edges))
|
print(list(dg.edges))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user