mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(upscale_sdx4): upgrade for v3.1 nodes API
This commit is contained in:
parent
caf52cfcce
commit
e06024d8ed
@ -1,31 +1,29 @@
|
|||||||
from typing import Literal, Union, List
|
from typing import List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionUpscalePipeline
|
from diffusers import StableDiffusionUpscalePipeline
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
InvocationContext,
|
|
||||||
title,
|
|
||||||
tags,
|
|
||||||
InputField,
|
|
||||||
FieldDescriptions,
|
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
UIType,
|
UIType,
|
||||||
|
invocation,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.image import ImageOutput
|
from invokeai.app.invocations.image import ImageOutput
|
||||||
from invokeai.app.invocations.latent import get_scheduler, SAMPLER_NAME_VALUES
|
from invokeai.app.invocations.latent import SAMPLER_NAME_VALUES, get_scheduler
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.model import VaeField, UNetField
|
from invokeai.app.invocations.model import UNetField, VaeField
|
||||||
from invokeai.app.invocations.primitives import ImageField, ConditioningField
|
from invokeai.app.invocations.primitives import ConditioningField, ImageField
|
||||||
from invokeai.app.models.image import ResourceOrigin, ImageCategory
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType
|
from invokeai.backend import BaseModelType
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, ConditioningData, PostprocessingSettings
|
from invokeai.backend.stable_diffusion import ConditioningData, PipelineIntermediateState, PostprocessingSettings
|
||||||
|
|
||||||
|
|
||||||
@title("Upscale (Stable Diffusion x4)")
|
@invocation("upscale_sdx4", title="Upscale (Stable Diffusion x4)", tags=["upscale"], version="0.1.0")
|
||||||
@tags("upscale")
|
|
||||||
class UpscaleLatentsInvocation(BaseInvocation):
|
class UpscaleLatentsInvocation(BaseInvocation):
|
||||||
"""Upscales an image using an upscaling diffusion model.
|
"""Upscales an image using an upscaling diffusion model.
|
||||||
|
|
||||||
@ -35,8 +33,6 @@ class UpscaleLatentsInvocation(BaseInvocation):
|
|||||||
models. We don't have ControlNet or LoRA support for it. It has its own VAE.
|
models. We don't have ControlNet or LoRA support for it. It has its own VAE.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["upscale_sdx4"] = "upscale_sdx4"
|
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = InputField(description="The image to upscale")
|
image: ImageField = InputField(description="The image to upscale")
|
||||||
|
|
||||||
@ -58,8 +54,6 @@ class UpscaleLatentsInvocation(BaseInvocation):
|
|||||||
metadata: CoreMetadata = InputField(default=None, description=FieldDescriptions.core_metadata)
|
metadata: CoreMetadata = InputField(default=None, description=FieldDescriptions.core_metadata)
|
||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
|
||||||
# TODO: fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
@ -113,6 +107,7 @@ class UpscaleLatentsInvocation(BaseInvocation):
|
|||||||
session_id=context.graph_execution_state_id,
|
session_id=context.graph_execution_state_id,
|
||||||
is_intermediate=self.is_intermediate,
|
is_intermediate=self.is_intermediate,
|
||||||
metadata=self.metadata.dict() if self.metadata else None,
|
metadata=self.metadata.dict() if self.metadata else None,
|
||||||
|
workflow=self.workflow,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
|
Loading…
Reference in New Issue
Block a user