mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Allow TIs to be either a key or a name in the prompt during our transition to using keys
This commit is contained in:
parent
94e3857110
commit
5922cee541
@ -3,7 +3,7 @@ from typing import Iterator, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
@ -18,7 +18,7 @@ from invokeai.app.services.model_records import UnknownModelException
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import ModelType
|
from invokeai.backend.model_manager.config import ModelType
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@ -70,7 +70,11 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
|
tokenizer_model = tokenizer_info.model
|
||||||
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
@ -82,21 +86,29 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name_or_key = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
loaded_model = context.models.load(key=name).model
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
model = loaded_model.model
|
||||||
ti_list.append((name, loaded_model))
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
except UnknownModelException:
|
except UnknownModelException:
|
||||||
# print(e)
|
try:
|
||||||
# import traceback
|
loaded_model = context.models.load_by_attrs(
|
||||||
# print(traceback.format_exc())
|
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
logger.warning(f'trigger: "{trigger}" not found')
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
@ -106,6 +118,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||||
):
|
):
|
||||||
|
assert isinstance(text_encoder, CLIPTextModel)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -155,7 +168,11 @@ class SDXLPromptInvocationBase:
|
|||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
|
tokenizer_model = tokenizer_info.model
|
||||||
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -189,25 +206,29 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name = trigger[1:-1]
|
name_or_key = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_model = context.models.load_by_attrs(
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
model = loaded_model.model
|
||||||
).model
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
ti_list.append((name_or_key, model))
|
||||||
ti_list.append((name, ti_model))
|
|
||||||
except UnknownModelException:
|
except UnknownModelException:
|
||||||
# print(e)
|
try:
|
||||||
# import traceback
|
loaded_model = context.models.load_by_attrs(
|
||||||
# print(traceback.format_exc())
|
model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||||
logger.warning(f'trigger: "{trigger}" not found')
|
)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
logger.warning(f'trigger: "{trigger}" not found')
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
@ -215,8 +236,9 @@ class SDXLPromptInvocationBase:
|
|||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
|
assert isinstance(text_encoder, CLIPTextModel)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
Loading…
Reference in New Issue
Block a user