mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Codesplit SCHEDULER_MAP for reusage
This commit is contained in:
parent
c1e7460d39
commit
9a383e456d
@ -17,6 +17,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
|
|||||||
from ...backend.image_util.seamless import configure_model_padding
|
from ...backend.image_util.seamless import configure_model_padding
|
||||||
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
|
||||||
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_storage import ImageType
|
||||||
@ -52,29 +53,13 @@ class NoiseOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
# TODO: this seems like a hack
|
|
||||||
scheduler_map = dict(
|
|
||||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
|
||||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
|
||||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
tuple(list(scheduler_map.keys()))
|
tuple(list(SCHEDULER_MAP.keys()))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = scheduler_map.get(scheduler_name, scheduler_map['ddim'])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||||
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
|
@ -37,6 +37,7 @@ from .safety_checker import SafetyChecker
|
|||||||
from .prompting import get_uc_and_c_and_ec
|
from .prompting import get_uc_and_c_and_ec
|
||||||
from .prompting.conditioning import log_tokenization
|
from .prompting.conditioning import log_tokenization
|
||||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||||
|
from .stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from .util import choose_precision, choose_torch_device
|
from .util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
@ -1047,26 +1048,8 @@ class Generate:
|
|||||||
def _set_scheduler(self):
|
def _set_scheduler(self):
|
||||||
default = self.model.scheduler
|
default = self.model.scheduler
|
||||||
|
|
||||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
if self.sampler_name in SCHEDULER_MAP:
|
||||||
scheduler_map = dict(
|
sampler_class, sampler_extra_config = SCHEDULER_MAP[self.sampler_name]
|
||||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
|
||||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
# DPMSolverMultistepScheduler is technically not `k_` anything, as it is neither
|
|
||||||
# the k-diffusers implementation nor included in EDM (Karras 2022), but we can
|
|
||||||
# provide an alias for compatibility.
|
|
||||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
|
||||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.sampler_name in scheduler_map:
|
|
||||||
sampler_class, sampler_extra_config = scheduler_map[self.sampler_name]
|
|
||||||
msg = (
|
msg = (
|
||||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||||
)
|
)
|
||||||
|
@ -31,6 +31,7 @@ from ..util.util import rand_perlin_2d
|
|||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
@ -71,20 +72,6 @@ class InvokeAIGeneratorOutput:
|
|||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
class InvokeAIGenerator(metaclass=ABCMeta):
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
scheduler_map = dict(
|
|
||||||
ddim=(diffusers.DDIMScheduler, dict(cpu_only=False)),
|
|
||||||
dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2=(diffusers.KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpm_2_a=(diffusers.KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_dpmpp_2=(diffusers.DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler=(diffusers.EulerDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_euler_a=(diffusers.EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_heun=(diffusers.HeunDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
k_lms=(diffusers.LMSDiscreteScheduler, dict(cpu_only=False)),
|
|
||||||
plms=(diffusers.PNDMScheduler, dict(cpu_only=False)),
|
|
||||||
unipc=(diffusers.UniPCMultistepScheduler, dict(cpu_only=True))
|
|
||||||
)
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_info: dict,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
@ -176,13 +163,13 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
'''
|
'''
|
||||||
Return list of all the schedulers that we currently handle.
|
Return list of all the schedulers that we currently handle.
|
||||||
'''
|
'''
|
||||||
return list(self.scheduler_map.keys())
|
return list(SCHEDULER_MAP.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
return generator_class(model, self.params.precision)
|
return generator_class(model, self.params.precision)
|
||||||
|
|
||||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = self.scheduler_map.get(scheduler_name, self.scheduler_map['ddim'])
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])
|
||||||
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
scheduler_config = {**model.scheduler.config, **scheduler_extra_config}
|
||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
|
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
1
invokeai/backend/stable_diffusion/schedulers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
from .schedulers import SCHEDULER_MAP
|
17
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
17
invokeai/backend/stable_diffusion/schedulers/schedulers.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||||
|
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||||
|
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler
|
||||||
|
|
||||||
|
SCHEDULER_MAP = dict(
|
||||||
|
ddim=(DDIMScheduler, dict(cpu_only=False)),
|
||||||
|
dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||||
|
k_dpm_2=(KDPM2DiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
k_dpm_2_a=(KDPM2AncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
k_dpmpp_2=(DPMSolverMultistepScheduler, dict(cpu_only=False)),
|
||||||
|
k_euler=(EulerDiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
k_euler_a=(EulerAncestralDiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
k_heun=(HeunDiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
k_lms=(LMSDiscreteScheduler, dict(cpu_only=False)),
|
||||||
|
plms=(PNDMScheduler, dict(cpu_only=False)),
|
||||||
|
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user