feat(nodes): move ConditioningFieldData to conditioning_data.py

This commit is contained in:
psychedelicious 2024-01-15 10:41:25 +11:00
parent 6452c706e1
commit 05fb485d33
4 changed files with 9 additions and 10 deletions

View File

@ -5,7 +5,6 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.fields import (
ConditioningFieldData,
FieldDescriptions,
Input,
InputField,
@ -15,6 +14,7 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningFieldData,
ExtraConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -1,13 +1,11 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Optional, Tuple
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
from invokeai.app.util.metaenum import MetaEnum
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import BasicConditioningInfo
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
@ -544,11 +542,6 @@ class ColorField(BaseModel):
return (self.r, self.g, self.b, self.a)
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@ -6,7 +6,7 @@ from PIL.Image import Image
from pydantic import ConfigDict
from torch import Tensor
from invokeai.app.invocations.fields import ConditioningFieldData, MetadataField, WithMetadata
from invokeai.app.invocations.fields import MetadataField, WithMetadata
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
@ -17,6 +17,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation

View File

@ -32,6 +32,11 @@ class BasicConditioningInfo:
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor