mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make dpmpp_sde(_k) use not random seed
This commit is contained in:
parent
096333ba3f
commit
d63bb39475
@ -13,7 +13,7 @@ from diffusers.models.attention_processor import (
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
@ -81,6 +81,7 @@ def get_scheduler(
|
||||
context: InvocationContext,
|
||||
scheduler_info: ModelInfo,
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.services.model_manager.get_model(
|
||||
@ -97,6 +98,11 @@ def get_scheduler(
|
||||
**scheduler_extra_config,
|
||||
"_backup": scheduler_config,
|
||||
}
|
||||
|
||||
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||
if scheduler_class is DPMSolverSDEScheduler:
|
||||
scheduler_config["noise_sampler_seed"] = seed
|
||||
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
|
||||
# hack copied over from generate.py
|
||||
@ -421,6 +427,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
|
@ -212,6 +212,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=0, # TODO: refactor this node
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
|
Loading…
Reference in New Issue
Block a user