This commit is contained in:
maryhipp 2024-02-27 15:39:34 -05:00 committed by Kent Keirsey
parent a13dce18c7
commit 6fff7de2ab
4 changed files with 54 additions and 62 deletions

View File

@ -261,13 +261,11 @@ async def update_model_metadata(
changes: ModelMetadataChanges = Body(description="The changes")
) -> Optional[AnyModelRepoMetadata]:
"""Updates or creates a model metadata object."""
logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store
metadata_store = ApiDependencies.invoker.services.model_manager.store.metadata_store
try:
original_metadata = record_store.get_metadata(key)
print(original_metadata)
if original_metadata:
original_metadata.trigger_phrases = changes.trigger_phrases
@ -275,7 +273,6 @@ async def update_model_metadata(
else:
metadata_store.add_metadata(key, BaseMetadata(name="", author="",trigger_phrases=changes.trigger_phrases))
except Exception as e:
ApiDependencies.invoker.services.logger.error(traceback.format_exception(e))
raise HTTPException(
status_code=500,
detail=f"An error occurred while updating the model metadata: {e}",
@ -286,7 +283,6 @@ async def update_model_metadata(
return result
@model_manager_router.get(
"/tags",
operation_id="list_tags",

View File

@ -13,7 +13,6 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
@ -22,7 +21,11 @@ from invokeai.backend.util.devices import torch_dtype
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import ClipField
# unconditioned: Optional[torch.Tensor]
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum):
@ -36,7 +39,7 @@ from .model import ClipField
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.0.1",
version="1.0.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
@ -61,15 +64,16 @@ class CompelInvocation(BaseInvocation):
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
lora_info = context.services.model_manager.get_model(
**lora.model_dump(exclude={"weight"}), context=context
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
@ -95,7 +99,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization:
if context.services.configuration.log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer)
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -116,14 +120,17 @@ class CompelInvocation(BaseInvocation):
]
)
conditioning_name = context.conditioning.save(conditioning_data)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
class SDXLPromptInvocationBase:
"""Prompt processor for SDXL models."""
def run_clip_compel(
self,
context: InvocationContext,
@ -142,15 +149,14 @@ class SDXLPromptInvocationBase:
# return zero on empty
if prompt == "" and zero_on_empty:
cpu_text_encoder = text_encoder_info.model
assert isinstance(cpu_text_encoder, torch.nn.Module)
cpu_text_encoder = text_encoder_info.context.model
c = torch.zeros(
(
1,
cpu_text_encoder.config.max_position_embeddings,
cpu_text_encoder.config.hidden_size,
),
dtype=cpu_text_encoder.dtype,
dtype=text_encoder_info.context.cache.precision,
)
if get_pooled:
c_pooled = torch.zeros(
@ -161,16 +167,16 @@ class SDXLPromptInvocationBase:
c_pooled = None
return c, c_pooled, None
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)
lora_info = context.services.model_manager.get_model(
**lora.model_dump(exclude={"weight"}), context=context
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
@ -199,7 +205,7 @@ class SDXLPromptInvocationBase:
conjunction = Compel.parse_prompt_string(prompt)
if context.config.get().log_tokenization:
if context.services.configuration.log_tokenization:
# TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer)
@ -232,7 +238,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.1",
version="1.0.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -303,7 +309,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
dim=1,
)
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -315,9 +320,14 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
]
)
conditioning_name = context.conditioning.save(conditioning_data)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation(
@ -325,7 +335,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
title="SDXL Refiner Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.0.1",
version="1.0.0",
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@ -352,7 +362,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
assert c2_pooled is not None
conditioning_data = ConditioningFieldData(
conditionings=[
SDXLConditioningInfo(
@ -364,9 +373,14 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
]
)
conditioning_name = context.conditioning.save(conditioning_data)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
context.services.latents.save(conditioning_name, conditioning_data)
return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
@invocation_output("clip_skip_output")
@ -387,7 +401,7 @@ class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
self.clip.skipped_layers += self.skipped_layers
@ -397,9 +411,9 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count(
tokenizer: CLIPTokenizer,
tokenizer,
prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long: bool = False,
truncate_if_too_long=False,
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
@ -411,9 +425,7 @@ def get_max_token_count(
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(
tokenizer: CLIPTokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long: bool = True
) -> List[str]:
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
@ -426,29 +438,24 @@ def get_tokens_for_prompt_object(
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens: List[str] = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
assert display_label_prefix is not None
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer: CLIPTokenizer, display_label_prefix: Optional[str] = None
) -> None:
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
@ -488,12 +495,7 @@ def log_tokenization_for_prompt_object(
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
def log_tokenization_for_text(
text: str,
tokenizer: CLIPTokenizer,
display_label: Optional[str] = None,
truncate_if_too_long: Optional[bool] = False,
) -> None:
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@ -38,8 +38,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
:param metadata: ModelRepoMetadata object to store
"""
json_serialized = metadata.model_dump_json()
print("json_serialized")
print(json_serialized)
with self._db.lock:
try:
self._cursor.execute(
@ -55,7 +53,7 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
json_serialized,
),
)
# self._update_tags(model_key, metadata.tags)
self._update_tags(model_key, metadata.tags)
self._db.conn.commit()
except sqlite3.IntegrityError as excp: # FOREIGN KEY error: the key was not in model_config table
self._db.conn.rollback()
@ -63,8 +61,6 @@ class ModelMetadataStoreSQL(ModelMetadataStoreBase):
except sqlite3.Error as excp:
self._db.conn.rollback()
raise excp
except Exception as e:
raise e
def get_metadata(self, model_key: str) -> AnyModelRepoMetadata:
"""Retrieve the ModelRepoMetadata corresponding to model key."""

View File

@ -172,8 +172,6 @@ class ModelPatcher:
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
print("TI LIST")
print(ti_list)
init_tokens_count = None
new_tokens_added = None