mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): update all invocations to use new invocation context
Update all invocations to use the new context. The changes are all fairly simple, but there are a lot of them. Supporting minor changes: - Patch bump for all nodes that use the context - Update invocation processor to provide new context - Minor change to `EventServiceBase` to accept a node's ID instead of the dict version of a node - Minor change to `ModelManagerService` to support the new wrapped context - Fanagling of imports to avoid circular dependencies
This commit is contained in:
@ -16,10 +16,16 @@ from pydantic import BaseModel, ConfigDict, Field, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.invocations.fields import FieldKind, Input
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
FieldKind,
|
||||
Input,
|
||||
InputFieldJSONSchemaExtra,
|
||||
MetadataField,
|
||||
logger,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -219,7 +225,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||
"""
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
@ -244,23 +250,23 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
if services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = context.services.invocation_cache.create_key(self)
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
key = services.invocation_cache.create_key(self)
|
||||
cached_value = services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
context.services.invocation_cache.save(key, output)
|
||||
services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = Field(
|
||||
@ -513,3 +519,29 @@ def invocation_output(
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
"""
|
||||
Inherit from this class if your node needs a metadata input field.
|
||||
"""
|
||||
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
Reference in New Issue
Block a user