Make dpmpp_sde(_k) use not random seed

This commit is contained in:
Sergey Borisov 2023-08-14 00:24:38 +03:00
parent 096333ba3f
commit d63bb39475
2 changed files with 9 additions and 1 deletions

View File

@ -13,7 +13,7 @@ from diffusers.models.attention_processor import (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
@ -81,6 +81,7 @@ def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
seed: int,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
@ -97,6 +98,11 @@ def get_scheduler(
**scheduler_extra_config, **scheduler_extra_config,
"_backup": scheduler_config, "_backup": scheduler_config,
} }
# make dpmpp_sde reproducable(seed can be passed only in initializer)
if scheduler_class is DPMSolverSDEScheduler:
scheduler_config["noise_sampler_seed"] = seed
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
@ -421,6 +427,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
seed=seed,
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)

View File

@ -212,6 +212,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
context=context, context=context,
scheduler_info=self.unet.scheduler, scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler, scheduler_name=self.scheduler,
seed=0, # TODO: refactor this node
) )
def torch2numpy(latent: torch.Tensor): def torch2numpy(latent: torch.Tensor):