From 3de43907113c316f48fd96bcd9afd773675a4a00 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 15 Jan 2024 10:41:25 +1100 Subject: [PATCH] feat(nodes): move `ConditioningFieldData` to `conditioning_data.py` --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/fields.py | 9 +-------- invokeai/app/services/shared/invocation_context.py | 3 ++- .../stable_diffusion/diffusion/conditioning_data.py | 5 +++++ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b4496031bc..94caf4128d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -5,7 +5,6 @@ from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from invokeai.app.invocations.fields import ( - ConditioningFieldData, FieldDescriptions, Input, InputField, @@ -15,6 +14,7 @@ from invokeai.app.invocations.fields import ( from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, + ConditioningFieldData, ExtraConditioningInfo, SDXLConditioningInfo, ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 566babbb6b..8879f76077 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -1,13 +1,11 @@ -from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter from pydantic.fields import _Unset from pydantic_core import PydanticUndefined from invokeai.app.util.metaenum import MetaEnum -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -544,11 +542,6 @@ class ColorField(BaseModel): return (self.r, self.g, self.b, self.a) -@dataclass -class ConditioningFieldData: - conditionings: List[BasicConditioningInfo] - - class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 023274d49f..3cf3952de0 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -6,7 +6,7 @@ from PIL.Image import Image from pydantic import ConfigDict from torch import Tensor -from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata +from invokeai.app.invocations.fields import MetadataField, WithMetadata from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO @@ -17,6 +17,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.model_manager import ModelInfo from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData if TYPE_CHECKING: from invokeai.app.invocations.baseinvocation import BaseInvocation diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 3e38f9f78d..0676555f7a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -32,6 +32,11 @@ class BasicConditioningInfo: return self +@dataclass +class ConditioningFieldData: + conditionings: List[BasicConditioningInfo] + + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): pooled_embeds: torch.Tensor