mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make dpmpp_sde(_k) use not random seed
This commit is contained in:
parent
096333ba3f
commit
d63bb39475
@ -13,7 +13,7 @@ from diffusers.models.attention_processor import (
|
|||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
@ -81,6 +81,7 @@ def get_scheduler(
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
|
seed: int,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
@ -97,6 +98,11 @@ def get_scheduler(
|
|||||||
**scheduler_extra_config,
|
**scheduler_extra_config,
|
||||||
"_backup": scheduler_config,
|
"_backup": scheduler_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# make dpmpp_sde reproducable(seed can be passed only in initializer)
|
||||||
|
if scheduler_class is DPMSolverSDEScheduler:
|
||||||
|
scheduler_config["noise_sampler_seed"] = seed
|
||||||
|
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
@ -421,6 +427,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
|
@ -212,6 +212,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
scheduler_name=self.scheduler,
|
scheduler_name=self.scheduler,
|
||||||
|
seed=0, # TODO: refactor this node
|
||||||
)
|
)
|
||||||
|
|
||||||
def torch2numpy(latent: torch.Tensor):
|
def torch2numpy(latent: torch.Tensor):
|
||||||
|
Loading…
Reference in New Issue
Block a user