This commit is contained in:
maryhipp 2024-02-27 15:39:34 -05:00
parent 30228ce2a4
commit 16b3718d6a
4 changed files with 117 additions and 119 deletions

View File

@ -261,13 +261,11 @@ async def update_model_metadata(
changes: ModelMetadataChanges = Body(description="The changes") changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]: ) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object.""" """Updates or creates a model metadata object."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try: try:
original_metadata = record_store.get_metadata(key) original_metadata = record_store.get_metadata(key)
print(original_metadata)
if original_metadata: if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases original_metadata.trigger_phrases = changes.trigger_phrases
@ -275,7 +273,6 @@ async def update_model_metadata(
else: else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases)) metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e: except Exception as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail=f"An error occurred while updating the model metadata: {e}", detail=f"An error occurred while updating the model metadata: {e}",
@ -286,7 +283,6 @@ async def update_model_metadata(
return result return result
@model_manager_router.get( @model_manager_router.get(
"/tags", "/tags",
operation_id="list_tags", operation_id="list_tags",

View File

@ -1,43 +1,40 @@
from typing import Iterator, List, Optional, Tuple, Union from dataclasses import dataclass
from typing import List, Optional, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTokenizer
import invokeai.backend.util.logging as logger from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.app.invocations.fields import ( from invokeai.app.shared.fields import FieldDescriptions
FieldDescriptions,
Input,
InputField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.model_records import UnknownModelException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import ModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo, ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.util.devices import torch_dtype
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from ..util.ti_utils import extract_ti_triggers_from_prompt
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
invocation, invocation,
invocation_output, invocation_output,
) )
from .model import ClipField from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum): # class ConditioningAlgo(str, Enum):
@ -51,7 +48,7 @@ from .model import ClipField
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.0.0",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -69,46 +66,49 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) tokenizer_info = context.services.model_manager.get_model(
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) **self.clip.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.model_dump(),
context=context,
)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader():
for lora in self.clip.loras: for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_info = context.services.model_manager.get_model(
assert isinstance(lora_info.model, LoRAModelRaw) **lora.model_dump(exclude={"weight"}), context=context
yield (lora_info.model, lora.weight) )
yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in extract_ti_triggers_from_prompt(self.prompt): for trigger in extract_ti_triggers_from_prompt(self.prompt):
name_or_key = trigger[1:-1] name = trigger[1:-1]
print(f"name_or_key: {name_or_key}")
try: try:
loaded_model = context.models.load(key=name_or_key) ti_list.append(
model = loaded_model.model (
print(model) name,
assert isinstance(model, TextualInversionModelRaw) context.services.model_manager.get_model(
ti_list.append((name_or_key, model)) model_name=name,
except UnknownModelException: base_model=self.clip.text_encoder.base_model,
try: model_type=ModelType.TextualInversion,
print(f"base: {text_encoder_info.config.base}") context=context,
loaded_model = context.models.load_by_attrs( ).context.model,
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
) )
model = loaded_model.model )
print(model) except ModelNotFoundException:
assert isinstance(model, TextualInversionModelRaw) # print(e)
ti_list.append((name_or_key, model)) # import traceback
except UnknownModelException: # print(traceback.format_exc())
logger.warning(f'trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with ( with (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@ -116,7 +116,7 @@ class CompelInvocation(BaseInvocation):
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -128,7 +128,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization: if context.services.configuration.log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -149,14 +149,17 @@ class CompelInvocation(BaseInvocation):
] ]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel( def run_clip_compel(
self, self,
context: InvocationContext, context: InvocationContext,
@ -165,21 +168,26 @@ class SDXLPromptInvocationBase:
get_pooled: bool, get_pooled: bool,
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: ):
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) tokenizer_info = context.services.model_manager.get_model(
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) **clip_field.tokenizer.model_dump(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.model_dump(),
context=context,
)
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model cpu_text_encoder = text_encoder_info.context.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
c = torch.zeros( c = torch.zeros(
( (
1, 1,
cpu_text_encoder.config.max_position_embeddings, cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size, cpu_text_encoder.config.hidden_size,
), ),
dtype=cpu_text_encoder.dtype, dtype=text_encoder_info.context.cache.precision,
) )
if get_pooled: if get_pooled:
c_pooled = torch.zeros( c_pooled = torch.zeros(
@ -190,36 +198,40 @@ class SDXLPromptInvocationBase:
c_pooled = None c_pooled = None
return c, c_pooled, None return c, c_pooled, None
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader():
for lora in clip_field.loras: for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"})) lora_info = context.services.model_manager.get_model(
lora_model = lora_info.model **lora.model_dump(exclude={"weight"}), context=context
assert isinstance(lora_model, LoRAModelRaw) )
yield (lora_model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = [] ti_list = []
for trigger in extract_ti_triggers_from_prompt(prompt): for trigger in extract_ti_triggers_from_prompt(prompt):
name = trigger[1:-1] name = trigger[1:-1]
try: try:
ti_model = context.models.load_by_attrs( ti_list.append(
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion (
).model name,
assert isinstance(ti_model, TextualInversionModelRaw) context.services.model_manager.get_model(
ti_list.append((name, ti_model)) model_name=name,
except UnknownModelException: base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model,
)
)
except ModelNotFoundException:
# print(e) # print(e)
# import traceback # import traceback
# print(traceback.format_exc()) # print(traceback.format_exc())
logger.warning(f'trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
except ValueError:
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
with ( with (
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
@ -227,7 +239,7 @@ class SDXLPromptInvocationBase:
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. # Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -241,7 +253,7 @@ class SDXLPromptInvocationBase:
conjunction = Compel.parse_prompt_string(prompt) conjunction = Compel.parse_prompt_string(prompt)
if context.config.get().log_tokenization: if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, tokenizer)
@ -274,7 +286,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.0.0",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -345,7 +357,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1, dim=1,
) )
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -357,9 +368,14 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
] ]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation( @invocation(
@ -367,7 +383,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt", title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.0.0",
) )
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -394,7 +410,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData( conditioning_data = ConditioningFieldData(
conditionings=[ conditionings=[
SDXLConditioningInfo( SDXLConditioningInfo(
@ -406,9 +421,14 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
] ]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")
@ -429,7 +449,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model.""" """Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers) skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers self.clip.skipped_layers += self.skipped_layers
@ -439,9 +459,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count( def get_max_token_count(
tokenizer: CLIPTokenizer, tokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction], prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long: bool = False, truncate_if_too_long=False,
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
@ -453,9 +473,7 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object( def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
if type(parsed_prompt) is Blend: if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -468,29 +486,24 @@ def get_tokens_for_prompt_object(
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)
tokens: List[str] = tokenizer.tokenize(text) tokens = tokenizer.tokenize(text)
if truncate_if_too_long: if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75 max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length] tokens = tokens[0:max_tokens_length]
return tokens return tokens
def log_tokenization_for_conjunction( def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts): for i, p in enumerate(c.prompts):
if len(c.prompts) > 1: if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else: else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object( def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
display_label_prefix = display_label_prefix or "" display_label_prefix = display_label_prefix or ""
if type(p) is Blend: if type(p) is Blend:
blend: Blend = p blend: Blend = p
@ -530,12 +543,7 @@ def log_tokenization_for_prompt_object(
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text( def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '

View File

@ -38,8 +38,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
:param metadata: ModelRepoMetadata object to store :param metadata: ModelRepoMetadata object to store
""" """
json_serialized = metadata.model_dump_json() json_serialized = metadata.model_dump_json()
print("json_serialized")
print(json_serialized)
with self._db.lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
@ -55,7 +53,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
json_serialized, json_serialized,
), ),
) )
# self._update_tags(model_key, metadata.tags) self._update_tags(model_key, metadata.tags)
self._db.conn.commit() self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback() self._db.conn.rollback()
@ -63,8 +61,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as excp: except sqlite3.Error as excp:
self._db.conn.rollback() self._db.conn.rollback()
raise excp raise excp
except Exception as e:
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata: def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key.""" """Retrieve the ModelRepoMetadata corresponding to model key."""

View File

@ -171,8 +171,6 @@ class ModelPatcher:
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModelRaw]], ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]: ) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
print("TI LIST")
print(ti_list)
init_tokens_count = None init_tokens_count = None
new_tokens_added = None new_tokens_added = None