From 16b3718d6ab81e9f0e26d55e6455cc6bb29e4805 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Tue, 27 Feb 2024 15:39:34 -0500 Subject: [PATCH] cleanup --- invokeai/app/api/routers/model_manager.py | 4 - invokeai/app/invocations/compel.py | 224 +++++++++--------- .../model_metadata/metadata_store_sql.py | 6 +- invokeai/backend/model_patcher.py | 2 - 4 files changed, 117 insertions(+), 119 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 774f39909d..8b66c70618 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -261,13 +261,11 @@ async def update_model_metadata( changes: ModelMetadataChanges = Body(description="The changes") ) -> Optional[AnyModelRepoMetadata]: """Updates or creates a model metadata object.""" - logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store try: original_metadata = record_store.get_metadata(key) - print(original_metadata) if original_metadata: original_metadata.trigger_phrases = changes.trigger_phrases @@ -275,7 +273,6 @@ async def update_model_metadata( else: metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases)) except Exception as e: - ApiDependencies.invoker.services.logger.error(traceback.format_exception(e)) raise HTTPException( status_code=500, detail=f"An error occurred while updating the model metadata: {e}", @@ -286,7 +283,6 @@ async def update_model_metadata( return result - @model_manager_router.get( "/tags", operation_id="list_tags", diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0d558ec898..49c62cff56 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,43 +1,40 @@ -from typing import Iterator, List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import List, Optional, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTokenizer -import invokeai.backend.util.logging as logger -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, -) -from invokeai.app.invocations.primitives import ConditioningOutput -from invokeai.app.services.model_records import UnknownModelException -from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt -from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import ModelType -from invokeai.backend.model_patcher import ModelPatcher +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput +from invokeai.app.shared.fields import FieldDescriptions from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, - ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) -from invokeai.backend.textual_inversion import TextualInversionModelRaw -from invokeai.backend.util.devices import torch_dtype +from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management.models import ModelNotFoundException, ModelType +from ...backend.util.devices import torch_dtype +from ..util.ti_utils import extract_ti_triggers_from_prompt from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Input, + InputField, + InvocationContext, + OutputField, + UIComponent, invocation, invocation_output, ) from .model import ClipField -# unconditioned: Optional[torch.Tensor] + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + # unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -51,7 +48,7 @@ from .model import ClipField title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -69,46 +66,49 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) - text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + tokenizer_info = context.services.model_manager.get_model( + **self.clip.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_manager.get_model( + **self.clip.text_encoder.model_dump(), + context=context, + ) - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader(): for lora in self.clip.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) + lora_info = context.services.model_manager.get_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): - name_or_key = trigger[1:-1] - print(f"name_or_key: {name_or_key}") + name = trigger[1:-1] try: - loaded_model = context.models.load(key=name_or_key) - model = loaded_model.model - print(model) - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - try: - print(f"base: {text_encoder_info.config.base}") - loaded_model = context.models.load_by_attrs( - model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, ) - model = loaded_model.model - print(model) - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - logger.warning(f'trigger: "{trigger}" not found') - except ValueError: - logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') + ) + except ModelNotFoundException: + # print(e) + # import traceback + # print(traceback.format_exc()) + print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, ), @@ -116,7 +116,7 @@ class CompelInvocation(BaseInvocation): # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -128,7 +128,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) - if context.config.get().log_tokenization: + if context.services.configuration.log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -149,14 +149,17 @@ class CompelInvocation(BaseInvocation): ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) class SDXLPromptInvocationBase: - """Prompt processor for SDXL models.""" - def run_clip_compel( self, context: InvocationContext, @@ -165,21 +168,26 @@ class SDXLPromptInvocationBase: get_pooled: bool, lora_prefix: str, zero_on_empty: bool, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: - tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) - text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + ): + tokenizer_info = context.services.model_manager.get_model( + **clip_field.tokenizer.model_dump(), + context=context, + ) + text_encoder_info = context.services.model_manager.get_model( + **clip_field.text_encoder.model_dump(), + context=context, + ) # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.model - assert isinstance(cpu_text_encoder, torch.nn.Module) + cpu_text_encoder = text_encoder_info.context.model c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=cpu_text_encoder.dtype, + dtype=text_encoder_info.context.cache.precision, ) if get_pooled: c_pooled = torch.zeros( @@ -190,36 +198,40 @@ class SDXLPromptInvocationBase: c_pooled = None return c, c_pooled, None - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader(): for lora in clip_field.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - lora_model = lora_info.model - assert isinstance(lora_model, LoRAModelRaw) - yield (lora_model, lora.weight) + lora_info = context.services.model_manager.get_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in extract_ti_triggers_from_prompt(prompt): name = trigger[1:-1] try: - ti_model = context.models.load_by_attrs( - model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ).model - assert isinstance(ti_model, TextualInversionModelRaw) - ti_list.append((name, ti_model)) - except UnknownModelException: + ti_list.append( + ( + name, + context.services.model_manager.get_model( + model_name=name, + base_model=clip_field.text_encoder.base_model, + model_type=ModelType.TextualInversion, + context=context, + ).context.model, + ) + ) + except ModelNotFoundException: # print(e) # import traceback # print(traceback.format_exc()) - logger.warning(f'trigger: "{trigger}" not found') - except ValueError: - logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') + print(f'Warn: trigger: "{trigger}" not found') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, ), @@ -227,7 +239,7 @@ class SDXLPromptInvocationBase: # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), ): compel = Compel( tokenizer=tokenizer, @@ -241,7 +253,7 @@ class SDXLPromptInvocationBase: conjunction = Compel.parse_prompt_string(prompt) - if context.config.get().log_tokenization: + if context.services.configuration.log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -274,7 +286,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -345,7 +357,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): dim=1, ) - assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -357,9 +368,14 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) @invocation( @@ -367,7 +383,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -394,7 +410,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) - assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -406,9 +421,14 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) @invocation_output("clip_skip_output") @@ -429,7 +449,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers @@ -439,9 +459,9 @@ class ClipSkipInvocation(BaseInvocation): def get_max_token_count( - tokenizer: CLIPTokenizer, + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long: bool = False, + truncate_if_too_long=False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -453,9 +473,7 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object( - tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True -) -> List[str]: +def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -468,29 +486,24 @@ def get_tokens_for_prompt_object( for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens: List[str] = tokenizer.tokenize(text) + tokens = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction( - c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None -) -> None: +def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: - assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object( - p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None -) -> None: +def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -530,12 +543,7 @@ def log_tokenization_for_prompt_object( log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text( - text: str, - tokenizer: CLIPTokenizer, - display_label: Optional[str] = None, - truncate_if_too_long: Optional[bool] = False, -) -> None: +def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py index 9d9057d0c5..849130e363 100644 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -38,8 +38,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): :param metadata: ModelRepoMetadata object to store """ json_serialized = metadata.model_dump_json() - print("json_serialized") - print(json_serialized) with self._db.lock: try: self._cursor.execute( @@ -55,7 +53,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): json_serialized, ), ) - # self._update_tags(model_key, metadata.tags) + self._update_tags(model_key, metadata.tags) self._db.conn.commit() except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table self._db.conn.rollback() @@ -63,8 +61,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): except sqlite3.Error as excp: self._db.conn.rollback() raise excp - except Exception as e: - raise e def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: """Retrieve the ModelRepoMetadata corresponding to model key.""" diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 87f10e4adc..bee8909c31 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -171,8 +171,6 @@ class ModelPatcher: text_encoder: CLIPTextModel, ti_list: List[Tuple[str, TextualInversionModelRaw]], ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: - print("TI LIST") - print(ti_list) init_tokens_count = None new_tokens_added = None