mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies
This commit is contained in:
parent
97a6c6eea7
commit
7e5ba2795e
@ -16,10 +16,16 @@ from pydantic import BaseModel, ConfigDict, Field, create_model
|
|||||||
from pydantic.fields import FieldInfo
|
from pydantic.fields import FieldInfo
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldKind, Input
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
FieldKind,
|
||||||
|
Input,
|
||||||
|
InputFieldJSONSchemaExtra,
|
||||||
|
MetadataField,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.fields import FieldDescriptions
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -219,7 +225,7 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
"""Invoke with provided context and return outputs."""
|
"""Invoke with provided context and return outputs."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||||
"""
|
"""
|
||||||
Internal invoke method, calls `invoke()` after some prep.
|
Internal invoke method, calls `invoke()` after some prep.
|
||||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||||
@ -244,23 +250,23 @@ class BaseInvocation(ABC, BaseModel):
|
|||||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||||
|
|
||||||
# skip node cache codepath if it's disabled
|
# skip node cache codepath if it's disabled
|
||||||
if context.services.configuration.node_cache_size == 0:
|
if services.configuration.node_cache_size == 0:
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
output: BaseInvocationOutput
|
output: BaseInvocationOutput
|
||||||
if self.use_cache:
|
if self.use_cache:
|
||||||
key = context.services.invocation_cache.create_key(self)
|
key = services.invocation_cache.create_key(self)
|
||||||
cached_value = context.services.invocation_cache.get(key)
|
cached_value = services.invocation_cache.get(key)
|
||||||
if cached_value is None:
|
if cached_value is None:
|
||||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||||
output = self.invoke(context)
|
output = self.invoke(context)
|
||||||
context.services.invocation_cache.save(key, output)
|
services.invocation_cache.save(key, output)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||||
return cached_value
|
return cached_value
|
||||||
else:
|
else:
|
||||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||||
return self.invoke(context)
|
return self.invoke(context)
|
||||||
|
|
||||||
id: str = Field(
|
id: str = Field(
|
||||||
@ -513,3 +519,29 @@ def invocation_output(
|
|||||||
return cls
|
return cls
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class WithMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Inherit from this class if your node needs a metadata input field.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: Optional[MetadataField] = Field(
|
||||||
|
default=None,
|
||||||
|
description=FieldDescriptions.metadata,
|
||||||
|
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||||
|
field_kind=FieldKind.Internal,
|
||||||
|
input=Input.Connection,
|
||||||
|
orig_required=False,
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WithWorkflow:
|
||||||
|
workflow = None
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
logger.warn(
|
||||||
|
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||||
|
)
|
||||||
|
super().__init_subclass__()
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import ValidationInfo, field_validator
|
|||||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField
|
from .fields import InputField
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ class RangeInvocation(BaseInvocation):
|
|||||||
raise ValueError("stop must be greater than start")
|
raise ValueError("stop must be greater than start")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context) -> IntegerCollectionOutput:
|
||||||
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
return IntegerCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
|||||||
size: int = InputField(default=1, gt=0, description="The number of values")
|
size: int = InputField(default=1, gt=0, description="The number of values")
|
||||||
step: int = InputField(default=1, description="The step of the range")
|
step: int = InputField(default=1, description="The step of the range")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context) -> IntegerCollectionOutput:
|
||||||
return IntegerCollectionOutput(
|
return IntegerCollectionOutput(
|
||||||
collection=list(range(self.start, self.start + (self.step * self.size), self.step))
|
collection=list(range(self.start, self.start + (self.step * self.size), self.step))
|
||||||
)
|
)
|
||||||
@ -72,6 +72,6 @@ class RandomRangeInvocation(BaseInvocation):
|
|||||||
description="The seed for the RNG (omit for random)",
|
description="The seed for the RNG (omit for random)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context) -> IntegerCollectionOutput:
|
||||||
rng = np.random.default_rng(self.seed)
|
rng = np.random.default_rng(self.seed)
|
||||||
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
return IntegerCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
|
||||||
|
@ -1,12 +1,18 @@
|
|||||||
from dataclasses import dataclass
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
ConditioningFieldData,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
@ -20,16 +26,14 @@ from ..util.ti_utils import extract_ti_triggers_from_prompt
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConditioningFieldData:
|
|
||||||
conditionings: List[BasicConditioningInfo]
|
|
||||||
# unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
@ -44,7 +48,7 @@ class ConditioningFieldData:
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -61,26 +65,18 @@ class CompelInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||||
**self.clip.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**self.clip.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||||
@ -89,11 +85,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
|
||||||
).context.model,
|
).context.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -124,7 +119,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
@ -145,34 +140,23 @@ class CompelInvocation(BaseInvocation):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
def run_clip_compel(
|
def run_clip_compel(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
clip_field: ClipField,
|
clip_field: ClipField,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
get_pooled: bool,
|
get_pooled: bool,
|
||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
):
|
):
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||||
**clip_field.tokenizer.model_dump(),
|
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
text_encoder_info = context.services.model_manager.get_model(
|
|
||||||
**clip_field.text_encoder.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -196,14 +180,12 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in clip_field.loras:
|
for lora in clip_field.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}), context=context
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||||
@ -212,11 +194,10 @@ class SDXLPromptInvocationBase:
|
|||||||
ti_list.append(
|
ti_list.append(
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=clip_field.text_encoder.base_model,
|
base_model=clip_field.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
model_type=ModelType.TextualInversion,
|
||||||
context=context,
|
|
||||||
).context.model,
|
).context.model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -249,7 +230,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(prompt)
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
|
|
||||||
if context.services.configuration.log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
# TODO: better logging for and syntax
|
# TODO: better logging for and syntax
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||||
|
|
||||||
@ -282,7 +263,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -307,7 +288,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
context, self.clip, self.prompt, False, "lora_te1_", zero_on_empty=True
|
||||||
)
|
)
|
||||||
@ -364,14 +345,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -379,7 +355,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
title="SDXL Refiner Prompt",
|
title="SDXL Refiner Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -397,7 +373,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context) -> ConditioningOutput:
|
||||||
# TODO: if there will appear lora for refiner - write proper prefix
|
# TODO: if there will appear lora for refiner - write proper prefix
|
||||||
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>", zero_on_empty=False)
|
||||||
|
|
||||||
@ -417,14 +393,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
context.services.latents.save(conditioning_name, conditioning_data)
|
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput.build(conditioning_name)
|
||||||
conditioning=ConditioningField(
|
|
||||||
conditioning_name=conditioning_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
@ -447,7 +418,7 @@ class ClipSkipInvocation(BaseInvocation):
|
|||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
def invoke(self, context) -> ClipSkipInvocationOutput:
|
||||||
self.clip.skipped_layers += self.skipped_layers
|
self.clip.skipped_layers += self.skipped_layers
|
||||||
return ClipSkipInvocationOutput(
|
return ClipSkipInvocationOutput(
|
||||||
clip=self.clip,
|
clip=self.clip,
|
||||||
|
@ -25,18 +25,17 @@ from controlnet_aux.util import HWC3, ade_palette
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, WithMetadata
|
from invokeai.app.invocations.baseinvocation import WithMetadata
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||||
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -121,7 +120,7 @@ class ControlNetInvocation(BaseInvocation):
|
|||||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
def invoke(self, context) -> ControlOutput:
|
||||||
return ControlOutput(
|
return ControlOutput(
|
||||||
control=ControlField(
|
control=ControlField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
@ -145,23 +144,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
|||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
raw_image = context.images.get_pil(self.image.image_name)
|
||||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||||
processed_image = self.run_processor(raw_image)
|
processed_image = self.run_processor(raw_image)
|
||||||
|
|
||||||
# currently can't see processed image in node UI without a showImage node,
|
# currently can't see processed image in node UI without a showImage node,
|
||||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=processed_image)
|
||||||
image=processed_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.CONTROL,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
node_id=self.id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
"""Builds an ImageOutput and its ImageField"""
|
||||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||||
@ -180,7 +170,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
|||||||
title="Canny Processor",
|
title="Canny Processor",
|
||||||
tags=["controlnet", "canny"],
|
tags=["controlnet", "canny"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Canny edge detection for ControlNet"""
|
"""Canny edge detection for ControlNet"""
|
||||||
@ -203,7 +193,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="HED (softedge) Processor",
|
title="HED (softedge) Processor",
|
||||||
tags=["controlnet", "hed", "softedge"],
|
tags=["controlnet", "hed", "softedge"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies HED edge detection to image"""
|
"""Applies HED edge detection to image"""
|
||||||
@ -232,7 +222,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Processor",
|
title="Lineart Processor",
|
||||||
tags=["controlnet", "lineart"],
|
tags=["controlnet", "lineart"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art processing to image"""
|
"""Applies line art processing to image"""
|
||||||
@ -254,7 +244,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Lineart Anime Processor",
|
title="Lineart Anime Processor",
|
||||||
tags=["controlnet", "lineart", "anime"],
|
tags=["controlnet", "lineart", "anime"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies line art anime processing to image"""
|
"""Applies line art anime processing to image"""
|
||||||
@ -277,7 +267,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Midas Depth Processor",
|
title="Midas Depth Processor",
|
||||||
tags=["controlnet", "midas"],
|
tags=["controlnet", "midas"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Midas depth processing to image"""
|
"""Applies Midas depth processing to image"""
|
||||||
@ -304,7 +294,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Normal BAE Processor",
|
title="Normal BAE Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies NormalBae processing to image"""
|
"""Applies NormalBae processing to image"""
|
||||||
@ -321,7 +311,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
|
||||||
)
|
)
|
||||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies MLSD processing to image"""
|
"""Applies MLSD processing to image"""
|
||||||
@ -344,7 +334,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
|
||||||
)
|
)
|
||||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies PIDI processing to image"""
|
"""Applies PIDI processing to image"""
|
||||||
@ -371,7 +361,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Content Shuffle Processor",
|
title="Content Shuffle Processor",
|
||||||
tags=["controlnet", "contentshuffle"],
|
tags=["controlnet", "contentshuffle"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies content shuffle processing to image"""
|
"""Applies content shuffle processing to image"""
|
||||||
@ -401,7 +391,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Zoe (Depth) Processor",
|
title="Zoe (Depth) Processor",
|
||||||
tags=["controlnet", "zoe", "depth"],
|
tags=["controlnet", "zoe", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
@ -417,7 +407,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Mediapipe Face Processor",
|
title="Mediapipe Face Processor",
|
||||||
tags=["controlnet", "mediapipe", "face"],
|
tags=["controlnet", "mediapipe", "face"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies mediapipe face processing to image"""
|
"""Applies mediapipe face processing to image"""
|
||||||
@ -440,7 +430,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Leres (Depth) Processor",
|
title="Leres (Depth) Processor",
|
||||||
tags=["controlnet", "leres", "depth"],
|
tags=["controlnet", "leres", "depth"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies leres processing to image"""
|
"""Applies leres processing to image"""
|
||||||
@ -469,7 +459,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Tile Resample Processor",
|
title="Tile Resample Processor",
|
||||||
tags=["controlnet", "tile"],
|
tags=["controlnet", "tile"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Tile resampler processor"""
|
"""Tile resampler processor"""
|
||||||
@ -509,7 +499,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
|||||||
title="Segment Anything Processor",
|
title="Segment Anything Processor",
|
||||||
tags=["controlnet", "segmentanything"],
|
tags=["controlnet", "segmentanything"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Applies segment anything processing to image"""
|
"""Applies segment anything processing to image"""
|
||||||
@ -551,7 +541,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
|||||||
title="Color Map Processor",
|
title="Color Map Processor",
|
||||||
tags=["controlnet"],
|
tags=["controlnet"],
|
||||||
category="controlnet",
|
category="controlnet",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||||
"""Generates a color map from the provided image"""
|
"""Generates a color map from the provided image"""
|
||||||
|
@ -5,23 +5,23 @@ import cv2 as cv
|
|||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithMetadata
|
from .fields import InputField, WithMetadata
|
||||||
|
|
||||||
|
|
||||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to inpaint")
|
image: ImageField = InputField(description="The image to inpaint")
|
||||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
mask = context.images.get_pil(self.mask.image_name)
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
# TODO: consider making these utility functions
|
# TODO: consider making these utility functions
|
||||||
@ -35,18 +35,6 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
|||||||
# TODO: consider making a utility function
|
# TODO: consider making a utility function
|
||||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=image_inpainted)
|
||||||
image=image_inpainted,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, TypedDict
|
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,13 +13,16 @@ from pydantic import field_validator
|
|||||||
import invokeai.assets.fonts as font_assets
|
import invokeai.assets.fonts as font_assets
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationContext,
|
WithMetadata,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import InputField, OutputField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, InputField, OutputField
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("face_mask_output")
|
@invocation_output("face_mask_output")
|
||||||
@ -174,7 +177,7 @@ def prepare_faces_list(
|
|||||||
|
|
||||||
|
|
||||||
def generate_face_box_mask(
|
def generate_face_box_mask(
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
minimum_confidence: float,
|
minimum_confidence: float,
|
||||||
x_offset: float,
|
x_offset: float,
|
||||||
y_offset: float,
|
y_offset: float,
|
||||||
@ -273,7 +276,7 @@ def generate_face_box_mask(
|
|||||||
|
|
||||||
|
|
||||||
def extract_face(
|
def extract_face(
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
image: ImageType,
|
image: ImageType,
|
||||||
face: FaceResultData,
|
face: FaceResultData,
|
||||||
padding: int,
|
padding: int,
|
||||||
@ -304,37 +307,37 @@ def extract_face(
|
|||||||
|
|
||||||
# Adjust the crop boundaries to stay within the original image's dimensions
|
# Adjust the crop boundaries to stay within the original image's dimensions
|
||||||
if x_min < 0:
|
if x_min < 0:
|
||||||
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||||
x_max -= x_min
|
x_max -= x_min
|
||||||
x_min = 0
|
x_min = 0
|
||||||
elif x_max > mask.width:
|
elif x_max > mask.width:
|
||||||
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||||
x_min -= x_max - mask.width
|
x_min -= x_max - mask.width
|
||||||
x_max = mask.width
|
x_max = mask.width
|
||||||
|
|
||||||
if y_min < 0:
|
if y_min < 0:
|
||||||
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||||
y_max -= y_min
|
y_max -= y_min
|
||||||
y_min = 0
|
y_min = 0
|
||||||
elif y_max > mask.height:
|
elif y_max > mask.height:
|
||||||
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
context.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||||
y_min -= y_max - mask.height
|
y_min -= y_max - mask.height
|
||||||
y_max = mask.height
|
y_max = mask.height
|
||||||
|
|
||||||
# Ensure the crop is square and adjust the boundaries if needed
|
# Ensure the crop is square and adjust the boundaries if needed
|
||||||
if x_max - x_min != crop_size:
|
if x_max - x_min != crop_size:
|
||||||
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||||
diff = crop_size - (x_max - x_min)
|
diff = crop_size - (x_max - x_min)
|
||||||
x_min -= diff // 2
|
x_min -= diff // 2
|
||||||
x_max += diff - diff // 2
|
x_max += diff - diff // 2
|
||||||
|
|
||||||
if y_max - y_min != crop_size:
|
if y_max - y_min != crop_size:
|
||||||
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||||
diff = crop_size - (y_max - y_min)
|
diff = crop_size - (y_max - y_min)
|
||||||
y_min -= diff // 2
|
y_min -= diff // 2
|
||||||
y_max += diff - diff // 2
|
y_max += diff - diff // 2
|
||||||
|
|
||||||
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||||
|
|
||||||
# Crop the output image to the specified size with the center of the face mesh as the center.
|
# Crop the output image to the specified size with the center of the face mesh as the center.
|
||||||
mask = mask.crop((x_min, y_min, x_max, y_max))
|
mask = mask.crop((x_min, y_min, x_max, y_max))
|
||||||
@ -354,7 +357,7 @@ def extract_face(
|
|||||||
|
|
||||||
|
|
||||||
def get_faces_list(
|
def get_faces_list(
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
image: ImageType,
|
image: ImageType,
|
||||||
should_chunk: bool,
|
should_chunk: bool,
|
||||||
minimum_confidence: float,
|
minimum_confidence: float,
|
||||||
@ -366,7 +369,7 @@ def get_faces_list(
|
|||||||
|
|
||||||
# Generate the face box mask and get the center of the face.
|
# Generate the face box mask and get the center of the face.
|
||||||
if not should_chunk:
|
if not should_chunk:
|
||||||
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
context.logger.info("FaceTools --> Attempting full image face detection.")
|
||||||
result = generate_face_box_mask(
|
result = generate_face_box_mask(
|
||||||
context=context,
|
context=context,
|
||||||
minimum_confidence=minimum_confidence,
|
minimum_confidence=minimum_confidence,
|
||||||
@ -378,7 +381,7 @@ def get_faces_list(
|
|||||||
draw_mesh=draw_mesh,
|
draw_mesh=draw_mesh,
|
||||||
)
|
)
|
||||||
if should_chunk or len(result) == 0:
|
if should_chunk or len(result) == 0:
|
||||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
image_chunks = []
|
image_chunks = []
|
||||||
x_offsets = []
|
x_offsets = []
|
||||||
@ -397,7 +400,7 @@ def get_faces_list(
|
|||||||
x_offsets.append(x)
|
x_offsets.append(x)
|
||||||
y_offsets.append(0)
|
y_offsets.append(0)
|
||||||
fx += increment
|
fx += increment
|
||||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
context.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||||
elif height > width:
|
elif height > width:
|
||||||
# Portrait - slice the image vertically
|
# Portrait - slice the image vertically
|
||||||
fy = 0.0
|
fy = 0.0
|
||||||
@ -409,10 +412,10 @@ def get_faces_list(
|
|||||||
x_offsets.append(0)
|
x_offsets.append(0)
|
||||||
y_offsets.append(y)
|
y_offsets.append(y)
|
||||||
fy += increment
|
fy += increment
|
||||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
context.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||||
|
|
||||||
for idx in range(len(image_chunks)):
|
for idx in range(len(image_chunks)):
|
||||||
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||||
result = result + generate_face_box_mask(
|
result = result + generate_face_box_mask(
|
||||||
context=context,
|
context=context,
|
||||||
minimum_confidence=minimum_confidence,
|
minimum_confidence=minimum_confidence,
|
||||||
@ -426,7 +429,7 @@ def get_faces_list(
|
|||||||
|
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
# Give up
|
# Give up
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -435,7 +438,7 @@ def get_faces_list(
|
|||||||
return all_faces
|
return all_faces
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
|
||||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||||
|
|
||||||
@ -456,7 +459,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def faceoff(self, context: InvocationContext, image: ImageType) -> Optional[ExtractFaceData]:
|
def faceoff(self, context: "InvocationContext", image: ImageType) -> Optional[ExtractFaceData]:
|
||||||
all_faces = get_faces_list(
|
all_faces = get_faces_list(
|
||||||
context=context,
|
context=context,
|
||||||
image=image,
|
image=image,
|
||||||
@ -468,11 +471,11 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if len(all_faces) == 0:
|
if len(all_faces) == 0:
|
||||||
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
context.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.face_id > len(all_faces) - 1:
|
if self.face_id > len(all_faces) - 1:
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -483,8 +486,8 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
return face_data
|
return face_data
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
def invoke(self, context) -> FaceOffOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result = self.faceoff(context=context, image=image)
|
result = self.faceoff(context=context, image=image)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
@ -498,24 +501,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
x = result["x_min"]
|
x = result["x_min"]
|
||||||
y = result["y_min"]
|
y = result["y_min"]
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result_image)
|
||||||
image=result_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_dto = context.services.images.create(
|
mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK)
|
||||||
image=result_mask,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.MASK,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = FaceOffOutput(
|
output = FaceOffOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
@ -529,7 +517,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
|
||||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Face mask creation using mediapipe face detection"""
|
"""Face mask creation using mediapipe face detection"""
|
||||||
|
|
||||||
@ -556,7 +544,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
|
raise ValueError('Face IDs must be a comma-separated list of integers (e.g. "1,2,3")')
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def facemask(self, context: InvocationContext, image: ImageType) -> FaceMaskResult:
|
def facemask(self, context: "InvocationContext", image: ImageType) -> FaceMaskResult:
|
||||||
all_faces = get_faces_list(
|
all_faces = get_faces_list(
|
||||||
context=context,
|
context=context,
|
||||||
image=image,
|
image=image,
|
||||||
@ -578,7 +566,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
if len(intersected_face_ids) == 0:
|
if len(intersected_face_ids) == 0:
|
||||||
id_range_str = ",".join([str(id) for id in id_range])
|
id_range_str = ",".join([str(id) for id in id_range])
|
||||||
context.services.logger.warning(
|
context.logger.warning(
|
||||||
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
||||||
)
|
)
|
||||||
return FaceMaskResult(
|
return FaceMaskResult(
|
||||||
@ -613,28 +601,13 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
mask=mask_pil,
|
mask=mask_pil,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
def invoke(self, context) -> FaceMaskOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result = self.facemask(context=context, image=image)
|
result = self.facemask(context=context, image=image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result["image"])
|
||||||
image=result["image"],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_dto = context.services.images.create(
|
mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK)
|
||||||
image=result["mask"],
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.MASK,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = FaceMaskOutput(
|
output = FaceMaskOutput(
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
@ -647,7 +620,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
|
||||||
)
|
)
|
||||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||||
@ -661,7 +634,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
|||||||
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
description="Whether to bypass full image face detection and default to image chunking. Chunking will occur if no faces are found in the full image.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def faceidentifier(self, context: InvocationContext, image: ImageType) -> ImageType:
|
def faceidentifier(self, context: "InvocationContext", image: ImageType) -> ImageType:
|
||||||
image = image.copy()
|
image = image.copy()
|
||||||
|
|
||||||
all_faces = get_faces_list(
|
all_faces = get_faces_list(
|
||||||
@ -702,22 +675,10 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
result_image = self.faceidentifier(context=context, image=image)
|
result_image = self.faceidentifier(context=context, image=image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=result_image)
|
||||||
image=result_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||||
from pydantic.fields import _Unset
|
from pydantic.fields import _Unset
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
@ -255,6 +257,10 @@ class InputFieldJSONSchemaExtra(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class WithMetadata(BaseModel):
|
class WithMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Inherit from this class if your node needs a metadata input field.
|
||||||
|
"""
|
||||||
|
|
||||||
metadata: Optional[MetadataField] = Field(
|
metadata: Optional[MetadataField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=FieldDescriptions.metadata,
|
description=FieldDescriptions.metadata,
|
||||||
@ -498,4 +504,53 @@ def OutputField(
|
|||||||
field_kind=FieldKind.Output,
|
field_kind=FieldKind.Output,
|
||||||
).model_dump(exclude_none=True),
|
).model_dump(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageField(BaseModel):
|
||||||
|
"""An image primitive field"""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class BoardField(BaseModel):
|
||||||
|
"""A board primitive field"""
|
||||||
|
|
||||||
|
board_id: str = Field(description="The id of the board")
|
||||||
|
|
||||||
|
|
||||||
|
class DenoiseMaskField(BaseModel):
|
||||||
|
"""An inpaint mask field"""
|
||||||
|
|
||||||
|
mask_name: str = Field(description="The name of the mask image")
|
||||||
|
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsField(BaseModel):
|
||||||
|
"""A latents tensor primitive field"""
|
||||||
|
|
||||||
|
latents_name: str = Field(description="The name of the latents")
|
||||||
|
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||||
|
|
||||||
|
|
||||||
|
class ColorField(BaseModel):
|
||||||
|
"""A color primitive field"""
|
||||||
|
|
||||||
|
r: int = Field(ge=0, le=255, description="The red component")
|
||||||
|
g: int = Field(ge=0, le=255, description="The green component")
|
||||||
|
b: int = Field(ge=0, le=255, description="The blue component")
|
||||||
|
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||||
|
|
||||||
|
def tuple(self) -> Tuple[int, int, int, int]:
|
||||||
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConditioningFieldData:
|
||||||
|
conditionings: List[BasicConditioningInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningField(BaseModel):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
# endregion
|
# endregion
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -6,15 +6,15 @@ from typing import Literal, Optional, get_args
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, WithMetadata, invocation
|
||||||
from .fields import InputField, WithMetadata
|
from .fields import InputField
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
|||||||
return si
|
return si
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
@ -129,33 +129,20 @@ class InfillColorInvocation(BaseInvocation, WithMetadata):
|
|||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Infills transparent areas of an image with tiles of the image"""
|
"""Infills transparent areas of an image with tiles of the image"""
|
||||||
|
|
||||||
@ -168,32 +155,19 @@ class InfillTileInvocation(BaseInvocation, WithMetadata):
|
|||||||
description="The seed to use for tile generation (omit for random)",
|
description="The seed to use for tile generation (omit for random)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
|
||||||
)
|
)
|
||||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||||
@ -202,8 +176,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
|||||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||||
|
|
||||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
@ -228,77 +202,38 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
|||||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Infills transparent areas of an image using the LaMa model"""
|
"""Infills transparent areas of an image using the LaMa model"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = infill_lama(image.copy())
|
infilled = infill_lama(image.copy())
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||||
|
|
||||||
image: ImageField = InputField(description="The image to infill")
|
image: ImageField = InputField(description="The image to infill")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
infilled = infill_cv2(image.copy())
|
infilled = infill_cv2(image.copy())
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=infilled)
|
||||||
image=infilled,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -62,7 +61,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
@ -93,9 +92,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.services.model_manager.model_info(
|
ip_adapter_info = context.models.get_info(
|
||||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||||
)
|
)
|
||||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||||
@ -104,7 +103,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||||
# disk vs. downloading the model.
|
# disk vs. downloading the model.
|
||||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
os.path.join(context.config.get().models_path, ip_adapter_info["path"])
|
||||||
)
|
)
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_model = CLIPVisionModelField(
|
image_encoder_model = CLIPVisionModelField(
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,21 +23,26 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType, WithMetadata
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
WithMetadata,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import (
|
||||||
DenoiseMaskField,
|
|
||||||
DenoiseMaskOutput,
|
DenoiseMaskOutput,
|
||||||
ImageField,
|
|
||||||
ImageOutput,
|
ImageOutput,
|
||||||
LatentsField,
|
|
||||||
LatentsOutput,
|
LatentsOutput,
|
||||||
build_latents_output,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||||
@ -59,14 +64,15 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .compel import ConditioningField
|
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
from torch import mps
|
from torch import mps
|
||||||
|
|
||||||
@ -102,7 +108,7 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
ui_type=UIType.Scheduler,
|
ui_type=UIType.Scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SchedulerOutput:
|
def invoke(self, context) -> SchedulerOutput:
|
||||||
return SchedulerOutput(scheduler=self.scheduler)
|
return SchedulerOutput(scheduler=self.scheduler)
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +117,7 @@ class SchedulerInvocation(BaseInvocation):
|
|||||||
title="Create Denoise Mask",
|
title="Create Denoise Mask",
|
||||||
tags=["mask", "denoise"],
|
tags=["mask", "denoise"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||||
"""Creates mask for denoising model run."""
|
"""Creates mask for denoising model run."""
|
||||||
@ -137,9 +143,9 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
return mask_tensor
|
return mask_tensor
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
def invoke(self, context) -> DenoiseMaskOutput:
|
||||||
if self.image is not None:
|
if self.image is not None:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image.dim() == 3:
|
if image.dim() == 3:
|
||||||
image = image.unsqueeze(0)
|
image = image.unsqueeze(0)
|
||||||
@ -147,47 +153,37 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
|||||||
image = None
|
image = None
|
||||||
|
|
||||||
mask = self.prep_mask_tensor(
|
mask = self.prep_mask_tensor(
|
||||||
context.services.images.get_pil_image(self.mask.image_name),
|
context.images.get_pil(self.mask.image_name),
|
||||||
)
|
)
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
# TODO:
|
# TODO:
|
||||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
|
|
||||||
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
masked_latents_name = context.latents.save(tensor=masked_latents)
|
||||||
context.services.latents.save(masked_latents_name, masked_latents)
|
|
||||||
else:
|
else:
|
||||||
masked_latents_name = None
|
masked_latents_name = None
|
||||||
|
|
||||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
mask_name = context.latents.save(tensor=mask)
|
||||||
context.services.latents.save(mask_name, mask)
|
|
||||||
|
|
||||||
return DenoiseMaskOutput(
|
return DenoiseMaskOutput.build(
|
||||||
denoise_mask=DenoiseMaskField(
|
|
||||||
mask_name=mask_name,
|
mask_name=mask_name,
|
||||||
masked_latents_name=masked_latents_name,
|
masked_latents_name=masked_latents_name,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||||
**scheduler_info.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
|
|
||||||
@ -216,7 +212,7 @@ def get_scheduler(
|
|||||||
title="Denoise Latents",
|
title="Denoise Latents",
|
||||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.5.1",
|
version="1.5.2",
|
||||||
)
|
)
|
||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
@ -302,34 +298,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
|
||||||
def dispatch_progress(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
source_node_id: str,
|
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> None:
|
|
||||||
stable_diffusion_step_callback(
|
|
||||||
context=context,
|
|
||||||
intermediate_state=intermediate_state,
|
|
||||||
node=self.model_dump(),
|
|
||||||
source_node_id=source_node_id,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
scheduler,
|
scheduler,
|
||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.conditioning.get(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
extra_conditioning_info = c.extra_conditioning
|
extra_conditioning_info = c.extra_conditioning
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.conditioning.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
@ -389,7 +369,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def prep_control_data(
|
def prep_control_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
control_input: Union[ControlField, List[ControlField]],
|
control_input: Union[ControlField, List[ControlField]],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
@ -417,17 +397,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
controlnet_data = []
|
controlnet_data = []
|
||||||
for control_info in control_list:
|
for control_info in control_list:
|
||||||
control_model = exit_stack.enter_context(
|
control_model = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=control_info.control_model.model_name,
|
model_name=control_info.control_model.model_name,
|
||||||
model_type=ModelType.ControlNet,
|
model_type=ModelType.ControlNet,
|
||||||
base_model=control_info.control_model.base_model,
|
base_model=control_info.control_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
input_image = context.images.get_pil(control_image_field.image_name)
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -463,7 +442,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def prep_ip_adapter_data(
|
def prep_ip_adapter_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
@ -485,19 +464,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data.ip_adapter_conditioning = []
|
conditioning_data.ip_adapter_conditioning = []
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.models.load(
|
||||||
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
model_name=single_ip_adapter.ip_adapter_model.model_name,
|
||||||
model_type=ModelType.IPAdapter,
|
model_type=ModelType.IPAdapter,
|
||||||
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
base_model=single_ip_adapter.ip_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
image_encoder_model_info = context.services.model_manager.get_model(
|
image_encoder_model_info = context.models.load(
|
||||||
model_name=single_ip_adapter.image_encoder_model.model_name,
|
model_name=single_ip_adapter.image_encoder_model.model_name,
|
||||||
model_type=ModelType.CLIPVision,
|
model_type=ModelType.CLIPVision,
|
||||||
base_model=single_ip_adapter.image_encoder_model.base_model,
|
base_model=single_ip_adapter.image_encoder_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
@ -505,7 +482,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not isinstance(single_ipa_images, list):
|
if not isinstance(single_ipa_images, list):
|
||||||
single_ipa_images = [single_ipa_images]
|
single_ipa_images = [single_ipa_images]
|
||||||
|
|
||||||
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_images]
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
@ -532,7 +509,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def run_t2i_adapters(
|
def run_t2i_adapters(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: "InvocationContext",
|
||||||
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
|
||||||
latents_shape: list[int],
|
latents_shape: list[int],
|
||||||
do_classifier_free_guidance: bool,
|
do_classifier_free_guidance: bool,
|
||||||
@ -549,13 +526,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
t2i_adapter_data = []
|
t2i_adapter_data = []
|
||||||
for t2i_adapter_field in t2i_adapter:
|
for t2i_adapter_field in t2i_adapter:
|
||||||
t2i_adapter_model_info = context.services.model_manager.get_model(
|
t2i_adapter_model_info = context.models.load(
|
||||||
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
|
||||||
model_type=ModelType.T2IAdapter,
|
model_type=ModelType.T2IAdapter,
|
||||||
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
|
||||||
context=context,
|
|
||||||
)
|
)
|
||||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||||
|
|
||||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||||
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
|
||||||
@ -642,30 +618,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
return num_inference_steps, timesteps, init_timestep
|
||||||
|
|
||||||
def prep_inpaint_mask(self, context, latents):
|
def prep_inpaint_mask(self, context: "InvocationContext", latents):
|
||||||
if self.denoise_mask is None:
|
if self.denoise_mask is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
mask = context.latents.get(self.denoise_mask.mask_name)
|
||||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||||
if self.denoise_mask.masked_latents_name is not None:
|
if self.denoise_mask.masked_latents_name is not None:
|
||||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
masked_latents = context.latents.get(self.denoise_mask.masked_latents_name)
|
||||||
else:
|
else:
|
||||||
masked_latents = None
|
masked_latents = None
|
||||||
|
|
||||||
return 1 - mask, masked_latents
|
return 1 - mask, masked_latents
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
seed = None
|
seed = None
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.latents.get(self.noise.latents_name)
|
||||||
seed = self.noise.seed
|
seed = self.noise.seed
|
||||||
|
|
||||||
if self.latents is not None:
|
if self.latents is not None:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = self.latents.seed
|
seed = self.latents.seed
|
||||||
|
|
||||||
@ -691,27 +667,17 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state, self.unet.unet.base_model)
|
context.util.sd_step_callback(state, self.unet.unet.base_model)
|
||||||
|
|
||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||||
**lora.model_dump(exclude={"weight"}),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||||
**self.unet.unet.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||||
@ -787,9 +753,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=result_latents)
|
||||||
context.services.latents.save(name, result_latents)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -797,7 +762,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
title="Latents to Image",
|
title="Latents to Image",
|
||||||
tags=["latents", "image", "vae", "l2i"],
|
tags=["latents", "image", "vae", "l2i"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.2.0",
|
version="1.2.1",
|
||||||
)
|
)
|
||||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Generates an image from latents."""
|
"""Generates an image from latents."""
|
||||||
@ -814,13 +779,10 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
|
||||||
latents = latents.to(vae.device)
|
latents = latents.to(vae.device)
|
||||||
@ -849,7 +811,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
latents = latents.half()
|
latents = latents.half()
|
||||||
|
|
||||||
if self.tiled or context.services.configuration.tiled_decode:
|
if self.tiled or context.config.get().tiled_decode:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@ -873,22 +835,9 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=image)
|
||||||
image=image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
@ -899,7 +848,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
|
|||||||
title="Resize Latents",
|
title="Resize Latents",
|
||||||
tags=["latents", "resize"],
|
tags=["latents", "resize"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||||
@ -921,8 +870,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -940,10 +889,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=resized_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
context.services.latents.save(name, resized_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -951,7 +898,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
title="Scale Latents",
|
title="Scale Latents",
|
||||||
tags=["latents", "resize"],
|
tags=["latents", "resize"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
"""Scales latents by a given factor."""
|
"""Scales latents by a given factor."""
|
||||||
@ -964,8 +911,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
@ -984,10 +931,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=resized_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
context.services.latents.save(name, resized_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -995,7 +940,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
title="Image to Latents",
|
title="Image to Latents",
|
||||||
tags=["latents", "image", "vae", "i2l"],
|
tags=["latents", "image", "vae", "i2l"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
"""Encodes an image into latents."""
|
"""Encodes an image into latents."""
|
||||||
@ -1055,13 +1000,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||||
**self.vae.vae.model_dump(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
if image_tensor.dim() == 3:
|
if image_tensor.dim() == 3:
|
||||||
@ -1069,10 +1011,9 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
name = context.latents.save(tensor=latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||||
|
|
||||||
@singledispatchmethod
|
@singledispatchmethod
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1092,7 +1033,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
title="Blend Latents",
|
title="Blend Latents",
|
||||||
tags=["latents", "blend"],
|
tags=["latents", "blend"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class BlendLatentsInvocation(BaseInvocation):
|
class BlendLatentsInvocation(BaseInvocation):
|
||||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||||
@ -1107,9 +1048,9 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
latents_a = context.latents.get(self.latents_a.latents_name)
|
||||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
latents_b = context.latents.get(self.latents_b.latents_name)
|
||||||
|
|
||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
@ -1163,10 +1104,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
if device == torch.device("mps"):
|
if device == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=blended_latents)
|
||||||
# context.services.latents.set(name, resized_latents)
|
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||||
context.services.latents.save(name, blended_latents)
|
|
||||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
|
||||||
|
|
||||||
|
|
||||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||||
@ -1176,7 +1115,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
title="Crop Latents",
|
title="Crop Latents",
|
||||||
tags=["latents", "crop"],
|
tags=["latents", "crop"],
|
||||||
category="latents",
|
category="latents",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||||
@ -1210,8 +1149,8 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
|||||||
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
x1 = self.x // LATENT_SCALE_FACTOR
|
x1 = self.x // LATENT_SCALE_FACTOR
|
||||||
y1 = self.y // LATENT_SCALE_FACTOR
|
y1 = self.y // LATENT_SCALE_FACTOR
|
||||||
@ -1220,10 +1159,9 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
|||||||
|
|
||||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=cropped_latents)
|
||||||
context.services.latents.save(name, cropped_latents)
|
|
||||||
|
|
||||||
return build_latents_output(latents_name=name, latents=cropped_latents)
|
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("ideal_size_output")
|
@invocation_output("ideal_size_output")
|
||||||
|
@ -8,7 +8,7 @@ from pydantic import ValidationInfo, field_validator
|
|||||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
||||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
|
|
||||||
|
|
||||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||||
@ -18,7 +18,7 @@ class AddInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class SubtractInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=self.a - self.b)
|
return IntegerOutput(value=self.a - self.b)
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +40,7 @@ class MultiplyInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=self.a * self.b)
|
return IntegerOutput(value=self.a * self.b)
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +51,7 @@ class DivideInvocation(BaseInvocation):
|
|||||||
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
a: int = InputField(default=0, description=FieldDescriptions.num_1)
|
||||||
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
b: int = InputField(default=0, description=FieldDescriptions.num_2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=int(self.a / self.b))
|
return IntegerOutput(value=int(self.a / self.b))
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class RandomIntInvocation(BaseInvocation):
|
|||||||
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
low: int = InputField(default=0, description=FieldDescriptions.inclusive_low)
|
||||||
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
high: int = InputField(default=np.iinfo(np.int32).max, description=FieldDescriptions.exclusive_high)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
return IntegerOutput(value=np.random.randint(self.low, self.high))
|
||||||
|
|
||||||
|
|
||||||
@ -88,7 +88,7 @@ class RandomFloatInvocation(BaseInvocation):
|
|||||||
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
high: float = InputField(default=1.0, description=FieldDescriptions.exclusive_high)
|
||||||
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
decimals: int = InputField(default=2, description=FieldDescriptions.decimal_places)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context) -> FloatOutput:
|
||||||
random_float = np.random.uniform(self.low, self.high)
|
random_float = np.random.uniform(self.low, self.high)
|
||||||
rounded_float = round(random_float, self.decimals)
|
rounded_float = round(random_float, self.decimals)
|
||||||
return FloatOutput(value=rounded_float)
|
return FloatOutput(value=rounded_float)
|
||||||
@ -110,7 +110,7 @@ class FloatToIntegerInvocation(BaseInvocation):
|
|||||||
default="Nearest", description="The method to use for rounding"
|
default="Nearest", description="The method to use for rounding"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
if self.method == "Nearest":
|
if self.method == "Nearest":
|
||||||
return IntegerOutput(value=round(self.value / self.multiple) * self.multiple)
|
return IntegerOutput(value=round(self.value / self.multiple) * self.multiple)
|
||||||
elif self.method == "Floor":
|
elif self.method == "Floor":
|
||||||
@ -128,7 +128,7 @@ class RoundInvocation(BaseInvocation):
|
|||||||
value: float = InputField(default=0, description="The float value")
|
value: float = InputField(default=0, description="The float value")
|
||||||
decimals: int = InputField(default=0, description="The number of decimal places")
|
decimals: int = InputField(default=0, description="The number of decimal places")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context) -> FloatOutput:
|
||||||
return FloatOutput(value=round(self.value, self.decimals))
|
return FloatOutput(value=round(self.value, self.decimals))
|
||||||
|
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ class IntegerMathInvocation(BaseInvocation):
|
|||||||
raise ValueError("Result of exponentiation is not an integer")
|
raise ValueError("Result of exponentiation is not an integer")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
||||||
if self.operation == "ADD":
|
if self.operation == "ADD":
|
||||||
return IntegerOutput(value=self.a + self.b)
|
return IntegerOutput(value=self.a + self.b)
|
||||||
@ -270,7 +270,7 @@ class FloatMathInvocation(BaseInvocation):
|
|||||||
raise ValueError("Root operation resulted in a complex number")
|
raise ValueError("Root operation resulted in a complex number")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context) -> FloatOutput:
|
||||||
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
# Python doesn't support switch statements until 3.10, but InvokeAI supports back to 3.9
|
||||||
if self.operation == "ADD":
|
if self.operation == "ADD":
|
||||||
return FloatOutput(value=self.a + self.b)
|
return FloatOutput(value=self.a + self.b)
|
||||||
|
@ -5,15 +5,20 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, MetadataField, OutputField, UIType
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
InputField,
|
||||||
|
MetadataField,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
@ -59,7 +64,7 @@ class MetadataItemInvocation(BaseInvocation):
|
|||||||
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
label: str = InputField(description=FieldDescriptions.metadata_item_label)
|
||||||
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
value: Any = InputField(description=FieldDescriptions.metadata_item_value, ui_type=UIType.Any)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataItemOutput:
|
def invoke(self, context) -> MetadataItemOutput:
|
||||||
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
|
return MetadataItemOutput(item=MetadataItemField(label=self.label, value=self.value))
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +81,7 @@ class MetadataInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.metadata_item_polymorphic
|
description=FieldDescriptions.metadata_item_polymorphic
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
def invoke(self, context) -> MetadataOutput:
|
||||||
if isinstance(self.items, MetadataItemField):
|
if isinstance(self.items, MetadataItemField):
|
||||||
# single metadata item
|
# single metadata item
|
||||||
data = {self.items.label: self.items.value}
|
data = {self.items.label: self.items.value}
|
||||||
@ -95,7 +100,7 @@ class MergeMetadataInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
|
collection: list[MetadataField] = InputField(description=FieldDescriptions.metadata_collection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
def invoke(self, context) -> MetadataOutput:
|
||||||
data = {}
|
data = {}
|
||||||
for item in self.collection:
|
for item in self.collection:
|
||||||
data.update(item.model_dump())
|
data.update(item.model_dump())
|
||||||
@ -213,7 +218,7 @@ class CoreMetadataInvocation(BaseInvocation):
|
|||||||
description="The start value used for refiner denoising",
|
description="The start value used for refiner denoising",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MetadataOutput:
|
def invoke(self, context) -> MetadataOutput:
|
||||||
"""Collects and outputs a CoreMetadata object"""
|
"""Collects and outputs a CoreMetadata object"""
|
||||||
|
|
||||||
return MetadataOutput(
|
return MetadataOutput(
|
||||||
|
@ -10,7 +10,6 @@ from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -102,7 +101,7 @@ class LoRAModelField(BaseModel):
|
|||||||
title="Main Model",
|
title="Main Model",
|
||||||
tags=["model"],
|
tags=["model"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class MainModelLoaderInvocation(BaseInvocation):
|
class MainModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a main model, outputting its submodels."""
|
"""Loads a main model, outputting its submodels."""
|
||||||
@ -110,13 +109,13 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
def invoke(self, context) -> ModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -203,7 +202,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
|||||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||||
|
|
||||||
|
|
||||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||||
class LoraLoaderInvocation(BaseInvocation):
|
class LoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
|
|
||||||
@ -222,14 +221,14 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
title="CLIP",
|
title="CLIP",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context) -> LoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
raise Exception("No LoRA provided")
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
base_model = self.lora.base_model
|
base_model = self.lora.base_model
|
||||||
lora_name = self.lora.model_name
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
@ -285,7 +284,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
|||||||
title="SDXL LoRA",
|
title="SDXL LoRA",
|
||||||
tags=["lora", "model"],
|
tags=["lora", "model"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||||
"""Apply selected lora to unet and text_encoder."""
|
"""Apply selected lora to unet and text_encoder."""
|
||||||
@ -311,14 +310,14 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
|||||||
title="CLIP 2",
|
title="CLIP 2",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
|
def invoke(self, context) -> SDXLLoraLoaderOutput:
|
||||||
if self.lora is None:
|
if self.lora is None:
|
||||||
raise Exception("No LoRA provided")
|
raise Exception("No LoRA provided")
|
||||||
|
|
||||||
base_model = self.lora.base_model
|
base_model = self.lora.base_model
|
||||||
lora_name = self.lora.model_name
|
lora_name = self.lora.model_name
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=lora_name,
|
model_name=lora_name,
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
@ -384,7 +383,7 @@ class VAEModelField(BaseModel):
|
|||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||||
class VaeLoaderInvocation(BaseInvocation):
|
class VaeLoaderInvocation(BaseInvocation):
|
||||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||||
|
|
||||||
@ -394,12 +393,12 @@ class VaeLoaderInvocation(BaseInvocation):
|
|||||||
title="VAE",
|
title="VAE",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
def invoke(self, context) -> VAEOutput:
|
||||||
base_model = self.vae_model.base_model
|
base_model = self.vae_model.base_model
|
||||||
model_name = self.vae_model.model_name
|
model_name = self.vae_model.model_name
|
||||||
model_type = ModelType.Vae
|
model_type = ModelType.Vae
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -449,7 +448,7 @@ class SeamlessModeInvocation(BaseInvocation):
|
|||||||
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
|
||||||
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
|
def invoke(self, context) -> SeamlessModeOutput:
|
||||||
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
|
||||||
unet = copy.deepcopy(self.unet)
|
unet = copy.deepcopy(self.unet)
|
||||||
vae = copy.deepcopy(self.vae)
|
vae = copy.deepcopy(self.vae)
|
||||||
@ -485,6 +484,6 @@ class FreeUInvocation(BaseInvocation):
|
|||||||
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
s1: float = InputField(default=0.9, ge=-1, le=3, description=FieldDescriptions.freeu_s1)
|
||||||
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
s2: float = InputField(default=0.2, ge=-1, le=3, description=FieldDescriptions.freeu_s2)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> UNetOutput:
|
def invoke(self, context) -> UNetOutput:
|
||||||
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
self.unet.freeu_config = FreeUConfig(s1=self.s1, s2=self.s2, b1=self.b1, b2=self.b2)
|
||||||
return UNetOutput(unet=self.unet)
|
return UNetOutput(unet=self.unet)
|
||||||
|
@ -4,15 +4,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField
|
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
|
||||||
from invokeai.app.invocations.latent import LatentsField
|
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -67,9 +65,9 @@ class NoiseOutput(BaseInvocationOutput):
|
|||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput":
|
||||||
return NoiseOutput(
|
return cls(
|
||||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.size()[3] * 8,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
@ -114,7 +112,7 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context) -> NoiseOutput:
|
||||||
noise = get_noise(
|
noise = get_noise(
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
@ -122,6 +120,5 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = context.latents.save(tensor=noise)
|
||||||
context.services.latents.save(name, noise)
|
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
|
||||||
|
@ -37,7 +37,7 @@ from .baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler
|
||||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
ORT_TO_NP_TYPE = {
|
||||||
@ -63,7 +63,7 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context) -> ConditioningOutput:
|
||||||
tokenizer_info = context.services.model_manager.get_model(
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
**self.clip.tokenizer.model_dump(),
|
**self.clip.tokenizer.model_dump(),
|
||||||
)
|
)
|
||||||
@ -201,7 +201,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# based on
|
# based on
|
||||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -342,7 +342,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
)
|
)
|
||||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||||
@ -417,7 +417,7 @@ class OnnxModelLoaderInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
def invoke(self, context) -> ONNXModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.ONNX
|
model_type = ModelType.ONNX
|
||||||
|
@ -41,7 +41,7 @@ from matplotlib.ticker import MaxNLocator
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField
|
from .fields import InputField
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ class FloatLinearRangeInvocation(BaseInvocation):
|
|||||||
description="number of values to interpolate over (including start and stop)",
|
description="number of values to interpolate over (including start and stop)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context) -> FloatCollectionOutput:
|
||||||
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
param_list = list(np.linspace(self.start, self.stop, self.steps))
|
||||||
return FloatCollectionOutput(collection=param_list)
|
return FloatCollectionOutput(collection=param_list)
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
|||||||
title="Step Param Easing",
|
title="Step Param Easing",
|
||||||
tags=["step", "easing"],
|
tags=["step", "easing"],
|
||||||
category="step",
|
category="step",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class StepParamEasingInvocation(BaseInvocation):
|
class StepParamEasingInvocation(BaseInvocation):
|
||||||
"""Experimental per-step parameter easing for denoising steps"""
|
"""Experimental per-step parameter easing for denoising steps"""
|
||||||
@ -130,7 +130,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
# alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing")
|
||||||
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
show_easing_plot: bool = InputField(default=False, description="show easing plot")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context) -> FloatCollectionOutput:
|
||||||
log_diagnostics = False
|
log_diagnostics = False
|
||||||
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
|
||||||
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
# start_step = int(np.floor(self.num_steps * self.start_step_percent))
|
||||||
@ -149,19 +149,19 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
postlist = list(num_poststeps * [self.post_end_value])
|
postlist = list(num_poststeps * [self.post_end_value])
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("start_step: " + str(start_step))
|
context.logger.debug("start_step: " + str(start_step))
|
||||||
context.services.logger.debug("end_step: " + str(end_step))
|
context.logger.debug("end_step: " + str(end_step))
|
||||||
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||||
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||||
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
context.services.logger.debug("prelist: " + str(prelist))
|
context.logger.debug("prelist: " + str(prelist))
|
||||||
context.services.logger.debug("postlist: " + str(postlist))
|
context.logger.debug("postlist: " + str(postlist))
|
||||||
|
|
||||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("easing class: " + str(easing_class))
|
context.logger.debug("easing class: " + str(easing_class))
|
||||||
easing_list = []
|
easing_list = []
|
||||||
if self.mirror: # "expected" mirroring
|
if self.mirror: # "expected" mirroring
|
||||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||||
@ -172,7 +172,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
|
|
||||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||||
easing_function = easing_class(
|
easing_function = easing_class(
|
||||||
start=self.start_value,
|
start=self.start_value,
|
||||||
@ -184,14 +184,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
easing_val = easing_function.ease(step_index)
|
easing_val = easing_function.ease(step_index)
|
||||||
base_easing_vals.append(easing_val)
|
base_easing_vals.append(easing_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||||
if even_num_steps:
|
if even_num_steps:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||||
else:
|
else:
|
||||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||||
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||||
easing_list = base_easing_vals + mirror_easing_vals
|
easing_list = base_easing_vals + mirror_easing_vals
|
||||||
|
|
||||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||||
@ -226,12 +226,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
|||||||
step_val = easing_function.ease(step_index)
|
step_val = easing_function.ease(step_index)
|
||||||
easing_list.append(step_val)
|
easing_list.append(step_val)
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||||
|
|
||||||
if log_diagnostics:
|
if log_diagnostics:
|
||||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||||
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||||
|
|
||||||
param_list = prelist + easing_list + postlist
|
param_list = prelist + easing_list + postlist
|
||||||
|
|
||||||
|
@ -1,16 +1,26 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
|
ColorField,
|
||||||
|
ConditioningField,
|
||||||
|
DenoiseMaskField,
|
||||||
|
FieldDescriptions,
|
||||||
|
ImageField,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
LatentsField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -49,7 +59,7 @@ class BooleanInvocation(BaseInvocation):
|
|||||||
|
|
||||||
value: bool = InputField(default=False, description="The boolean value")
|
value: bool = InputField(default=False, description="The boolean value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanOutput:
|
def invoke(self, context) -> BooleanOutput:
|
||||||
return BooleanOutput(value=self.value)
|
return BooleanOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +75,7 @@ class BooleanCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
def invoke(self, context) -> BooleanCollectionOutput:
|
||||||
return BooleanCollectionOutput(collection=self.collection)
|
return BooleanCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -98,7 +108,7 @@ class IntegerInvocation(BaseInvocation):
|
|||||||
|
|
||||||
value: int = InputField(default=0, description="The integer value")
|
value: int = InputField(default=0, description="The integer value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
def invoke(self, context) -> IntegerOutput:
|
||||||
return IntegerOutput(value=self.value)
|
return IntegerOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +124,7 @@ class IntegerCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
def invoke(self, context) -> IntegerCollectionOutput:
|
||||||
return IntegerCollectionOutput(collection=self.collection)
|
return IntegerCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +155,7 @@ class FloatInvocation(BaseInvocation):
|
|||||||
|
|
||||||
value: float = InputField(default=0.0, description="The float value")
|
value: float = InputField(default=0.0, description="The float value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
def invoke(self, context) -> FloatOutput:
|
||||||
return FloatOutput(value=self.value)
|
return FloatOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@ -161,7 +171,7 @@ class FloatCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[float] = InputField(default=[], description="The collection of float values")
|
collection: list[float] = InputField(default=[], description="The collection of float values")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
def invoke(self, context) -> FloatCollectionOutput:
|
||||||
return FloatCollectionOutput(collection=self.collection)
|
return FloatCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -192,7 +202,7 @@ class StringInvocation(BaseInvocation):
|
|||||||
|
|
||||||
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
value: str = InputField(default="", description="The string value", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context) -> StringOutput:
|
||||||
return StringOutput(value=self.value)
|
return StringOutput(value=self.value)
|
||||||
|
|
||||||
|
|
||||||
@ -208,7 +218,7 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[str] = InputField(default=[], description="The collection of string values")
|
collection: list[str] = InputField(default=[], description="The collection of string values")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context) -> StringCollectionOutput:
|
||||||
return StringCollectionOutput(collection=self.collection)
|
return StringCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -217,18 +227,6 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
# region Image
|
# region Image
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
|
||||||
"""An image primitive field"""
|
|
||||||
|
|
||||||
image_name: str = Field(description="The name of the image")
|
|
||||||
|
|
||||||
|
|
||||||
class BoardField(BaseModel):
|
|
||||||
"""A board primitive field"""
|
|
||||||
|
|
||||||
board_id: str = Field(description="The id of the board")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_output")
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
@ -237,6 +235,14 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
width: int = OutputField(description="The width of the image in pixels")
|
width: int = OutputField(description="The width of the image in pixels")
|
||||||
height: int = OutputField(description="The height of the image in pixels")
|
height: int = OutputField(description="The height of the image in pixels")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, image_dto: ImageDTO) -> "ImageOutput":
|
||||||
|
return cls(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_collection_output")
|
@invocation_output("image_collection_output")
|
||||||
class ImageCollectionOutput(BaseInvocationOutput):
|
class ImageCollectionOutput(BaseInvocationOutput):
|
||||||
@ -247,7 +253,7 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
||||||
class ImageInvocation(
|
class ImageInvocation(
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
):
|
):
|
||||||
@ -255,8 +261,8 @@ class ImageInvocation(
|
|||||||
|
|
||||||
image: ImageField = InputField(description="The image to load")
|
image: ImageField = InputField(description="The image to load")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_name=self.image.image_name),
|
image=ImageField(image_name=self.image.image_name),
|
||||||
@ -277,7 +283,7 @@ class ImageCollectionInvocation(BaseInvocation):
|
|||||||
|
|
||||||
collection: list[ImageField] = InputField(description="The collection of image values")
|
collection: list[ImageField] = InputField(description="The collection of image values")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageCollectionOutput:
|
def invoke(self, context) -> ImageCollectionOutput:
|
||||||
return ImageCollectionOutput(collection=self.collection)
|
return ImageCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
@ -286,32 +292,24 @@ class ImageCollectionInvocation(BaseInvocation):
|
|||||||
# region DenoiseMask
|
# region DenoiseMask
|
||||||
|
|
||||||
|
|
||||||
class DenoiseMaskField(BaseModel):
|
|
||||||
"""An inpaint mask field"""
|
|
||||||
|
|
||||||
mask_name: str = Field(description="The name of the mask image")
|
|
||||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("denoise_mask_output")
|
@invocation_output("denoise_mask_output")
|
||||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, mask_name: str, masked_latents_name: Optional[str] = None) -> "DenoiseMaskOutput":
|
||||||
|
return cls(
|
||||||
|
denoise_mask=DenoiseMaskField(mask_name=mask_name, masked_latents_name=masked_latents_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Latents
|
# region Latents
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
|
||||||
"""A latents tensor primitive field"""
|
|
||||||
|
|
||||||
latents_name: str = Field(description="The name of the latents")
|
|
||||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("latents_output")
|
@invocation_output("latents_output")
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single latents tensor"""
|
"""Base class for nodes that output a single latents tensor"""
|
||||||
@ -322,6 +320,14 @@ class LatentsOutput(BaseInvocationOutput):
|
|||||||
width: int = OutputField(description=FieldDescriptions.width)
|
width: int = OutputField(description=FieldDescriptions.width)
|
||||||
height: int = OutputField(description=FieldDescriptions.height)
|
height: int = OutputField(description=FieldDescriptions.height)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput":
|
||||||
|
return cls(
|
||||||
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
|
width=latents.size()[3] * 8,
|
||||||
|
height=latents.size()[2] * 8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("latents_collection_output")
|
@invocation_output("latents_collection_output")
|
||||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||||
@ -333,17 +339,17 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
||||||
)
|
)
|
||||||
class LatentsInvocation(BaseInvocation):
|
class LatentsInvocation(BaseInvocation):
|
||||||
"""A latents tensor primitive value"""
|
"""A latents tensor primitive value"""
|
||||||
|
|
||||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
return build_latents_output(self.latents.latents_name, latents)
|
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -360,35 +366,15 @@ class LatentsCollectionInvocation(BaseInvocation):
|
|||||||
description="The collection of latents tensors",
|
description="The collection of latents tensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsCollectionOutput:
|
def invoke(self, context) -> LatentsCollectionOutput:
|
||||||
return LatentsCollectionOutput(collection=self.collection)
|
return LatentsCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None):
|
|
||||||
return LatentsOutput(
|
|
||||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
|
||||||
width=latents.size()[3] * 8,
|
|
||||||
height=latents.size()[2] * 8,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region Color
|
# region Color
|
||||||
|
|
||||||
|
|
||||||
class ColorField(BaseModel):
|
|
||||||
"""A color primitive field"""
|
|
||||||
|
|
||||||
r: int = Field(ge=0, le=255, description="The red component")
|
|
||||||
g: int = Field(ge=0, le=255, description="The green component")
|
|
||||||
b: int = Field(ge=0, le=255, description="The blue component")
|
|
||||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
|
||||||
|
|
||||||
def tuple(self) -> Tuple[int, int, int, int]:
|
|
||||||
return (self.r, self.g, self.b, self.a)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("color_output")
|
@invocation_output("color_output")
|
||||||
class ColorOutput(BaseInvocationOutput):
|
class ColorOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single color"""
|
"""Base class for nodes that output a single color"""
|
||||||
@ -411,7 +397,7 @@ class ColorInvocation(BaseInvocation):
|
|||||||
|
|
||||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=255), description="The color value")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ColorOutput:
|
def invoke(self, context) -> ColorOutput:
|
||||||
return ColorOutput(color=self.color)
|
return ColorOutput(color=self.color)
|
||||||
|
|
||||||
|
|
||||||
@ -420,18 +406,16 @@ class ColorInvocation(BaseInvocation):
|
|||||||
# region Conditioning
|
# region Conditioning
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
|
||||||
"""A conditioning tensor primitive value"""
|
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_output")
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||||
|
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_collection_output")
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
@ -454,7 +438,7 @@ class ConditioningInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
conditioning: ConditioningField = InputField(description=FieldDescriptions.cond, input=Input.Connection)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context) -> ConditioningOutput:
|
||||||
return ConditioningOutput(conditioning=self.conditioning)
|
return ConditioningOutput(conditioning=self.conditioning)
|
||||||
|
|
||||||
|
|
||||||
@ -473,7 +457,7 @@ class ConditioningCollectionInvocation(BaseInvocation):
|
|||||||
description="The collection of conditioning tensors",
|
description="The collection of conditioning tensors",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningCollectionOutput:
|
def invoke(self, context) -> ConditioningCollectionOutput:
|
||||||
return ConditioningCollectionOutput(collection=self.collection)
|
return ConditioningCollectionOutput(collection=self.collection)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import field_validator
|
|||||||
|
|
||||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, UIComponent
|
from .fields import InputField, UIComponent
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
|||||||
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
max_prompts: int = InputField(default=1, description="The number of prompts to generate")
|
||||||
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context) -> StringCollectionOutput:
|
||||||
if self.combinatorial:
|
if self.combinatorial:
|
||||||
generator = CombinatorialPromptGenerator()
|
generator = CombinatorialPromptGenerator()
|
||||||
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
prompts = generator.generate(self.prompt, max_prompts=self.max_prompts)
|
||||||
@ -91,7 +91,7 @@ class PromptsFromFileInvocation(BaseInvocation):
|
|||||||
break
|
break
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
def invoke(self, context) -> StringCollectionOutput:
|
||||||
prompts = self.promptsFromFile(
|
prompts = self.promptsFromFile(
|
||||||
self.file_path,
|
self.file_path,
|
||||||
self.pre_prompt,
|
self.pre_prompt,
|
||||||
|
@ -4,7 +4,6 @@ from ...backend.model_management import ModelType, SubModelType
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -30,7 +29,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
|||||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||||
|
|
||||||
|
|
||||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -39,13 +38,13 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
def invoke(self, context) -> SDXLModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -116,7 +115,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
title="SDXL Refiner Model",
|
title="SDXL Refiner Model",
|
||||||
tags=["model", "sdxl", "refiner"],
|
tags=["model", "sdxl", "refiner"],
|
||||||
category="model",
|
category="model",
|
||||||
version="1.0.0",
|
version="1.0.1",
|
||||||
)
|
)
|
||||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
@ -128,13 +127,13 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
def invoke(self, context) -> SDXLRefinerModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
model_type = ModelType.Main
|
model_type = ModelType.Main
|
||||||
|
|
||||||
# TODO: not found exceptions
|
# TODO: not found exceptions
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.models.exists(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -5,7 +5,6 @@ import re
|
|||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -33,7 +32,7 @@ class StringSplitNegInvocation(BaseInvocation):
|
|||||||
|
|
||||||
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
|
string: str = InputField(default="", description="String to split", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringPosNegOutput:
|
def invoke(self, context) -> StringPosNegOutput:
|
||||||
p_string = ""
|
p_string = ""
|
||||||
n_string = ""
|
n_string = ""
|
||||||
brackets_depth = 0
|
brackets_depth = 0
|
||||||
@ -77,7 +76,7 @@ class StringSplitInvocation(BaseInvocation):
|
|||||||
default="", description="Delimiter to spilt with. blank will split on the first whitespace"
|
default="", description="Delimiter to spilt with. blank will split on the first whitespace"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> String2Output:
|
def invoke(self, context) -> String2Output:
|
||||||
result = self.string.split(self.delimiter, 1)
|
result = self.string.split(self.delimiter, 1)
|
||||||
if len(result) == 2:
|
if len(result) == 2:
|
||||||
part1, part2 = result
|
part1, part2 = result
|
||||||
@ -95,7 +94,7 @@ class StringJoinInvocation(BaseInvocation):
|
|||||||
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
|
string_left: str = InputField(default="", description="String Left", ui_component=UIComponent.Textarea)
|
||||||
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context) -> StringOutput:
|
||||||
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
|
return StringOutput(value=((self.string_left or "") + (self.string_right or "")))
|
||||||
|
|
||||||
|
|
||||||
@ -107,7 +106,7 @@ class StringJoinThreeInvocation(BaseInvocation):
|
|||||||
string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea)
|
string_middle: str = InputField(default="", description="String Middle", ui_component=UIComponent.Textarea)
|
||||||
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
string_right: str = InputField(default="", description="String Right", ui_component=UIComponent.Textarea)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context) -> StringOutput:
|
||||||
return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or "")))
|
return StringOutput(value=((self.string_left or "") + (self.string_middle or "") + (self.string_right or "")))
|
||||||
|
|
||||||
|
|
||||||
@ -126,7 +125,7 @@ class StringReplaceInvocation(BaseInvocation):
|
|||||||
default=False, description="Use search string as a regex expression (non regex is case insensitive)"
|
default=False, description="Use search string as a regex expression (non regex is case insensitive)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
def invoke(self, context) -> StringOutput:
|
||||||
pattern = self.search_string or ""
|
pattern = self.search_string or ""
|
||||||
new_string = self.string or ""
|
new_string = self.string or ""
|
||||||
if len(pattern) > 0:
|
if len(pattern) > 0:
|
||||||
|
@ -5,13 +5,11 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
|
|
||||||
@ -91,7 +89,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
|||||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
|
def invoke(self, context) -> T2IAdapterOutput:
|
||||||
return T2IAdapterOutput(
|
return T2IAdapterOutput(
|
||||||
t2i_adapter=T2IAdapterField(
|
t2i_adapter=T2IAdapterField(
|
||||||
image=self.image,
|
image=self.image,
|
||||||
|
@ -8,13 +8,12 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Classification,
|
Classification,
|
||||||
InvocationContext,
|
WithMetadata,
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import Input, InputField, OutputField, WithMetadata
|
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.backend.tiles.tiles import (
|
from invokeai.backend.tiles.tiles import (
|
||||||
calc_tiles_even_split,
|
calc_tiles_even_split,
|
||||||
calc_tiles_min_overlap,
|
calc_tiles_min_overlap,
|
||||||
@ -58,7 +57,7 @@ class CalculateImageTilesInvocation(BaseInvocation):
|
|||||||
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
|
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||||
tiles = calc_tiles_with_overlap(
|
tiles = calc_tiles_with_overlap(
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
@ -101,7 +100,7 @@ class CalculateImageTilesEvenSplitInvocation(BaseInvocation):
|
|||||||
description="The overlap, in pixels, between adjacent tiles.",
|
description="The overlap, in pixels, between adjacent tiles.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||||
tiles = calc_tiles_even_split(
|
tiles = calc_tiles_even_split(
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
@ -131,7 +130,7 @@ class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation):
|
|||||||
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||||
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
|
min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
def invoke(self, context) -> CalculateImageTilesOutput:
|
||||||
tiles = calc_tiles_min_overlap(
|
tiles = calc_tiles_min_overlap(
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
@ -176,7 +175,7 @@ class TileToPropertiesInvocation(BaseInvocation):
|
|||||||
|
|
||||||
tile: Tile = InputField(description="The tile to split into properties.")
|
tile: Tile = InputField(description="The tile to split into properties.")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
|
def invoke(self, context) -> TileToPropertiesOutput:
|
||||||
return TileToPropertiesOutput(
|
return TileToPropertiesOutput(
|
||||||
coords_left=self.tile.coords.left,
|
coords_left=self.tile.coords.left,
|
||||||
coords_right=self.tile.coords.right,
|
coords_right=self.tile.coords.right,
|
||||||
@ -213,7 +212,7 @@ class PairTileImageInvocation(BaseInvocation):
|
|||||||
image: ImageField = InputField(description="The tile image.")
|
image: ImageField = InputField(description="The tile image.")
|
||||||
tile: Tile = InputField(description="The tile properties.")
|
tile: Tile = InputField(description="The tile properties.")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
|
def invoke(self, context) -> PairTileImageOutput:
|
||||||
return PairTileImageOutput(
|
return PairTileImageOutput(
|
||||||
tile_with_image=TileWithImage(
|
tile_with_image=TileWithImage(
|
||||||
tile=self.tile,
|
tile=self.tile,
|
||||||
@ -249,7 +248,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
images = [twi.image for twi in self.tiles_with_images]
|
images = [twi.image for twi in self.tiles_with_images]
|
||||||
tiles = [twi.tile for twi in self.tiles_with_images]
|
tiles = [twi.tile for twi in self.tiles_with_images]
|
||||||
|
|
||||||
@ -265,7 +264,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
# existed in memory at an earlier point in the graph.
|
# existed in memory at an earlier point in the graph.
|
||||||
tile_np_images: list[np.ndarray] = []
|
tile_np_images: list[np.ndarray] = []
|
||||||
for image in images:
|
for image in images:
|
||||||
pil_image = context.services.images.get_pil_image(image.image_name)
|
pil_image = context.images.get_pil(image.image_name)
|
||||||
pil_image = pil_image.convert("RGB")
|
pil_image = pil_image.convert("RGB")
|
||||||
tile_np_images.append(np.array(pil_image))
|
tile_np_images.append(np.array(pil_image))
|
||||||
|
|
||||||
@ -288,18 +287,5 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
|||||||
# Convert into a PIL image and save
|
# Convert into a PIL image and save
|
||||||
pil_image = Image.fromarray(np_image)
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=pil_image)
|
||||||
image=pil_image,
|
return ImageOutput.build(image_dto)
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -8,13 +8,13 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.fields import ImageField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithMetadata
|
from .fields import InputField, WithMetadata
|
||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
@ -30,7 +30,7 @@ if choose_torch_device() == torch.device("mps"):
|
|||||||
from torch import mps
|
from torch import mps
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
|
||||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
@ -42,9 +42,9 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.images.get_pil(self.image.image_name)
|
||||||
models_path = context.services.configuration.models_path
|
models_path = context.config.get().models_path
|
||||||
|
|
||||||
rrdbnet_model = None
|
rrdbnet_model = None
|
||||||
netscale = None
|
netscale = None
|
||||||
@ -88,7 +88,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
netscale = 2
|
netscale = 2
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
context.services.logger.error(msg)
|
context.logger.error(msg)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||||
@ -111,19 +111,6 @@ class ESRGANInvocation(BaseInvocation, WithMetadata):
|
|||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.images.save(image=pil_image)
|
||||||
image=pil_image,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
|
||||||
node_id=self.id,
|
|
||||||
session_id=context.graph_execution_state_id,
|
|
||||||
is_intermediate=self.is_intermediate,
|
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput.build(image_dto)
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
|
||||||
|
@ -55,7 +55,7 @@ class EventServiceBase:
|
|||||||
queue_item_id: int,
|
queue_item_id: int,
|
||||||
queue_batch_id: str,
|
queue_batch_id: str,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
node: dict,
|
node_id: str,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
progress_image: Optional[ProgressImage],
|
progress_image: Optional[ProgressImage],
|
||||||
step: int,
|
step: int,
|
||||||
@ -70,7 +70,7 @@ class EventServiceBase:
|
|||||||
"queue_item_id": queue_item_id,
|
"queue_item_id": queue_item_id,
|
||||||
"queue_batch_id": queue_batch_id,
|
"queue_batch_id": queue_batch_id,
|
||||||
"graph_execution_state_id": graph_execution_state_id,
|
"graph_execution_state_id": graph_execution_state_id,
|
||||||
"node_id": node.get("id"),
|
"node_id": node_id,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||||
"step": step,
|
"step": step,
|
||||||
|
@ -5,11 +5,11 @@ from threading import BoundedSemaphore, Event, Thread
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||||
GESStatsNotFoundError,
|
GESStatsNotFoundError,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
@ -131,16 +131,20 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# which handles a few things:
|
# which handles a few things:
|
||||||
# - nodes that require a value, but get it only from a connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
# - referencing the invocation cache instead of executing the node
|
# - referencing the invocation cache instead of executing the node
|
||||||
outputs = invocation.invoke_internal(
|
context_data = InvocationContextData(
|
||||||
InvocationContext(
|
invocation=invocation,
|
||||||
services=self.__invoker.services,
|
session_id=graph_id,
|
||||||
graph_execution_state_id=graph_execution_state.id,
|
|
||||||
queue_item_id=queue_item.session_queue_item_id,
|
|
||||||
queue_id=queue_item.session_queue_id,
|
|
||||||
queue_batch_id=queue_item.session_queue_batch_id,
|
|
||||||
workflow=queue_item.workflow,
|
workflow=queue_item.workflow,
|
||||||
|
source_node_id=source_node_id,
|
||||||
|
queue_id=queue_item.session_queue_id,
|
||||||
|
queue_item_id=queue_item.session_queue_item_id,
|
||||||
|
batch_id=queue_item.session_queue_batch_id,
|
||||||
)
|
)
|
||||||
|
context = build_invocation_context(
|
||||||
|
services=self.__invoker.services,
|
||||||
|
context_data=context_data,
|
||||||
)
|
)
|
||||||
|
outputs = invocation.invoke_internal(context=context, services=self.__invoker.services)
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
||||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||||
|
@ -5,11 +5,12 @@ from __future__ import annotations
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -21,9 +22,6 @@ from invokeai.backend.model_management import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
class ModelManagerServiceBase(ABC):
|
||||||
"""Responsible for managing models on disk and in memory"""
|
"""Responsible for managing models on disk and in memory"""
|
||||||
@ -49,8 +47,7 @@ class ModelManagerServiceBase(ABC):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
node: Optional[BaseInvocation] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""Retrieve the indicated model with name and type.
|
"""Retrieve the indicated model with name and type.
|
||||||
submodel can be used to get a part (such as the vae)
|
submodel can be used to get a part (such as the vae)
|
||||||
|
@ -11,6 +11,8 @@ from pydantic import Field
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -30,7 +32,7 @@ from invokeai.backend.util import choose_precision, choose_torch_device
|
|||||||
from .model_manager_base import ModelManagerServiceBase
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
pass
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
||||||
@ -86,13 +88,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
)
|
)
|
||||||
logger.info("Model manager service initialized")
|
logger.info("Model manager service initialized")
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self._invoker: Optional[Invoker] = invoker
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
context: Optional[InvocationContext] = None,
|
context_data: Optional[InvocationContextData] = None,
|
||||||
) -> ModelInfo:
|
) -> ModelInfo:
|
||||||
"""
|
"""
|
||||||
Retrieve the indicated model. submodel can be used to get a
|
Retrieve the indicated model. submodel can be used to get a
|
||||||
@ -100,9 +105,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# we can emit model loading events if we are executing with access to the invocation context
|
# we can emit model loading events if we are executing with access to the invocation context
|
||||||
if context:
|
if context_data is not None:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -116,9 +121,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
submodel,
|
submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if context:
|
if context_data is not None:
|
||||||
self._emit_load_event(
|
self._emit_load_event(
|
||||||
context=context,
|
context_data=context_data,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -263,22 +268,25 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
|
|
||||||
def _emit_load_event(
|
def _emit_load_event(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context_data: InvocationContextData,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
model_type: ModelType,
|
model_type: ModelType,
|
||||||
submodel: Optional[SubModelType] = None,
|
submodel: Optional[SubModelType] = None,
|
||||||
model_info: Optional[ModelInfo] = None,
|
model_info: Optional[ModelInfo] = None,
|
||||||
):
|
):
|
||||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
if self._invoker is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._invoker.services.queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException()
|
raise CanceledException()
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
context.services.events.emit_model_load_completed(
|
self._invoker.services.events.emit_model_load_completed(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
@ -286,11 +294,11 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context.services.events.emit_model_load_started(
|
self._invoker.services.events.emit_model_load_started(
|
||||||
queue_id=context.queue_id,
|
queue_id=context_data.queue_id,
|
||||||
queue_item_id=context.queue_item_id,
|
queue_item_id=context_data.queue_item_id,
|
||||||
queue_batch_id=context.queue_batch_id,
|
queue_batch_id=context_data.batch_id,
|
||||||
graph_execution_state_id=context.graph_execution_state_id,
|
graph_execution_state_id=context_data.session_id,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
|
@ -13,7 +13,6 @@ from invokeai.app.invocations import * # noqa: F401 F403
|
|||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
InvocationContext,
|
|
||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
@ -202,7 +201,7 @@ class GraphInvocation(BaseInvocation):
|
|||||||
# TODO: figure out how to create a default here
|
# TODO: figure out how to create a default here
|
||||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
def invoke(self, context) -> GraphInvocationOutput:
|
||||||
"""Invoke with provided services and return outputs."""
|
"""Invoke with provided services and return outputs."""
|
||||||
return GraphInvocationOutput()
|
return GraphInvocationOutput()
|
||||||
|
|
||||||
@ -228,7 +227,7 @@ class IterateInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
def invoke(self, context) -> IterateInvocationOutput:
|
||||||
"""Produces the outputs as values"""
|
"""Produces the outputs as values"""
|
||||||
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
|
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
|
||||||
|
|
||||||
@ -255,7 +254,7 @@ class CollectInvocation(BaseInvocation):
|
|||||||
description="The collection, will be provided on execution", default=[], ui_hidden=True
|
description="The collection, will be provided on execution", default=[], ui_hidden=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
def invoke(self, context) -> CollectInvocationOutput:
|
||||||
"""Invoke with provided services and return outputs."""
|
"""Invoke with provided services and return outputs."""
|
||||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||||
|
|
||||||
|
@ -6,8 +6,7 @@ from PIL.Image import Image
|
|||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata
|
||||||
from invokeai.app.invocations.fields import MetadataField, WithMetadata
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
@ -245,13 +244,15 @@ class ConditioningInterface:
|
|||||||
)
|
)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def get(conditioning_name: str) -> Tensor:
|
def get(conditioning_name: str) -> ConditioningFieldData:
|
||||||
"""
|
"""
|
||||||
Gets conditioning data by name.
|
Gets conditioning data by name.
|
||||||
|
|
||||||
:param conditioning_name: The name of the conditioning data to get.
|
:param conditioning_name: The name of the conditioning data to get.
|
||||||
"""
|
"""
|
||||||
return services.latents.get(conditioning_name)
|
# TODO(sm): We are (ab)using the latents storage service as a general pickle storage
|
||||||
|
# service, but it is typed as returning tensors, so we need to ignore the type here.
|
||||||
|
return services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||||
|
|
||||||
self.save = save
|
self.save = save
|
||||||
self.get = get
|
self.get = get
|
||||||
|
@ -1,25 +1,18 @@
|
|||||||
from typing import Protocol
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||||
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util.util import image_to_dataURL
|
from ...backend.util.util import image_to_dataURL
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
class StepCallback(Protocol):
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
def __call__(
|
from invokeai.app.services.invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
self,
|
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||||
intermediate_state: PipelineIntermediateState,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||||
@ -38,11 +31,11 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
|||||||
|
|
||||||
|
|
||||||
def stable_diffusion_step_callback(
|
def stable_diffusion_step_callback(
|
||||||
context_data: InvocationContextData,
|
context_data: "InvocationContextData",
|
||||||
intermediate_state: PipelineIntermediateState,
|
intermediate_state: PipelineIntermediateState,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
invocation_queue: InvocationQueueABC,
|
invocation_queue: "InvocationQueueABC",
|
||||||
events: EventServiceBase,
|
events: "EventServiceBase",
|
||||||
) -> None:
|
) -> None:
|
||||||
if invocation_queue.is_canceled(context_data.session_id):
|
if invocation_queue.is_canceled(context_data.session_id):
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
Loading…
Reference in New Issue
Block a user