From 6fff7de2ab805b10165e6ff4f611808aa0c27eb6 Mon Sep 17 00:00:00 2001 From: maryhipp Date: Tue, 27 Feb 2024 15:39:34 -0500 Subject: [PATCH] cleanup --- invokeai/app/api/routers/model_manager.py | 4 - invokeai/app/invocations/compel.py | 104 +++++++++--------- .../model_metadata/metadata_store_sql.py | 6 +- invokeai/backend/model_patcher.py | 2 - 4 files changed, 54 insertions(+), 62 deletions(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 774f39909d..8b66c70618 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -261,13 +261,11 @@ async def update_model_metadata( changes: ModelMetadataChanges = Body(description="The changes") ) -> Optional[AnyModelRepoMetadata]: """Updates or creates a model metadata object.""" - logger = ApiDependencies.invoker.services.logger record_store = ApiDependencies.invoker.services.model_manager.store metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store try: original_metadata = record_store.get_metadata(key) - print(original_metadata) if original_metadata: original_metadata.trigger_phrases = changes.trigger_phrases @@ -275,7 +273,6 @@ async def update_model_metadata( else: metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases)) except Exception as e: - ApiDependencies.invoker.services.logger.error(traceback.format_exception(e)) raise HTTPException( status_code=500, detail=f"An error occurred while updating the model metadata: {e}", @@ -286,7 +283,6 @@ async def update_model_metadata( return result - @model_manager_router.get( "/tags", operation_id="list_tags", diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ff13658052..233643af7c 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -13,7 +13,6 @@ from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, - ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) @@ -22,7 +21,11 @@ from invokeai.backend.util.devices import torch_dtype from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import ClipField -# unconditioned: Optional[torch.Tensor] + +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + # unconditioned: Optional[torch.Tensor] # class ConditioningAlgo(str, Enum): @@ -36,7 +39,7 @@ from .model import ClipField title="Prompt", tags=["prompt", "compel"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" @@ -61,15 +64,16 @@ class CompelInvocation(BaseInvocation): text_encoder_model = text_encoder_info.model assert isinstance(text_encoder_model, CLIPTextModel) - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader(): for lora in self.clip.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) + lora_info = context.services.model_manager.get_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) @@ -95,7 +99,7 @@ class CompelInvocation(BaseInvocation): conjunction = Compel.parse_prompt_string(self.prompt) - if context.config.get().log_tokenization: + if context.services.configuration.log_tokenization: log_tokenization_for_conjunction(conjunction, tokenizer) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) @@ -116,14 +120,17 @@ class CompelInvocation(BaseInvocation): ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) class SDXLPromptInvocationBase: - """Prompt processor for SDXL models.""" - def run_clip_compel( self, context: InvocationContext, @@ -142,15 +149,14 @@ class SDXLPromptInvocationBase: # return zero on empty if prompt == "" and zero_on_empty: - cpu_text_encoder = text_encoder_info.model - assert isinstance(cpu_text_encoder, torch.nn.Module) + cpu_text_encoder = text_encoder_info.context.model c = torch.zeros( ( 1, cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.hidden_size, ), - dtype=cpu_text_encoder.dtype, + dtype=text_encoder_info.context.cache.precision, ) if get_pooled: c_pooled = torch.zeros( @@ -161,16 +167,16 @@ class SDXLPromptInvocationBase: c_pooled = None return c, c_pooled, None - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + def _lora_loader(): for lora in clip_field.loras: - lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) - lora_model = lora_info.model - assert isinstance(lora_model, LoRAModelRaw) - yield (lora_model, lora.weight) + lora_info = context.services.model_manager.get_model( + **lora.model_dump(exclude={"weight"}), context=context + ) + yield (lora_info.context.model, lora.weight) del lora_info return - # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) @@ -199,7 +205,7 @@ class SDXLPromptInvocationBase: conjunction = Compel.parse_prompt_string(prompt) - if context.config.get().log_tokenization: + if context.services.configuration.log_tokenization: # TODO: better logging for and syntax log_tokenization_for_conjunction(conjunction, tokenizer) @@ -232,7 +238,7 @@ class SDXLPromptInvocationBase: title="SDXL Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -303,7 +309,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): dim=1, ) - assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -315,9 +320,14 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) @invocation( @@ -325,7 +335,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): title="SDXL Refiner Prompt", tags=["sdxl", "compel", "prompt"], category="conditioning", - version="1.0.1", + version="1.0.0", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -352,7 +362,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) - assert c2_pooled is not None conditioning_data = ConditioningFieldData( conditionings=[ SDXLConditioningInfo( @@ -364,9 +373,14 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase ] ) - conditioning_name = context.conditioning.save(conditioning_data) + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + context.services.latents.save(conditioning_name, conditioning_data) - return ConditioningOutput.build(conditioning_name) + return ConditioningOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) @invocation_output("clip_skip_output") @@ -387,7 +401,7 @@ class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") - skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) + skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers @@ -397,9 +411,9 @@ class ClipSkipInvocation(BaseInvocation): def get_max_token_count( - tokenizer: CLIPTokenizer, + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long: bool = False, + truncate_if_too_long=False, ) -> int: if type(prompt) is Blend: blend: Blend = prompt @@ -411,9 +425,7 @@ def get_max_token_count( return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object( - tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True -) -> List[str]: +def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") @@ -426,29 +438,24 @@ def get_tokens_for_prompt_object( for x in parsed_prompt.children ] text = " ".join(text_fragments) - tokens: List[str] = tokenizer.tokenize(text) + tokens = tokenizer.tokenize(text) if truncate_if_too_long: max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def log_tokenization_for_conjunction( - c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None -) -> None: +def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: - assert display_label_prefix is not None this_display_label_prefix = display_label_prefix log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object( - p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None -) -> None: +def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -488,12 +495,7 @@ def log_tokenization_for_prompt_object( log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text( - text: str, - tokenizer: CLIPTokenizer, - display_label: Optional[str] = None, - truncate_if_too_long: Optional[bool] = False, -) -> None: +def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/services/model_metadata/metadata_store_sql.py b/invokeai/app/services/model_metadata/metadata_store_sql.py index 9d9057d0c5..849130e363 100644 --- a/invokeai/app/services/model_metadata/metadata_store_sql.py +++ b/invokeai/app/services/model_metadata/metadata_store_sql.py @@ -38,8 +38,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): :param metadata: ModelRepoMetadata object to store """ json_serialized = metadata.model_dump_json() - print("json_serialized") - print(json_serialized) with self._db.lock: try: self._cursor.execute( @@ -55,7 +53,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): json_serialized, ), ) - # self._update_tags(model_key, metadata.tags) + self._update_tags(model_key, metadata.tags) self._db.conn.commit() except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table self._db.conn.rollback() @@ -63,8 +61,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase): except sqlite3.Error as excp: self._db.conn.rollback() raise excp - except Exception as e: - raise e def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: """Retrieve the ModelRepoMetadata corresponding to model key.""" diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 7887802598..76271fc025 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -172,8 +172,6 @@ class ModelPatcher: text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], ti_list: List[Tuple[str, TextualInversionModelRaw]], ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: - print("TI LIST") - print(ti_list) init_tokens_count = None new_tokens_added = None