mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(nodes): model identifiers
- All models are identified by a key and optionally a submodel type via new model `ModelField`. Previously, a few model types had their own class, but not all of them. This inconsistency just added complexity without any benefit. - Update all invocation to use the new format. - In the node API, models are loaded by key or an instance of `ModelField` as a convenience. - Add an enriched model schema for metadata. It includes key, hash, name, base and type.
This commit is contained in:
@ -54,16 +54,16 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
lora_info = context.models.load(lora.lora)
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
@ -133,10 +133,10 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||||
text_encoder_model = text_encoder_info.model
|
text_encoder_model = text_encoder_info.model
|
||||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
lora_info = context.models.load(lora.lora)
|
||||||
lora_model = lora_info.model
|
lora_model = lora_info.model
|
||||||
assert isinstance(lora_model, LoRAModelRaw)
|
assert isinstance(lora_model, LoRAModelRaw)
|
||||||
yield (lora_model, lora.weight)
|
yield (lora_model, lora.weight)
|
||||||
|
@ -34,6 +34,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.model import ModelField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
@ -51,15 +52,9 @@ CONTROLNET_RESIZE_VALUES = Literal[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelField(BaseModel):
|
|
||||||
"""ControlNet model field"""
|
|
||||||
|
|
||||||
key: str = Field(description="Model config record key for the ControlNet model")
|
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
class ControlField(BaseModel):
|
||||||
image: ImageField = Field(description="The control image")
|
image: ImageField = Field(description="The control image")
|
||||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
control_model: ModelField = Field(description="The ControlNet model to use")
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
@ -95,7 +90,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
"""Collects ControlNet info to pass to other nodes"""
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The control image")
|
image: ImageField = InputField(description="The control image")
|
||||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
control_model: ModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||||
control_weight: Union[float, List[float]] = InputField(
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||||
)
|
)
|
||||||
|
@ -228,7 +228,7 @@ class ConditioningField(BaseModel):
|
|||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
class MetadataField(RootModel[dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||||
Metadata is stored without a strict schema.
|
Metadata is stored without a strict schema.
|
||||||
|
@ -11,25 +11,17 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
|
from invokeai.app.invocations.model import ModelField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType
|
||||||
|
|
||||||
|
|
||||||
# LS: Consider moving these two classes into model.py
|
|
||||||
class IPAdapterModelField(BaseModel):
|
|
||||||
key: str = Field(description="Key to the IP-Adapter model")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionModelField(BaseModel):
|
|
||||||
key: str = Field(description="Key to the CLIP Vision image encoder model")
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: ModelField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: ModelField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
@ -62,7 +54,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: IPAdapterModelField = InputField(
|
ip_adapter_model: ModelField = InputField(
|
||||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,18 +82,18 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
|
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.models.search_by_attrs(
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)
|
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
ip_adapter=IPAdapterField(
|
ip_adapter=IPAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
ip_adapter_model=self.ip_adapter_model,
|
ip_adapter_model=self.ip_adapter_model,
|
||||||
image_encoder_model=image_encoder_model,
|
image_encoder_model=ModelField(key=image_encoder_models[0].key),
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
@ -26,6 +26,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
from PIL import Image, ImageFilter
|
from PIL import Image, ImageFilter
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
from transformers import CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
@ -75,7 +76,7 @@ from .baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelField, UNetField, VaeField
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
from torch import mps
|
from torch import mps
|
||||||
@ -153,7 +154,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_tensor is not None:
|
if image_tensor is not None:
|
||||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
@ -244,12 +245,12 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelField,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
orig_scheduler_info = context.models.load(scheduler_info)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
@ -461,7 +462,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# and if weight is None, populate with default 1.0?
|
# and if weight is None, populate with default 1.0?
|
||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
@ -523,11 +524,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||||
|
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
single_ipa_image_fields = single_ip_adapter.image
|
single_ipa_image_fields = single_ip_adapter.image
|
||||||
if not isinstance(single_ipa_image_fields, list):
|
if not isinstance(single_ipa_image_fields, list):
|
||||||
@ -538,6 +538,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
|
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
@ -577,8 +578,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
|
||||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
|
||||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
@ -731,12 +732,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
lora_info = context.models.load(lora.lora)
|
||||||
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
unet_info = context.models.load(self.unet.unet)
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
@ -841,8 +843,8 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL))
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
@ -1064,7 +1066,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
vae_info = context.models.load(self.vae.vae)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
|
@ -8,7 +8,10 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import (
|
||||||
|
CONTROLNET_MODE_VALUES,
|
||||||
|
CONTROLNET_RESIZE_VALUES,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
ImageField,
|
ImageField,
|
||||||
@ -17,10 +20,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
OutputField,
|
OutputField,
|
||||||
UIType,
|
UIType,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
|
|
||||||
@ -30,10 +31,20 @@ class MetadataItemField(BaseModel):
|
|||||||
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
value: Any = Field(description=FieldDescriptions.metadata_item_value)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelMetadataField(BaseModel):
|
||||||
|
"""Model Metadata Field"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
hash: str
|
||||||
|
name: str
|
||||||
|
base: BaseModelType
|
||||||
|
type: ModelType
|
||||||
|
|
||||||
|
|
||||||
class LoRAMetadataField(BaseModel):
|
class LoRAMetadataField(BaseModel):
|
||||||
"""LoRA Metadata Field"""
|
"""LoRA Metadata Field"""
|
||||||
|
|
||||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
model: ModelMetadataField = Field(description=FieldDescriptions.lora_model)
|
||||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +52,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
"""IP Adapter Field, minus the CLIP Vision Encoder model"""
|
||||||
|
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: IPAdapterModelField = Field(
|
ip_adapter_model: ModelMetadataField = Field(
|
||||||
description="The IP-Adapter model.",
|
description="The IP-Adapter model.",
|
||||||
)
|
)
|
||||||
weight: Union[float, list[float]] = Field(
|
weight: Union[float, list[float]] = Field(
|
||||||
@ -51,6 +62,33 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapterMetadataField(BaseModel):
|
||||||
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
|
t2i_adapter_model: ModelMetadataField = Field(description="The T2I-Adapter model to use.")
|
||||||
|
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||||
|
begin_step_percent: float = Field(
|
||||||
|
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = Field(
|
||||||
|
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetMetadataField(BaseModel):
|
||||||
|
image: ImageField = Field(description="The control image")
|
||||||
|
control_model: ModelMetadataField = Field(description="The ControlNet model to use")
|
||||||
|
control_weight: Union[float, list[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
|
begin_step_percent: float = Field(
|
||||||
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = Field(
|
||||||
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("metadata_item_output")
|
@invocation_output("metadata_item_output")
|
||||||
class MetadataItemOutput(BaseInvocationOutput):
|
class MetadataItemOutput(BaseInvocationOutput):
|
||||||
"""Metadata Item Output"""
|
"""Metadata Item Output"""
|
||||||
@ -140,14 +178,14 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The number of skipped CLIP layers",
|
description="The number of skipped CLIP layers",
|
||||||
)
|
)
|
||||||
model: Optional[MainModelField] = InputField(default=None, description="The main model used for inference")
|
model: Optional[ModelMetadataField] = InputField(default=None, description="The main model used for inference")
|
||||||
controlnets: Optional[list[ControlField]] = InputField(
|
controlnets: Optional[list[ControlNetMetadataField]] = InputField(
|
||||||
default=None, description="The ControlNets used for inference"
|
default=None, description="The ControlNets used for inference"
|
||||||
)
|
)
|
||||||
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
ipAdapters: Optional[list[IPAdapterMetadataField]] = InputField(
|
||||||
default=None, description="The IP Adapters used for inference"
|
default=None, description="The IP Adapters used for inference"
|
||||||
)
|
)
|
||||||
t2iAdapters: Optional[list[T2IAdapterField]] = InputField(
|
t2iAdapters: Optional[list[T2IAdapterMetadataField]] = InputField(
|
||||||
default=None, description="The IP Adapters used for inference"
|
default=None, description="The IP Adapters used for inference"
|
||||||
)
|
)
|
||||||
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
loras: Optional[list[LoRAMetadataField]] = InputField(default=None, description="The LoRAs used for inference")
|
||||||
@ -159,7 +197,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
default=None,
|
default=None,
|
||||||
description="The name of the initial image",
|
description="The name of the initial image",
|
||||||
)
|
)
|
||||||
vae: Optional[VAEModelField] = InputField(
|
vae: Optional[ModelMetadataField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The VAE used for decoding, if the main model's default was not used",
|
description="The VAE used for decoding, if the main model's default was not used",
|
||||||
)
|
)
|
||||||
@ -190,7 +228,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# SDXL Refiner
|
# SDXL Refiner
|
||||||
refiner_model: Optional[MainModelField] = InputField(
|
refiner_model: Optional[ModelMetadataField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
description="The SDXL Refiner model used",
|
description="The SDXL Refiner model used",
|
||||||
)
|
)
|
||||||
@ -222,10 +260,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataOutput(
|
as_dict = self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
||||||
metadata=MetadataField.model_validate(
|
as_dict["app_version"] = __version__
|
||||||
self.model_dump(exclude_none=True, exclude={"id", "type", "is_intermediate", "use_cache"})
|
|
||||||
)
|
return MetadataOutput(metadata=MetadataField.model_validate(as_dict))
|
||||||
)
|
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
@ -6,8 +6,8 @@ from pydantic import BaseModel, Field
|
|||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.model_manager.config import SubModelType
|
||||||
|
|
||||||
from ...backend.model_manager import SubModelType
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -16,33 +16,34 @@ from .baseinvocation import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelField(BaseModel):
|
||||||
key: str = Field(description="Key of model as returned by ModelRecordServiceBase.get_model()")
|
key: str = Field(description="Key of the model")
|
||||||
submodel_type: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
|
submodel_type: Optional[SubModelType] = Field(description="Submodel type", default=None)
|
||||||
|
|
||||||
|
|
||||||
class LoraInfo(ModelInfo):
|
class LoRAField(BaseModel):
|
||||||
weight: float = Field(description="Lora's weight which to use when apply to model")
|
lora: ModelField = Field(description="Info to load lora model")
|
||||||
|
weight: float = Field(description="Weight to apply to lora model")
|
||||||
|
|
||||||
|
|
||||||
class UNetField(BaseModel):
|
class UNetField(BaseModel):
|
||||||
unet: ModelInfo = Field(description="Info to load unet submodel")
|
unet: ModelField = Field(description="Info to load unet submodel")
|
||||||
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
|
scheduler: ModelField = Field(description="Info to load scheduler submodel")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||||
|
|
||||||
|
|
||||||
class ClipField(BaseModel):
|
class ClipField(BaseModel):
|
||||||
tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel")
|
tokenizer: ModelField = Field(description="Info to load tokenizer submodel")
|
||||||
text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel")
|
text_encoder: ModelField = Field(description="Info to load text_encoder submodel")
|
||||||
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
|
||||||
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
|
loras: List[LoRAField] = Field(description="Loras to apply on model loading")
|
||||||
|
|
||||||
|
|
||||||
class VaeField(BaseModel):
|
class VaeField(BaseModel):
|
||||||
# TODO: better naming?
|
# TODO: better naming?
|
||||||
vae: ModelInfo = Field(description="Info to load vae submodel")
|
vae: ModelField = Field(description="Info to load vae submodel")
|
||||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||||
|
|
||||||
|
|
||||||
@ -74,18 +75,6 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
|
||||||
"""Main model field"""
|
|
||||||
|
|
||||||
key: str = Field(description="Model key")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelField(BaseModel):
|
|
||||||
"""LoRA model field"""
|
|
||||||
|
|
||||||
key: str = Field(description="LoRA model key")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
title="Main Model",
|
title="Main Model",
|
||||||
@ -96,46 +85,24 @@ class LoRAModelField(BaseModel):
|
|||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
|
|
||||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
model: ModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
key = self.model.key
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.models.exists(key):
|
if not context.models.exists(self.model.key):
|
||||||
raise Exception(f"Unknown model {key}")
|
raise Exception(f"Unknown model {self.model.key}")
|
||||||
|
|
||||||
|
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||||
unet=ModelInfo(
|
clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||||
key=key,
|
vae=VaeField(vae=vae),
|
||||||
submodel_type=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
key=key,
|
|
||||||
submodel_type=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
key=key,
|
|
||||||
submodel_type=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
key=key,
|
|
||||||
submodel_type=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
key=key,
|
|
||||||
submodel_type=SubModelType.VAE,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -151,7 +118,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -167,38 +134,33 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
if self.lora is None:
|
|
||||||
raise Exception("No LoRA provided")
|
|
||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.models.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unkown lora: {lora_key}!")
|
raise Exception(f"Unkown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||||
|
|
||||||
output = LoraLoaderOutput()
|
output = LoraLoaderOutput()
|
||||||
|
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
output.unet = copy.deepcopy(self.unet)
|
output.unet = self.unet.model_copy(deep=True)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoRAField(
|
||||||
key=lora_key,
|
lora=self.lora,
|
||||||
submodel_type=None,
|
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip is not None:
|
if self.clip is not None:
|
||||||
output.clip = copy.deepcopy(self.clip)
|
output.clip = self.clip.model_copy(deep=True)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoRAField(
|
||||||
key=lora_key,
|
lora=self.lora,
|
||||||
submodel_type=None,
|
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -225,7 +187,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
lora: ModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA")
|
||||||
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
|
||||||
unet: Optional[UNetField] = InputField(
|
unet: Optional[UNetField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -247,51 +209,45 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
||||||
if self.lora is None:
|
|
||||||
raise Exception("No LoRA provided")
|
|
||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.models.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unknown lora: {lora_key}!")
|
raise Exception(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.lora.key == lora_key for lora in self.unet.loras):
|
||||||
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
raise Exception(f'Lora "{lora_key}" already applied to unet')
|
||||||
|
|
||||||
if self.clip is not None and any(lora.key == lora_key for lora in self.clip.loras):
|
if self.clip is not None and any(lora.lora.key == lora_key for lora in self.clip.loras):
|
||||||
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
raise Exception(f'Lora "{lora_key}" already applied to clip')
|
||||||
|
|
||||||
if self.clip2 is not None and any(lora.key == lora_key for lora in self.clip2.loras):
|
if self.clip2 is not None and any(lora.lora.key == lora_key for lora in self.clip2.loras):
|
||||||
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
raise Exception(f'Lora "{lora_key}" already applied to clip2')
|
||||||
|
|
||||||
output = SDXLLoraLoaderOutput()
|
output = SDXLLoraLoaderOutput()
|
||||||
|
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
output.unet = copy.deepcopy(self.unet)
|
output.unet = self.unet.model_copy(deep=True)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoRAField(
|
||||||
key=lora_key,
|
lora=self.lora,
|
||||||
submodel_type=None,
|
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip is not None:
|
if self.clip is not None:
|
||||||
output.clip = copy.deepcopy(self.clip)
|
output.clip = self.clip.model_copy(deep=True)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoRAField(
|
||||||
key=lora_key,
|
lora=self.lora,
|
||||||
submodel_type=None,
|
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.clip2 is not None:
|
if self.clip2 is not None:
|
||||||
output.clip2 = copy.deepcopy(self.clip2)
|
output.clip2 = self.clip2.model_copy(deep=True)
|
||||||
output.clip2.loras.append(
|
output.clip2.loras.append(
|
||||||
LoraInfo(
|
LoRAField(
|
||||||
key=lora_key,
|
lora=self.lora,
|
||||||
submodel_type=None,
|
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -299,17 +255,11 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class VAEModelField(BaseModel):
|
|
||||||
"""Vae model field"""
|
|
||||||
|
|
||||||
key: str = Field(description="Model's key")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
vae_model: VAEModelField = InputField(
|
vae_model: ModelField = InputField(
|
||||||
description=FieldDescriptions.vae_model,
|
description=FieldDescriptions.vae_model,
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
title="VAE",
|
title="VAE",
|
||||||
@ -321,7 +271,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(key):
|
if not context.models.exists(key):
|
||||||
raise Exception(f"Unkown vae: {key}!")
|
raise Exception(f"Unkown vae: {key}!")
|
||||||
|
|
||||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
return VAEOutput(vae=VaeField(vae=self.vae_model))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("seamless_output")
|
@invocation_output("seamless_output")
|
||||||
|
@ -8,7 +8,7 @@ from .baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField
|
from .model import ClipField, ModelField, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("sdxl_model_loader_output")
|
@invocation_output("sdxl_model_loader_output")
|
||||||
@ -34,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
model: MainModelField = InputField(
|
model: ModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel
|
||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
@ -46,48 +46,19 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
return SDXLModelLoaderOutput(
|
return SDXLModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||||
unet=ModelInfo(
|
clip=ClipField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
|
||||||
key=model_key,
|
clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||||
submodel_type=SubModelType.UNet,
|
vae=VaeField(vae=vae),
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
clip2=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.Tokenizer2,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.TextEncoder2,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.VAE,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -101,10 +72,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
|
||||||
model: MainModelField = InputField(
|
model: ModelField = InputField(
|
||||||
description=FieldDescriptions.sdxl_refiner_model,
|
description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel
|
||||||
input=Input.Direct,
|
|
||||||
ui_type=UIType.SDXLRefinerModel,
|
|
||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
@ -115,34 +84,14 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
if not context.models.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
|
unet = self.model.model_copy(update={"submodel_type": SubModelType.UNet})
|
||||||
|
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||||
|
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
return SDXLRefinerModelLoaderOutput(
|
return SDXLRefinerModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(unet=unet, scheduler=scheduler, loras=[]),
|
||||||
unet=ModelInfo(
|
clip2=ClipField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
|
||||||
key=model_key,
|
vae=VaeField(vae=vae),
|
||||||
submodel_type=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip2=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.Tokenizer2,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.TextEncoder2,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
key=model_key,
|
|
||||||
submodel_type=SubModelType.VAE,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
@ -10,17 +10,14 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||||
|
from invokeai.app.invocations.model import ModelField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterModelField(BaseModel):
|
|
||||||
key: str = Field(description="Model record key for the T2I-Adapter model")
|
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterField(BaseModel):
|
class T2IAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
image: ImageField = Field(description="The T2I-Adapter image prompt.")
|
||||||
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
|
t2i_adapter_model: ModelField = Field(description="The T2I-Adapter model to use.")
|
||||||
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
|
||||||
begin_step_percent: float = Field(
|
begin_step_percent: float = Field(
|
||||||
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||||
@ -55,7 +52,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||||
t2i_adapter_model: T2IAdapterModelField = InputField(
|
t2i_adapter_model: ModelField = InputField(
|
||||||
description="The T2I-Adapter model.",
|
description="The T2I-Adapter model.",
|
||||||
title="T2I-Adapter Model",
|
title="T2I-Adapter Model",
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
@ -130,6 +130,17 @@ class ModelRecordServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Retrieve the configuration for the indicated model.
|
||||||
|
|
||||||
|
:param hash: Hash of model config to be fetched.
|
||||||
|
|
||||||
|
Exceptions: UnknownModelException
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_models(
|
def list_models(
|
||||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||||
|
@ -203,6 +203,21 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||||
|
with self._db.lock:
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
|
WHERE hash=?;
|
||||||
|
""",
|
||||||
|
(hash,),
|
||||||
|
)
|
||||||
|
rows = self._cursor.fetchone()
|
||||||
|
if not rows:
|
||||||
|
raise UnknownModelException("model not found")
|
||||||
|
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
|
||||||
|
return model
|
||||||
|
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Return True if a model with the indicated key exists in the databse.
|
Return True if a model with the indicated key exists in the databse.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -13,15 +13,16 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.app.invocations.model import ModelField
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -299,22 +300,25 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
def exists(self, key: str) -> bool:
|
def exists(self, identifier: Union[str, "ModelField"]) -> bool:
|
||||||
"""Checks if a model exists.
|
"""Checks if a model exists.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key of the model.
|
identifier: The key or ModelField representing the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the model exists, False if not.
|
True if the model exists, False if not.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.exists(key)
|
if isinstance(identifier, str):
|
||||||
|
return self._services.model_manager.store.exists(identifier)
|
||||||
|
|
||||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
return self._services.model_manager.store.exists(identifier.key)
|
||||||
|
|
||||||
|
def load(self, identifier: Union[str, "ModelField"], submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""Loads a model.
|
"""Loads a model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key of the model.
|
identifier: The key or ModelField representing the model.
|
||||||
submodel_type: The submodel of the model to get.
|
submodel_type: The submodel of the model to get.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -324,9 +328,13 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
# The model manager emits events as it loads the model. It needs the context data to build
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
# the event payloads.
|
# the event payloads.
|
||||||
|
|
||||||
return self._services.model_manager.load_model_by_key(
|
if isinstance(identifier, str):
|
||||||
key=key, submodel_type=submodel_type, context_data=self._data
|
model = self._services.model_manager.store.get_model(identifier)
|
||||||
)
|
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||||
|
else:
|
||||||
|
_submodel_type = submodel_type or identifier.submodel_type
|
||||||
|
model = self._services.model_manager.store.get_model(identifier.key)
|
||||||
|
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||||
|
|
||||||
def load_by_attrs(
|
def load_by_attrs(
|
||||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||||
@ -343,35 +351,29 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An object representing the loaded model.
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.load_model_by_attr(
|
|
||||||
model_name=name,
|
|
||||||
base_model=base,
|
|
||||||
model_type=type,
|
|
||||||
submodel=submodel_type,
|
|
||||||
context_data=self._data,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_config(self, key: str) -> AnyModelConfig:
|
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||||
|
if len(configs) == 0:
|
||||||
|
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||||
|
|
||||||
|
if len(configs) > 1:
|
||||||
|
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||||
|
|
||||||
|
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||||
|
|
||||||
|
def get_config(self, identifier: Union[str, "ModelField"]) -> AnyModelConfig:
|
||||||
"""Gets a model's config.
|
"""Gets a model's config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key of the model.
|
identifier: The key or ModelField representing the model.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The model's config.
|
The model's config.
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.store.get_model(key=key)
|
if isinstance(identifier, str):
|
||||||
|
return self._services.model_manager.store.get_model(identifier)
|
||||||
|
|
||||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
return self._services.model_manager.store.get_model(identifier.key)
|
||||||
"""Gets a model's metadata, if it has any.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The key of the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The model's metadata, if it has any.
|
|
||||||
"""
|
|
||||||
return self._services.model_manager.store.get_metadata(key=key)
|
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""Searches for models by path.
|
"""Searches for models by path.
|
||||||
|
@ -22,7 +22,7 @@ def generate_ti_list(
|
|||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name_or_key = trigger[1:-1]
|
name_or_key = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
loaded_model = context.models.load(key=name_or_key)
|
loaded_model = context.models.load(name_or_key)
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
assert isinstance(model, TextualInversionModelRaw)
|
assert isinstance(model, TextualInversionModelRaw)
|
||||||
assert loaded_model.config.base == base
|
assert loaded_model.config.base == base
|
||||||
|
@ -35,17 +35,13 @@ from invokeai.app.invocations.metadata import MetadataItemField, MetadataItemOut
|
|||||||
from invokeai.app.invocations.model import (
|
from invokeai.app.invocations.model import (
|
||||||
ClipField,
|
ClipField,
|
||||||
CLIPOutput,
|
CLIPOutput,
|
||||||
LoraInfo,
|
|
||||||
LoraLoaderOutput,
|
LoraLoaderOutput,
|
||||||
LoRAModelField,
|
ModelField,
|
||||||
MainModelField,
|
|
||||||
ModelInfo,
|
|
||||||
ModelLoaderOutput,
|
ModelLoaderOutput,
|
||||||
SDXLLoraLoaderOutput,
|
SDXLLoraLoaderOutput,
|
||||||
UNetField,
|
UNetField,
|
||||||
UNetOutput,
|
UNetOutput,
|
||||||
VaeField,
|
VaeField,
|
||||||
VAEModelField,
|
|
||||||
VAEOutput,
|
VAEOutput,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import (
|
||||||
@ -73,8 +69,8 @@ from invokeai.app.services.image_records.image_records_common import ImageCatego
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
@ -118,14 +114,10 @@ __all__ = [
|
|||||||
"MetadataItemOutput",
|
"MetadataItemOutput",
|
||||||
"MetadataOutput",
|
"MetadataOutput",
|
||||||
# invokeai.app.invocations.model
|
# invokeai.app.invocations.model
|
||||||
"ModelInfo",
|
"ModelField",
|
||||||
"LoraInfo",
|
|
||||||
"UNetField",
|
"UNetField",
|
||||||
"ClipField",
|
"ClipField",
|
||||||
"VaeField",
|
"VaeField",
|
||||||
"MainModelField",
|
|
||||||
"LoRAModelField",
|
|
||||||
"VAEModelField",
|
|
||||||
"UNetOutput",
|
"UNetOutput",
|
||||||
"VAEOutput",
|
"VAEOutput",
|
||||||
"CLIPOutput",
|
"CLIPOutput",
|
||||||
@ -166,7 +158,7 @@ __all__ = [
|
|||||||
# invokeai.app.services.config.config_default
|
# invokeai.app.services.config.config_default
|
||||||
"InvokeAIAppConfig",
|
"InvokeAIAppConfig",
|
||||||
# invokeai.backend.model_management.model_manager
|
# invokeai.backend.model_management.model_manager
|
||||||
"LoadedModelInfo",
|
"LoadedModel",
|
||||||
# invokeai.backend.model_management.models.base
|
# invokeai.backend.model_management.models.base
|
||||||
"BaseModelType",
|
"BaseModelType",
|
||||||
"ModelType",
|
"ModelType",
|
||||||
|
Reference in New Issue
Block a user