chore: minor cleanup after merge & flake8

This commit is contained in:
psychedelicious 2023-08-18 16:05:39 +10:00
parent 3c43594c26
commit c49851e027

View File

@ -22,18 +22,18 @@ from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (
ImageField, ImageField,
ImageOutput, ImageOutput,
LatentsField,
LatentsOutput,
InpaintMaskField, InpaintMaskField,
InpaintMaskOutput, InpaintMaskOutput,
LatentsField,
LatentsOutput,
build_latents_output, build_latents_output,
) )
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
@ -45,16 +45,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import Post
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, UIType, tags, title
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
UIType,
tags,
title,
)
from .compel import ConditioningField from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
@ -65,7 +56,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create inpaint mask") @title("Create Inpaint Mask")
@tags("mask", "inpaint") @tags("mask", "inpaint")
class CreateInpaintMaskInvocation(BaseInvocation): class CreateInpaintMaskInvocation(BaseInvocation):
"""Creates mask for inpaint model run.""" """Creates mask for inpaint model run."""
@ -85,7 +76,6 @@ class CreateInpaintMaskInvocation(BaseInvocation):
def prep_mask_tensor(self, mask_image): def prep_mask_tensor(self, mask_image):
if mask_image.mode != "L": if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L") mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3: if mask_tensor.dim() == 3:
@ -779,12 +769,8 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), **self.vae.vae.dict(),
context=context, context=context,