mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
10 Commits
next-fix-t
...
next-allow
Author | SHA1 | Date | |
---|---|---|---|
b8ef9407d1 | |||
d1e7a2f094 | |||
6b595ecd8a | |||
e391bf7c25 | |||
8c6860a2c5 | |||
fa8263e6f0 | |||
e4b8cb1d34 | |||
408a800593 | |||
9e5e3f1019 | |||
98a13aa7dc |
@ -3,9 +3,8 @@ from typing import Iterator, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
@ -14,11 +13,9 @@ from invokeai.app.invocations.fields import (
|
|||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager import ModelType
|
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@ -26,7 +23,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import torch_dtype
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -70,7 +66,11 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
|
tokenizer_model = tokenizer_info.model
|
||||||
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
@ -82,21 +82,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
loaded_model = context.models.load(key=name).model
|
|
||||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name, loaded_model))
|
|
||||||
except UnknownModelException:
|
|
||||||
# print(e)
|
|
||||||
# import traceback
|
|
||||||
# print(traceback.format_exc())
|
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
@ -104,8 +93,9 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||||
):
|
):
|
||||||
|
assert isinstance(text_encoder, CLIPTextModel)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -155,7 +145,11 @@ class SDXLPromptInvocationBase:
|
|||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
|
tokenizer_model = tokenizer_info.model
|
||||||
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -189,25 +183,10 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
|
||||||
name = trigger[1:-1]
|
|
||||||
try:
|
|
||||||
ti_model = context.models.load_by_attrs(
|
|
||||||
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
|
||||||
).model
|
|
||||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
|
||||||
ti_list.append((name, ti_model))
|
|
||||||
except UnknownModelException:
|
|
||||||
# print(e)
|
|
||||||
# import traceback
|
|
||||||
# print(traceback.format_exc())
|
|
||||||
logger.warning(f'trigger: "{trigger}" not found')
|
|
||||||
except ValueError:
|
|
||||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
@ -215,8 +194,9 @@ class SDXLPromptInvocationBase:
|
|||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||||
):
|
):
|
||||||
|
assert isinstance(text_encoder, CLIPTextModel)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -197,6 +197,7 @@ class Categories(object):
|
|||||||
Development: JsonDict = {"category": "Development"}
|
Development: JsonDict = {"category": "Development"}
|
||||||
Other: JsonDict = {"category": "Other"}
|
Other: JsonDict = {"category": "Other"}
|
||||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||||
|
ModelImport: JsonDict = {"category": "Model Import"}
|
||||||
Device: JsonDict = {"category": "Device"}
|
Device: JsonDict = {"category": "Device"}
|
||||||
Generation: JsonDict = {"category": "Generation"}
|
Generation: JsonDict = {"category": "Generation"}
|
||||||
Queue: JsonDict = {"category": "Queue"}
|
Queue: JsonDict = {"category": "Queue"}
|
||||||
@ -286,7 +287,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||||
|
|
||||||
# MODEL IMPORT
|
# MODEL IMPORT
|
||||||
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.Other)
|
civitai_api_key : Optional[str] = Field(default=os.environ.get("CIVITAI_API_KEY"), description="API key for CivitAI", json_schema_extra=Categories.ModelImport)
|
||||||
|
model_sym_links : bool = Field(default=False, description="If true, create symbolic links to models instead of copying them. [REQUIRES ADMIN PERMISSIONS OR DEVELOPER MODE IN WINDOWS]", json_schema_extra=Categories.ModelImport)
|
||||||
|
|
||||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||||
|
@ -507,6 +507,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if old_path == new_path:
|
if old_path == new_path:
|
||||||
return old_path
|
return old_path
|
||||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if self.app_config.model_sym_links:
|
||||||
|
new_path.symlink_to(old_path, target_is_directory=old_path.is_dir())
|
||||||
|
else:
|
||||||
if old_path.is_dir():
|
if old_path.is_dir():
|
||||||
copytree(old_path, new_path)
|
copytree(old_path, new_path)
|
||||||
else:
|
else:
|
||||||
|
@ -1,8 +1,47 @@
|
|||||||
import re
|
import re
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||||
|
|
||||||
|
|
||||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
|
||||||
ti_triggers = []
|
ti_triggers: List[str] = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||||
ti_triggers.append(trigger)
|
ti_triggers.append(str(trigger))
|
||||||
return ti_triggers
|
return ti_triggers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_ti_list(
|
||||||
|
prompt: str, base: BaseModelType, context: InvocationContext
|
||||||
|
) -> List[Tuple[str, TextualInversionModelRaw]]:
|
||||||
|
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||||
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
|
name_or_key = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
loaded_model = context.models.load(key=name_or_key)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
assert loaded_model.config.base == base
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
try:
|
||||||
|
loaded_model = context.models.load_by_attrs(
|
||||||
|
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
|
||||||
|
)
|
||||||
|
model = loaded_model.model
|
||||||
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
|
assert loaded_model.config.base == base
|
||||||
|
ti_list.append((name_or_key, model))
|
||||||
|
except UnknownModelException:
|
||||||
|
pass
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||||
|
except AssertionError:
|
||||||
|
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
|
||||||
|
return ti_list
|
||||||
|
@ -160,7 +160,7 @@ class CivitaiMetadataFetch(ModelMetadataFetchBase):
|
|||||||
nsfw=model_json["nsfw"],
|
nsfw=model_json["nsfw"],
|
||||||
restrictions=LicenseRestrictions(
|
restrictions=LicenseRestrictions(
|
||||||
AllowNoCredit=model_json["allowNoCredit"],
|
AllowNoCredit=model_json["allowNoCredit"],
|
||||||
AllowCommercialUse=CommercialUsage(model_json["allowCommercialUse"]),
|
AllowCommercialUse={CommercialUsage(x) for x in model_json["allowCommercialUse"]},
|
||||||
AllowDerivatives=model_json["allowDerivatives"],
|
AllowDerivatives=model_json["allowDerivatives"],
|
||||||
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
AllowDifferentLicense=model_json["allowDifferentLicense"],
|
||||||
),
|
),
|
||||||
|
@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
|
|||||||
AllowDifferentLicense: bool = Field(
|
AllowDifferentLicense: bool = Field(
|
||||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||||
)
|
)
|
||||||
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
AllowCommercialUse: Optional[Set[CommercialUsage] | CommercialUsage] = Field(
|
||||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
description="Type of commercial use allowed if no commercial use is allowed.", default=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,7 +142,10 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
|||||||
if self.restrictions.AllowCommercialUse is None:
|
if self.restrictions.AllowCommercialUse is None:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
# accommodate schema change
|
||||||
|
acu = self.restrictions.AllowCommercialUse
|
||||||
|
commercial_usage = acu if isinstance(acu, set) else {acu}
|
||||||
|
return CommercialUsage.No not in commercial_usage
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def allow_derivatives(self) -> bool:
|
def allow_derivatives(self) -> bool:
|
||||||
|
File diff suppressed because one or more lines are too long
@ -133,7 +133,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
|||||||
assert metadata.id == 215485
|
assert metadata.id == 215485
|
||||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||||
assert metadata.version_id == 242807
|
assert metadata.version_id == 242807
|
||||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user