mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make latent generation nodes use conditions instead of prompt
This commit is contained in:
parent
d99a08a441
commit
8f460b92f1
@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
|
|||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput, build_image_output
|
||||||
|
from .compel import ConditioningField
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -143,9 +143,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["t2l"] = "t2l"
|
type: Literal["t2l"] = "t2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
# TODO: consider making prompt optional to enable providing prompt through a link
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
prompt: Optional[str] = Field(description="The prompt to generate an image from")
|
positive: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
|
negative: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
seed: int = Field(default=-1,ge=-1, le=np.iinfo(np.uint32).max, description="The seed to use (-1 for a random seed)", )
|
||||||
noise: Optional[LatentsField] = Field(description="The noise to use")
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
@ -206,8 +206,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
c, extra_conditioning_info = context.services.latents.get(self.positive.conditioning_name)
|
||||||
|
uc, _ = context.services.latents.get(self.negative.conditioning_name)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
@ -234,7 +236,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
model = self.get_model(context.services.model_manager)
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(context, model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
|
||||||
|
from ..invocations.compel import CompelInvocation
|
||||||
from ..invocations.params import ParamIntInvocation
|
from ..invocations.params import ParamIntInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage import ItemStorageABC
|
||||||
@ -17,25 +18,28 @@ def create_text_to_image() -> LibraryGraph:
|
|||||||
'width': ParamIntInvocation(id='width', a=512),
|
'width': ParamIntInvocation(id='width', a=512),
|
||||||
'height': ParamIntInvocation(id='height', a=512),
|
'height': ParamIntInvocation(id='height', a=512),
|
||||||
'3': NoiseInvocation(id='3'),
|
'3': NoiseInvocation(id='3'),
|
||||||
'4': TextToLatentsInvocation(id='4'),
|
'4': CompelInvocation(id='4'),
|
||||||
'5': LatentsToImageInvocation(id='5')
|
'5': TextToLatentsInvocation(id='5'),
|
||||||
|
'6': LatentsToImageInvocation(id='6'),
|
||||||
},
|
},
|
||||||
edges=[
|
edges=[
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
|
||||||
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='4', field='width')),
|
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='5', field='width')),
|
||||||
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='4', field='height')),
|
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='5', field='height')),
|
||||||
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')),
|
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='5', field='noise')),
|
||||||
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')),
|
Edge(source=EdgeConnection(node_id='5', field='latents'), destination=EdgeConnection(node_id='6', field='latents')),
|
||||||
|
Edge(source=EdgeConnection(node_id='4', field='positive'), destination=EdgeConnection(node_id='5', field='positive')),
|
||||||
|
Edge(source=EdgeConnection(node_id='4', field='negative'), destination=EdgeConnection(node_id='5', field='negative')),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
exposed_inputs=[
|
exposed_inputs=[
|
||||||
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'),
|
ExposedNodeInput(node_path='4', field='positive_prompt', alias='prompt'), # TODO: cli uses concatenated prompt
|
||||||
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
ExposedNodeInput(node_path='width', field='a', alias='width'),
|
||||||
ExposedNodeInput(node_path='height', field='a', alias='height')
|
ExposedNodeInput(node_path='height', field='a', alias='height')
|
||||||
],
|
],
|
||||||
exposed_outputs=[
|
exposed_outputs=[
|
||||||
ExposedNodeOutput(node_path='5', field='image', alias='image')
|
ExposedNodeOutput(node_path='6', field='image', alias='image')
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user