mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add seed to latents field
This commit is contained in:
parent
492bfe002a
commit
5f29526a8e
@ -49,6 +49,7 @@ class LatentsField(BaseModel):
|
|||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
|
seed: Optional[int] = Field(description="Seed used to generate this latents")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
@ -67,9 +68,9 @@ class LatentsOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
|
||||||
return LatentsOutput(
|
return LatentsOutput(
|
||||||
latents=LatentsField(latents_name=latents_name),
|
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.size()[3] * 8,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
)
|
)
|
||||||
@ -175,6 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler,
|
scheduler,
|
||||||
unet,
|
unet,
|
||||||
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -201,7 +203,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(seed),
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
@ -336,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
seed = self.noise.seed or 0
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -370,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline,
|
model=pipeline,
|
||||||
@ -407,7 +410,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
||||||
@ -440,10 +443,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
with SilenceWarnings(): # this quenches NSFW nag from diffusers
|
||||||
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
seed = self.latents.seed or 0
|
||||||
|
|
||||||
noise = None
|
noise = None
|
||||||
if self.noise is not None:
|
if self.noise is not None:
|
||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
latent = context.services.latents.get(self.latents.latents_name)
|
if self.noise.seed is not None:
|
||||||
|
seed = self.noise.seed
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -480,7 +487,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline,
|
model=pipeline,
|
||||||
@ -521,7 +528,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||||
|
|
||||||
|
|
||||||
# Latent to image
|
# Latent to image
|
||||||
@ -663,7 +670,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
class ScaleLatentsInvocation(BaseInvocation):
|
class ScaleLatentsInvocation(BaseInvocation):
|
||||||
@ -705,7 +712,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
context.services.latents.save(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
|
|
||||||
|
|
||||||
class ImageToLatentsInvocation(BaseInvocation):
|
class ImageToLatentsInvocation(BaseInvocation):
|
||||||
@ -786,4 +793,4 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||||
|
@ -71,9 +71,9 @@ class NoiseOutput(BaseInvocationOutput):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def build_noise_output(latents_name: str, latents: torch.Tensor):
|
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||||
return NoiseOutput(
|
return NoiseOutput(
|
||||||
noise=LatentsField(latents_name=latents_name),
|
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.size()[3] * 8,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.size()[2] * 8,
|
||||||
)
|
)
|
||||||
@ -132,4 +132,4 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, noise)
|
context.services.latents.save(name, noise)
|
||||||
return build_noise_output(latents_name=name, latents=noise)
|
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||||
|
Loading…
Reference in New Issue
Block a user