mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update invocation context for mm2, update nodes model usage
This commit is contained in:
parent
7a36cd2832
commit
8958e820c8
@ -69,20 +69,12 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
**self.clip.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
|
||||||
**self.clip.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
@ -94,10 +86,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
loaded_model = context.services.model_manager.load.load_model_by_key(
|
loaded_model = context.models.load(**self.clip.text_encoder.model_dump()).model
|
||||||
**self.clip.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
).model
|
|
||||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||||
ti_list.append((name, loaded_model))
|
ti_list.append((name, loaded_model))
|
||||||
except UnknownModelException:
|
except UnknownModelException:
|
||||||
@ -165,14 +154,8 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||||
tokenizer_info = context.services.model_manager.load.load_model_by_key(
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
**clip_field.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.load.load_model_by_key(
|
|
||||||
**clip_field.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -197,9 +180,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
lora_model = lora_info.model
|
lora_model = lora_info.model
|
||||||
assert isinstance(lora_model, LoRAModelRaw)
|
assert isinstance(lora_model, LoRAModelRaw)
|
||||||
yield (lora_model, lora.weight)
|
yield (lora_model, lora.weight)
|
||||||
@ -212,11 +193,8 @@ class SDXLPromptInvocationBase:
|
|||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_model = context.services.model_manager.load.load_model_by_attr(
|
ti_model = context.models.load_by_attrs(
|
||||||
model_name=name,
|
model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion
|
||||||
base_model=text_encoder_info.config.base,
|
|
||||||
model_type=ModelType.TextualInversion,
|
|
||||||
context=context,
|
|
||||||
).model
|
).model
|
||||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||||
ti_list.append((name, ti_model))
|
ti_list.append((name, ti_model))
|
||||||
|
@ -14,8 +14,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
|||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
|
|
||||||
|
|
||||||
|
|
||||||
# LS: Consider moving these two classes into model.py
|
# LS: Consider moving these two classes into model.py
|
||||||
@ -90,10 +89,10 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
image_encoder_models = context.models.search_by_attrs(
|
||||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||||
)
|
)
|
||||||
assert len(image_encoder_models) == 1
|
assert len(image_encoder_models) == 1
|
||||||
|
@ -141,7 +141,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||||
if self.image is not None:
|
if self.image is not None:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
image_tensor = image_tensor.unsqueeze(0)
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
@ -153,10 +153,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_tensor is not None:
|
if image_tensor is not None:
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
@ -182,10 +179,7 @@ def get_scheduler(
|
|||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.load.load_model_by_key(
|
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||||
**scheduler_info.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
@ -399,12 +393,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# and if weight is None, populate with default 1.0?
|
# and if weight is None, populate with default 1.0?
|
||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(
|
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||||
context.services.model_manager.load.load_model_by_key(
|
|
||||||
key=control_info.control_model.key,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
@ -466,25 +455,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.services.model_manager.load.load_model_by_key(
|
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||||
key=single_ip_adapter.ip_adapter_model.key,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.services.model_manager.load.load_model_by_key(
|
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||||
key=single_ip_adapter.image_encoder_model.key,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
single_ipa_image_fields = single_ip_adapter.image
|
single_ipa_image_fields = single_ip_adapter.image
|
||||||
if not isinstance(single_ipa_image_fields, list):
|
if not isinstance(single_ipa_image_fields, list):
|
||||||
single_ipa_image_fields = [single_ipa_image_fields]
|
single_ipa_image_fields = [single_ipa_image_fields]
|
||||||
|
|
||||||
single_ipa_images = [
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
|
||||||
]
|
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
@ -528,10 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_info = context.services.model_manager.load.load_model_by_key(
|
t2i_adapter_model_info = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
@ -676,30 +654,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||||
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
unet_config = context.models.get_config(self.unet.unet.key)
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState) -> None:
|
def step_callback(state: PipelineIntermediateState) -> None:
|
||||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
context.util.sd_step_callback(state, unet_config.base)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.load.load_model_by_key(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
yield (lora_info.model, lora.weight)
|
yield (lora_info.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.load.load_model_by_key(
|
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||||
**self.unet.unet.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
@ -806,10 +774,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
assert isinstance(vae, torch.nn.Module)
|
assert isinstance(vae, torch.nn.Module)
|
||||||
@ -1032,10 +997,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.images.get_pil(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.load.load_model_by_key(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
@ -1239,10 +1201,7 @@ class IdealSizeInvocation(BaseInvocation):
|
|||||||
return tuple((x - x % multiple_of) for x in args)
|
return tuple((x - x % multiple_of) for x in args)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||||
unet_config = context.services.model_manager.load.load_model_by_key(
|
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
||||||
**self.unet.unet.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
aspect = self.width / self.height
|
aspect = self.width / self.height
|
||||||
dimension: float = 512
|
dimension: float = 512
|
||||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||||
|
@ -103,7 +103,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
key = self.model.key
|
key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.store.exists(key):
|
if not context.models.exists(key):
|
||||||
raise Exception(f"Unknown model {key}")
|
raise Exception(f"Unknown model {key}")
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
@ -172,7 +172,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.services.model_manager.store.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unkown lora: {lora_key}!")
|
raise Exception(f"Unkown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
@ -252,7 +252,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
lora_key = self.lora.key
|
lora_key = self.lora.key
|
||||||
|
|
||||||
if not context.services.model_manager.store.exists(lora_key):
|
if not context.models.exists(lora_key):
|
||||||
raise Exception(f"Unknown lora: {lora_key}!")
|
raise Exception(f"Unknown lora: {lora_key}!")
|
||||||
|
|
||||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||||
@ -318,7 +318,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||||
key = self.vae_model.key
|
key = self.vae_model.key
|
||||||
|
|
||||||
if not context.services.model_manager.store.exists(key):
|
if not context.models.exists(key):
|
||||||
raise Exception(f"Unkown vae: {key}!")
|
raise Exception(f"Unkown vae: {key}!")
|
||||||
|
|
||||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||||
|
@ -43,7 +43,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.store.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
return SDXLModelLoaderOutput(
|
return SDXLModelLoaderOutput(
|
||||||
@ -112,7 +112,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
model_key = self.model.key
|
model_key = self.model.key
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.store.exists(model_key):
|
if not context.models.exists(model_key):
|
||||||
raise Exception(f"Unknown model: {model_key}")
|
raise Exception(f"Unknown model: {model_key}")
|
||||||
|
|
||||||
return SDXLRefinerModelLoaderOutput(
|
return SDXLRefinerModelLoaderOutput(
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
@ -19,14 +19,14 @@ class ModelLoadServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's key, load it and return the LoadedModel object.
|
Given a model's key, load it and return the LoadedModel object.
|
||||||
|
|
||||||
:param key: Key of model config to be fetched.
|
:param key: Key of model config to be fetched.
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||||
:param context: Invocation context used for event reporting
|
:param context_data: Invocation context data used for event reporting
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -35,14 +35,14 @@ class ModelLoadServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's configuration, load it and return the LoadedModel object.
|
Given a model's configuration, load it and return the LoadedModel object.
|
||||||
|
|
||||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||||
:param context: Invocation context used for event reporting
|
:param context_data: Invocation context data used for event reporting
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -53,7 +53,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||||
@ -66,7 +66,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
:param base_model: Base model
|
:param base_model: Base model
|
||||||
:param model_type: Type of the model
|
:param model_type: Type of the model
|
||||||
:param submodel: For main (pipeline models), the submodel to fetch
|
:param submodel: For main (pipeline models), the submodel to fetch
|
||||||
:param context: The invocation context.
|
:param context_data: The invocation context data.
|
||||||
|
|
||||||
Exceptions: UnknownModelException -- model with these attributes not known
|
Exceptions: UnknownModelException -- model with these attributes not known
|
||||||
NotImplementedException -- a model loader was not provided at initialization time
|
NotImplementedException -- a model loader was not provided at initialization time
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
from invokeai.app.services.model_records import ModelRecordServiceBase, UnknownModelException
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import AnyModelLoader, LoadedModel, ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
@ -46,6 +47,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker = invoker
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||||
"""Return the RAM cache used by this loader."""
|
"""Return the RAM cache used by this loader."""
|
||||||
@ -60,7 +64,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's key, load it and return the LoadedModel object.
|
Given a model's key, load it and return the LoadedModel object.
|
||||||
@ -70,7 +74,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
:param context: Invocation context used for event reporting
|
:param context: Invocation context used for event reporting
|
||||||
"""
|
"""
|
||||||
config = self._store.get_model(key)
|
config = self._store.get_model(key)
|
||||||
return self.load_model_by_config(config, submodel_type, context)
|
return self.load_model_by_config(config, submodel_type, context_data)
|
||||||
|
|
||||||
def load_model_by_attr(
|
def load_model_by_attr(
|
||||||
self,
|
self,
|
||||||
@ -78,7 +82,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||||
@ -109,7 +113,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self,
|
self,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Given a model's configuration, load it and return the LoadedModel object.
|
Given a model's configuration, load it and return the LoadedModel object.
|
||||||
@ -118,15 +122,15 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||||
:param context: Invocation context used for event reporting
|
:param context: Invocation context used for event reporting
|
||||||
"""
|
"""
|
||||||
if context:
|
if context_data:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
loaded_model = self._any_loader.load_model(model_config, submodel_type)
|
||||||
if context:
|
if context_data:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
loaded=True,
|
loaded=True,
|
||||||
)
|
)
|
||||||
@ -134,26 +138,28 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context_data: InvocationContextData,
|
||||||
model_config: AnyModelConfig,
|
model_config: AnyModelConfig,
|
||||||
loaded: Optional[bool] = False,
|
loaded: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if not self._invoker:
|
||||||
|
return
|
||||||
|
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
|
|
||||||
if not loaded:
|
if not loaded:
|
||||||
context.services.events.emit_model_load_started(
|
self._invoker.services.events.emit_model_load_started(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_completed(
|
self._invoker.services.events.emit_model_load_completed(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
@ -12,8 +13,9 @@ from invokeai.app.services.images.images_common import ImageDTO
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
@ -259,45 +261,95 @@ class ConditioningInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsInterface(InvocationContextInterface):
|
class ModelsInterface(InvocationContextInterface):
|
||||||
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
def exists(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Checks if a model exists.
|
Checks if a model exists.
|
||||||
|
|
||||||
:param model_name: The name of the model to check.
|
:param key: The key of the model.
|
||||||
:param base_model: The base model of the model to check.
|
|
||||||
:param model_type: The type of the model to check.
|
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.model_exists(model_name, base_model, model_type)
|
return self._services.model_manager.store.exists(key)
|
||||||
|
|
||||||
def load(
|
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
|
||||||
) -> LoadedModelInfo:
|
|
||||||
"""
|
"""
|
||||||
Loads a model.
|
Loads a model.
|
||||||
|
|
||||||
:param model_name: The name of the model to get.
|
:param key: The key of the model.
|
||||||
:param base_model: The base model of the model to get.
|
:param submodel_type: The submodel of the model to get.
|
||||||
:param model_type: The type of the model to get.
|
|
||||||
:param submodel: The submodel of the model to get.
|
|
||||||
:returns: An object representing the loaded model.
|
:returns: An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The model manager emits events as it loads the model. It needs the context data to build
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
# the event payloads.
|
# the event payloads.
|
||||||
|
|
||||||
return self._services.model_manager.get_model(
|
return self._services.model_manager.load.load_model_by_key(
|
||||||
model_name, base_model, model_type, submodel, context_data=self._context_data
|
key=key, submodel_type=submodel_type, context_data=self._context_data
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
def load_by_attrs(
|
||||||
|
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||||
|
) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Loads a model by its attributes.
|
||||||
|
|
||||||
|
:param model_name: Name of to be fetched.
|
||||||
|
:param base_model: Base model
|
||||||
|
:param model_type: Type of the model
|
||||||
|
:param submodel: For main (pipeline models), the submodel to fetch
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.load.load_model_by_attr(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=submodel,
|
||||||
|
context_data=self._context_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_config(self, key: str) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Gets a model's info, an dict-like object.
|
Gets a model's info, an dict-like object.
|
||||||
|
|
||||||
:param model_name: The name of the model to get.
|
:param key: The key of the model.
|
||||||
:param base_model: The base model of the model to get.
|
|
||||||
:param model_type: The type of the model to get.
|
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.model_info(model_name, base_model, model_type)
|
return self._services.model_manager.store.get_model(key=key)
|
||||||
|
|
||||||
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||||
|
"""
|
||||||
|
Gets a model's metadata, if it has any.
|
||||||
|
|
||||||
|
:param key: The key of the model.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.store.get_metadata(key=key)
|
||||||
|
|
||||||
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
|
"""
|
||||||
|
Searches for models by path.
|
||||||
|
|
||||||
|
:param path: The path to search for.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.store.search_by_path(path)
|
||||||
|
|
||||||
|
def search_by_attrs(
|
||||||
|
self,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
base_model: Optional[BaseModelType] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
|
model_format: Optional[ModelFormat] = None,
|
||||||
|
) -> list[AnyModelConfig]:
|
||||||
|
"""
|
||||||
|
Searches for models by attributes.
|
||||||
|
|
||||||
|
:param model_name: Name of to be fetched.
|
||||||
|
:param base_model: Base model
|
||||||
|
:param model_type: Type of the model
|
||||||
|
:param submodel: For main (pipeline models), the submodel to fetch
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._services.model_manager.store.search_by_attr(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
model_format=model_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
|
@ -4,8 +4,8 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||||
|
from invokeai.backend.model_manager.config import BaseModelType
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user