refactor(diffusers_pipeline): remove unused precision 🚮

This commit is contained in:
Kevin Turner 2023-08-05 20:41:47 -07:00
parent b80abdd101
commit 77033eabd3
3 changed files with 21 additions and 32 deletions

View File

@ -1,26 +1,23 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from contextlib import contextmanager, ContextDecorator
from functools import partial from functools import partial
from typing import Literal, Optional, get_args from typing import Literal, Optional, get_args
import torch
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext
from .image import ImageOutput
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from .model import UNetField, VaeField
from .compel import ConditioningField from .compel import ConditioningField
from contextlib import contextmanager, ExitStack, ContextDecorator from .image import ImageOutput
from .model import UNetField, VaeField
from ..util.step_callback import stable_diffusion_step_callback
from ...backend.generator import Inpaint, InvokeAIGenerator
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())] INFILL_METHODS = Literal[tuple(infill_methods())]
@ -193,7 +190,6 @@ class InpaintInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if dtype == torch.float16 else "float32",
execution_device=device, execution_device=device,
) )

View File

@ -5,15 +5,26 @@ from typing import List, Literal, Optional, Union
import einops import einops
import torch import torch
from diffusers import ControlNetModel
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ModelPatcher from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -24,23 +35,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.util.controlnet_utils import prepare_control_image
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -231,7 +226,6 @@ class TextToLatentsInvocation(BaseInvocation):
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
precision="float16" if unet.dtype == torch.float16 else "float32",
) )
def prep_control_data( def prep_control_data(

View File

@ -300,7 +300,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker: Optional[StableDiffusionSafetyChecker], safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor], feature_extractor: Optional[CLIPFeatureExtractor],
requires_safety_checker: bool = False, requires_safety_checker: bool = False,
precision: str = "float32",
control_model: ControlNetModel = None, control_model: ControlNetModel = None,
execution_device: Optional[torch.device] = None, execution_device: Optional[torch.device] = None,
): ):