From 600b4c6a909ddc817090efd0da84f3423a742cb3 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Tue, 27 Feb 2024 15:41:03 -0500 Subject: [PATCH] what have i done --- invokeai/app/invocations/compel.py | 104 ++++++++++++++--------------- 1 file changed, 51 insertions(+), 53 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 233643af7c..ff13658052 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -13,6 +13,7 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) @@ -21,11 +22,7 @@ from invokeai.backend.util.devices import torch_dtype from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import ClipField - -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - # unconditioned: Optional[torch.Tensor] +# unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -39,7 +36,7 @@ class ConditioningFieldData: title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -64,16 +61,15 @@ class CompelInvocation(BaseInvocation): text_encoder_model = text_encoder_info.model assert isinstance(text_encoder_model, CLIPTextModel) - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) - yield (lora_info.context.model, lora.weight) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) @@ -99,7 +95,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -120,17 +116,14 @@ class CompelInvocation(BaseInvocation): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) class SDXLPromptInvocationBase: + """Prompt processor for SDXL models.""" + def run_clip_compel( self, context: InvocationContext, @@ -149,14 +142,15 @@ class SDXLPromptInvocationBase: # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.context.model + cpu_text_encoder = text_encoder_info.model + assert isinstance(cpu_text_encoder, torch.nn.Module) c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=text_encoder_info.context.cache.precision, + dtype=cpu_text_encoder.dtype, ) if get_pooled: c_pooled = torch.zeros( @@ -167,16 +161,16 @@ class SDXLPromptInvocationBase: c_pooled = None return c, c_pooled, None - def _lora_loader(): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.model_dump(exclude={"weight"}), context=context - ) - yield (lora_info.context.model, lora.weight) + lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) + lora_model = lora_info.model + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) del lora_info return - # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) @@ -205,7 +199,7 @@ class SDXLPromptInvocationBase: conjunction = Compel.parse_prompt_string(prompt) - if context.services.configuration.log_tokenization: + if context.config.get().log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -238,7 +232,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -309,6 +303,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): dim=1, ) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -320,14 +315,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation( @@ -335,7 +325,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.0", + version="1.0.1", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -362,6 +352,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) + assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -373,14 +364,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase ] ) - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - context.services.latents.save(conditioning_name, conditioning_data) + conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) + return ConditioningOutput.build(conditioning_name) @invocation_output("clip_skip_output") @@ -401,7 +387,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers @@ -411,9 +397,9 @@ class ClipSkipInvocation(BaseInvocation): def get_max_token_count( - tokenizer, + tokenizer: CLIPTokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False, + truncate_if_too_long: bool = False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -425,7 +411,9 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: +def get_tokens_for_prompt_object( + tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -438,24 +426,29 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens = tokenizer.tokenize(text) + tokens: List[str] = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): +def log_tokenization_for_conjunction( + c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: + assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): +def log_tokenization_for_prompt_object( + p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None +) -> None: display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -495,7 +488,12 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text: str, + tokenizer: CLIPTokenizer, + display_label: Optional[str] = None, + truncate_if_too_long: Optional[bool] = False, +) -> None: """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' '