what have i done

This commit is contained in:
maryhipp 2024-02-27 15:41:03 -05:00 committed by Kent Keirsey
parent 6fff7de2ab
commit 600b4c6a90

View File

@ -13,6 +13,7 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo, ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
@ -21,10 +22,6 @@ from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField from .model import ClipField
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]
@ -39,7 +36,7 @@ class ConditioningFieldData:
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -64,16 +61,15 @@ class CompelInvocation(BaseInvocation):
text_encoder_model = text_encoder_info.model text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel) assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader(): def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context assert isinstance(lora_info.model, LoRAModelRaw)
) yield (lora_info.model, lora.weight)
yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
@ -99,7 +95,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.services.configuration.log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -120,17 +116,14 @@ class CompelInvocation(BaseInvocation):
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel( def run_clip_compel(
self, self,
context: InvocationContext, context: InvocationContext,
@ -149,14 +142,15 @@ class SDXLPromptInvocationBase:
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros( c = torch.zeros(
( (
1, 1,
cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size, cpu_text_encoder.config.hidden_size,
), ),
dtype=text_encoder_info.context.cache.precision, dtype=cpu_text_encoder.dtype,
) )
if get_pooled: if get_pooled:
c_pooled = torch.zeros( c_pooled = torch.zeros(
@ -167,16 +161,16 @@ class SDXLPromptInvocationBase:
c_pooled = None c_pooled = None
return c, c_pooled, None return c, c_pooled, None
def _lora_loader(): def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context lora_model = lora_info.model
) assert isinstance(lora_model, LoRAModelRaw)
yield (lora_info.context.model, lora.weight) yield (lora_model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
@ -205,7 +199,7 @@ class SDXLPromptInvocationBase:
conjunction = Compel.parse_prompt_string(prompt) conjunction = Compel.parse_prompt_string(prompt)
if context.services.configuration.log_tokenization: if context.config.get().log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
@ -238,7 +232,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -309,6 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1, dim=1,
) )
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -320,14 +315,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation( @invocation(
@ -335,7 +325,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt", title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -362,6 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -373,14 +364,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")
@ -401,7 +387,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers
@ -411,9 +397,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction], prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False, truncate_if_too_long: bool = False,
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
@ -425,7 +411,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -438,24 +426,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text) tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long: if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75 max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length] tokens = tokens[0:max_tokens_length]
return tokens return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts) > 1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
@ -495,7 +488,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '