mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove scheduler_args from ConditioningData structure.
This commit is contained in:
parent
cad3e5dbd7
commit
e7ec13f209
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
@ -375,14 +376,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
|
||||||
scheduler,
|
|
||||||
# for ddim scheduler
|
|
||||||
eta=0.0, # ddim_eta
|
|
||||||
# for ancestral and sde schedulers
|
|
||||||
# flip all bits to have noise different from initial
|
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
|
||||||
)
|
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
@ -642,7 +635,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||||
# TODO: research more for second order schedulers timesteps
|
# TODO: research more for second order schedulers timesteps
|
||||||
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
|
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end, seed: int):
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
scheduler.set_timesteps(steps, device="cpu")
|
scheduler.set_timesteps(steps, device="cpu")
|
||||||
timesteps = scheduler.timesteps.to(device=device)
|
timesteps = scheduler.timesteps.to(device=device)
|
||||||
@ -669,7 +662,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
num_inference_steps = len(timesteps) // scheduler.order
|
num_inference_steps = len(timesteps) // scheduler.order
|
||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
scheduler_step_kwargs = {}
|
||||||
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
|
# reproducibility.
|
||||||
|
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||||
|
|
||||||
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
def prep_inpaint_mask(self, context, latents):
|
def prep_inpaint_mask(self, context, latents):
|
||||||
if self.denoise_mask is None:
|
if self.denoise_mask is None:
|
||||||
@ -783,12 +784,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
denoising_start=self.denoising_start,
|
denoising_start=self.denoising_start,
|
||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
@ -800,6 +802,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
mask=mask,
|
mask=mask,
|
||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data,
|
control_data=controlnet_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
@ -309,6 +309,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
@ -368,6 +369,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@ -388,6 +390,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
@ -454,6 +457,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@ -485,6 +489,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
@ -584,7 +589,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: issue to diffusers?
|
# TODO: issue to diffusers?
|
||||||
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import dataclasses
|
from dataclasses import dataclass
|
||||||
import inspect
|
from typing import List, Optional, Union
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -73,19 +71,5 @@ class ConditioningData:
|
|||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||||
"""
|
"""
|
||||||
guidance_rescale_multiplier: float = 0
|
guidance_rescale_multiplier: float = 0
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
|
||||||
scheduler_args = dict(self.scheduler_args)
|
|
||||||
step_method = inspect.signature(scheduler.step)
|
|
||||||
for name, value in kwargs.items():
|
|
||||||
try:
|
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
|
||||||
# FIXME: don't silently discard arguments
|
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
||||||
else:
|
|
||||||
scheduler_args[name] = value
|
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user