mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Codesplit SCHEDULER_MAP for reusage
This commit is contained in:
parent
c1e7460d39
commit
9a383e456d
@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
|
||||
from ...backend.image_util.seamless import configure_model_padding
|
||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||
import numpy as np
|
||||
from ..services.image_storage import ImageType
|
||||
@ -52,29 +53,13 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
# TODO: this seems like a hack
|
||||
scheduler_map = dict(
|
||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(list(scheduler_map.keys()))
|
||||
tuple(list(SCHEDULER_MAP.keys()))
|
||||
]
|
||||
|
||||
|
||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = scheduler_map.get(scheduler_name, scheduler_map['ddim'])
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
# hack copied over from generate.py
|
||||
|
@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
|
||||
from .prompting import get_uc_and_c_and_ec
|
||||
from .prompting.conditioning import log_tokenization
|
||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||
from .stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from .util import choose_precision, choose_torch_device
|
||||
|
||||
def fix_func(orig):
|
||||
@ -1047,26 +1048,8 @@ class Generate:
|
||||
def _set_scheduler(self):
|
||||
default = self.model.scheduler
|
||||
|
||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||
scheduler_map = dict(
|
||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
||||
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
||||
# provide an alias for compatibility.
|
||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
||||
if self.sampler_name in scheduler_map:
|
||||
sampler_class, sampler_extra_config = scheduler_map[self.sampler_name]
|
||||
if self.sampler_name in SCHEDULER_MAP:
|
||||
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
|
||||
msg = (
|
||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||
)
|
||||
|
@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@ -71,20 +72,6 @@ class InvokeAIGeneratorOutput:
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_map = dict(
|
||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
@ -176,13 +163,13 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
return list(SCHEDULER_MAP.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class, scheduler_extra_config = self.scheduler_map.get(scheduler_name, self.scheduler_map['ddim'])
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
||||
scheduler = scheduler_class.from_config(scheduler_config)
|
||||
# hack copied over from generate.py
|
||||
|
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .schedulers import SCHEDULER_MAP
|
17
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
17
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@ -0,0 +1,17 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict(cpu_only=False)),
|
||||
dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2=(KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpm_2_a=(KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||
k_euler=(EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_euler_a=(EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_heun=(HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||
k_lms=(LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||
plms=(PNDMScheduler, dict(cpu_only=False)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
Loading…
Reference in New Issue
Block a user