fix(nodes): fix conflicts with controlnet

This commit is contained in:
psychedelicious 2023-05-27 21:55:29 +10:00 committed by Kent Keirsey
parent 29fcc92da9
commit 08a14ee6d5
3 changed files with 8 additions and 8 deletions

View File

@ -7,7 +7,7 @@ from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType, ImageCategory
from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -163,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
@ -177,8 +177,8 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
image=processed_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate
@ -187,7 +187,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
)
return ImageOutput(
image=processed_image_field,

View File

@ -86,8 +86,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
else context.services.images.get_pil_image(
self.control_image.image_origin, self.control_image.image_name
)
)
# loading controlnet model

View File

@ -297,7 +297,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_type,
input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes