what have i done

This commit is contained in:
maryhipp 2024-02-27 15:41:03 -05:00
parent 16b3718d6a
commit ef474a3196

View File

@ -1,40 +1,43 @@
from dataclasses import dataclass from typing import Iterator, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTokenizer
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput import invokeai.backend.util.logging as logger
from invokeai.app.shared.fields import FieldDescriptions from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo, ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
invocation, invocation,
invocation_output, invocation_output,
) )
from .model import ClipField from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum): # class ConditioningAlgo(str, Enum):
@ -48,7 +51,7 @@ class ConditioningFieldData:
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -66,49 +69,34 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
**self.clip.tokenizer.model_dump(), text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader(): def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context assert isinstance(lora_info.model, LoRAModelRaw)
) yield (lora_info.model, lora.weight)
yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt): for trigger in extract_ti_triggers_from_prompt(self.prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( loaded_model = context.models.load(key=name).model
( assert isinstance(loaded_model, TextualInversionModelRaw)
name, ti_list.append((name, loaded_model))
context.services.model_manager.get_model( except UnknownModelException:
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
)
except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ( with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@ -116,7 +104,7 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -128,7 +116,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.services.configuration.log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -149,17 +137,14 @@ class CompelInvocation(BaseInvocation):
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel( def run_clip_compel(
self, self,
context: InvocationContext, context: InvocationContext,
@ -168,26 +153,21 @@ class SDXLPromptInvocationBase:
get_pooled: bool, get_pooled: bool,
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
): ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
tokenizer_info = context.services.model_manager.get_model( tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
**clip_field.tokenizer.model_dump(), text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.model_dump(),
context=context,
)
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.context.model cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros( c = torch.zeros(
( (
1, 1,
cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size, cpu_text_encoder.config.hidden_size,
), ),
dtype=text_encoder_info.context.cache.precision, dtype=cpu_text_encoder.dtype,
) )
if get_pooled: if get_pooled:
c_pooled = torch.zeros( c_pooled = torch.zeros(
@ -198,40 +178,36 @@ class SDXLPromptInvocationBase:
c_pooled = None c_pooled = None
return c, c_pooled, None return c, c_pooled, None
def _lora_loader(): def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
**lora.model_dump(exclude={"weight"}), context=context lora_model = lora_info.model
) assert isinstance(lora_model, LoRAModelRaw)
yield (lora_info.context.model, lora.weight) yield (lora_model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt): for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_list.append( ti_model = context.models.load_by_attrs(
( model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
name, ).model
context.services.model_manager.get_model( assert isinstance(ti_model, TextualInversionModelRaw)
model_name=name, ti_list.append((name, ti_model))
base_model=clip_field.text_encoder.base_model, except UnknownModelException:
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
)
except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') logger.warning(f'trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with ( with (
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@ -239,7 +215,7 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -253,7 +229,7 @@ class SDXLPromptInvocationBase:
conjunction = Compel.parse_prompt_string(prompt) conjunction = Compel.parse_prompt_string(prompt)
if context.services.configuration.log_tokenization: if context.config.get().log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
@ -286,7 +262,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -357,6 +333,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1, dim=1,
) )
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -368,14 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation( @invocation(
@ -383,7 +355,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt", title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.0", version="1.0.1",
) )
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -410,6 +382,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -421,14 +394,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
] ]
) )
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" conditioning_name = context.conditioning.save(conditioning_data)
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput( return ConditioningOutput.build(conditioning_name)
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")
@ -449,7 +417,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers
@ -459,9 +427,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer, tokenizer: CLIPTokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction], prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False, truncate_if_too_long: bool = False,
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
@ -473,7 +441,9 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -486,24 +456,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text) tokens: List[str] = tokenizer.tokenize(text)
if truncate_if_too_long: if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75 max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length] tokens = tokens[0:max_tokens_length]
return tokens return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts) > 1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
@ -543,7 +518,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '