fix(upscale_sdx4): upgrade for v3.1 nodes API

This commit is contained in:
Kevin Turner 2023-09-16 14:21:29 -07:00
parent caf52cfcce
commit e06024d8ed

View File

@ -1,31 +1,29 @@
from typing import Literal, Union, List
from typing import List, Union
import torch
from diffusers import StableDiffusionUpscalePipeline
from invokeai.app.invocations.baseinvocation import (
InvocationContext,
title,
tags,
InputField,
FieldDescriptions,
BaseInvocation,
FieldDescriptions,
Input,
InputField,
InvocationContext,
UIType,
invocation,
)
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.invocations.latent import get_scheduler, SAMPLER_NAME_VALUES
from invokeai.app.invocations.latent import SAMPLER_NAME_VALUES, get_scheduler
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.model import VaeField, UNetField
from invokeai.app.invocations.primitives import ImageField, ConditioningField
from invokeai.app.models.image import ResourceOrigin, ImageCategory
from invokeai.app.invocations.model import UNetField, VaeField
from invokeai.app.invocations.primitives import ConditioningField, ImageField
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType
from invokeai.backend.stable_diffusion import PipelineIntermediateState, ConditioningData, PostprocessingSettings
from invokeai.backend.stable_diffusion import ConditioningData, PipelineIntermediateState, PostprocessingSettings
@title("Upscale (Stable Diffusion x4)")
@tags("upscale")
@invocation("upscale_sdx4", title="Upscale (Stable Diffusion x4)", tags=["upscale"], version="0.1.0")
class UpscaleLatentsInvocation(BaseInvocation):
"""Upscales an image using an upscaling diffusion model.
@ -35,8 +33,6 @@ class UpscaleLatentsInvocation(BaseInvocation):
models. We don't have ControlNet or LoRA support for it. It has its own VAE.
"""
type: Literal["upscale_sdx4"] = "upscale_sdx4"
# Inputs
image: ImageField = InputField(description="The image to upscale")
@ -58,8 +54,6 @@ class UpscaleLatentsInvocation(BaseInvocation):
metadata: CoreMetadata = InputField(default=None, description=FieldDescriptions.core_metadata)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# TODO: fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@ -113,6 +107,7 @@ class UpscaleLatentsInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(