Add seed to latents field

This commit is contained in:
Sergey Borisov 2023-08-08 04:00:33 +03:00
parent 492bfe002a
commit 5f29526a8e
2 changed files with 21 additions and 14 deletions

View File

@ -49,6 +49,7 @@ class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents") latents_name: Optional[str] = Field(default=None, description="The name of the latents")
seed: Optional[int] = Field(description="Seed used to generate this latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
@ -67,9 +68,9 @@ class LatentsOutput(BaseInvocationOutput):
# fmt: on # fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor): def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
return LatentsOutput( return LatentsOutput(
latents=LatentsField(latents_name=latents_name), latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8, width=latents.size()[3] * 8,
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
@ -175,6 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
scheduler, scheduler,
unet, unet,
seed,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -201,7 +203,7 @@ class TextToLatentsInvocation(BaseInvocation):
# for ddim scheduler # for ddim scheduler
eta=0.0, # ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0), generator=torch.Generator(device=unet.device).manual_seed(seed),
) )
return conditioning_data return conditioning_data
@ -336,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
seed = self.noise.seed or 0
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -370,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, model=pipeline,
@ -407,7 +410,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
class LatentsToLatentsInvocation(TextToLatentsInvocation): class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -440,10 +443,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers with SilenceWarnings(): # this quenches NSFW nag from diffusers
latent = context.services.latents.get(self.latents.latents_name)
seed = self.latents.seed or 0
noise = None noise = None
if self.noise is not None: if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) if self.noise.seed is not None:
seed = self.noise.seed
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -480,7 +487,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, model=pipeline,
@ -521,7 +528,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
# Latent to image # Latent to image
@ -663,7 +670,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
class ScaleLatentsInvocation(BaseInvocation): class ScaleLatentsInvocation(BaseInvocation):
@ -705,7 +712,7 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents) context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents) return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
class ImageToLatentsInvocation(BaseInvocation): class ImageToLatentsInvocation(BaseInvocation):
@ -786,4 +793,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
latents = latents.to("cpu") latents = latents.to("cpu")
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents, seed=None)

View File

@ -71,9 +71,9 @@ class NoiseOutput(BaseInvocationOutput):
# fmt: on # fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor): def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
return NoiseOutput( return NoiseOutput(
noise=LatentsField(latents_name=latents_name), noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8, width=latents.size()[3] * 8,
height=latents.size()[2] * 8, height=latents.size()[2] * 8,
) )
@ -132,4 +132,4 @@ class NoiseInvocation(BaseInvocation):
) )
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise) context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise) return build_noise_output(latents_name=name, latents=noise, seed=self.seed)