Make latent generation nodes use conditions instead of prompt

This commit is contained in:
StAlKeR7779 2023-04-25 04:21:03 +03:00
parent d99a08a441
commit 8f460b92f1
2 changed files with 20 additions and 14 deletions

View File

@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers import diffusers
@ -143,9 +143,9 @@ class TextToLatentsInvocation(BaseInvocation):
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
# Inputs # Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off # fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", ) seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
@ -206,8 +206,10 @@ class TextToLatentsInvocation(BaseInvocation):
return model return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
uc, _ = context.services.latents.get(self.negative.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,
c, c,
@ -234,7 +236,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -17,25 +18,28 @@ def create_text_to_image() -> LibraryGraph:
'width': ParamIntInvocation(id='width', a=512), 'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512), 'height': ParamIntInvocation(id='height', a=512),
'3': NoiseInvocation(id='3'), '3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'), '4': CompelInvocation(id='4'),
'5': LatentsToImageInvocation(id='5') '5': TextToLatentsInvocation(id='5'),
'6': LatentsToImageInvocation(id='6'),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='5', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='5', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')),
] ]
), ),
exposed_inputs=[ exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), ExposedNodeInput(node_path='4', field='positive_prompt', alias='prompt'), # TODO: cli uses concatenated prompt
ExposedNodeInput(node_path='width', field='a', alias='width'), ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height') ExposedNodeInput(node_path='height', field='a', alias='height')
], ],
exposed_outputs=[ exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image') ExposedNodeOutput(node_path='6', field='image', alias='image')
]) ])