Add seed to latents field

This commit is contained in:
Sergey Borisov 2023-08-08 04:00:33 +03:00
parent 492bfe002a
commit 5f29526a8e
2 changed files with 21 additions and 14 deletions

View File

@ -49,6 +49,7 @@ class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
seed: Optional[int] = Field(description="Seed used to generate this latents")
class Config:
schema_extra = {"required": ["latents_name"]}
@ -67,9 +68,9 @@ class LatentsOutput(BaseInvocationOutput):
# fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]):
return LatentsOutput(
latents=LatentsField(latents_name=latents_name),
latents=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
@ -175,6 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
context: InvocationContext,
scheduler,
unet,
seed,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
@ -201,7 +203,7 @@ class TextToLatentsInvocation(BaseInvocation):
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0),
generator=torch.Generator(device=unet.device).manual_seed(seed),
)
return conditioning_data
@ -336,6 +338,7 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
seed = self.noise.seed or 0
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -370,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data(
model=pipeline,
@ -407,7 +410,7 @@ class TextToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
@ -440,10 +443,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
latent = context.services.latents.get(self.latents.latents_name)
seed = self.latents.seed or 0
noise = None
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
if self.noise.seed is not None:
seed = self.noise.seed
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -480,7 +487,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data(
model=pipeline,
@ -521,7 +528,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
# Latent to image
@ -663,7 +670,7 @@ class ResizeLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
class ScaleLatentsInvocation(BaseInvocation):
@ -705,7 +712,7 @@ class ScaleLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
class ImageToLatentsInvocation(BaseInvocation):
@ -786,4 +793,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
return build_latents_output(latents_name=name, latents=latents, seed=None)

View File

@ -71,9 +71,9 @@ class NoiseOutput(BaseInvocationOutput):
# fmt: on
def build_noise_output(latents_name: str, latents: torch.Tensor):
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
return NoiseOutput(
noise=LatentsField(latents_name=latents_name),
noise=LatentsField(latents_name=latents_name, seed=seed),
width=latents.size()[3] * 8,
height=latents.size()[2] * 8,
)
@ -132,4 +132,4 @@ class NoiseInvocation(BaseInvocation):
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)