fix(nodes): fix conflicts with controlnet

This commit is contained in:
psychedelicious 2023-05-27 21:55:29 +10:00 committed by Kent Keirsey
parent 29fcc92da9
commit 08a14ee6d5
3 changed files with 8 additions and 8 deletions

View File

@ -7,7 +7,7 @@ from typing import Literal, Optional, Union, List
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType, ImageCategory from ..models.image import ImageField, ImageCategory, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -163,7 +163,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
raw_image = context.services.images.get_pil_image( raw_image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
# image type should be PIL.PngImagePlugin.PngImageFile ? # image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image) processed_image = self.run_processor(raw_image)
@ -177,8 +177,8 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery # so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=processed_image, image=processed_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.is_intermediate is_intermediate=self.is_intermediate
@ -187,7 +187,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
"""Builds an ImageOutput and its ImageField""" """Builds an ImageOutput and its ImageField"""
processed_image_field = ImageField( processed_image_field = ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
) )
return ImageOutput( return ImageOutput(
image=processed_image_field, image=processed_image_field,

View File

@ -86,8 +86,8 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# loading controlnet image (currently requires pre-processed image) # loading controlnet image (currently requires pre-processed image)
control_image = ( control_image = (
None if self.control_image is None None if self.control_image is None
else context.services.images.get( else context.services.images.get_pil_image(
self.control_image.image_type, self.control_image.image_name self.control_image.image_origin, self.control_image.image_name
) )
) )
# loading controlnet model # loading controlnet model

View File

@ -297,7 +297,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch_dtype=model.unet.dtype).to(model.device) torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_type, input_image = context.services.images.get_pil_image(control_image_field.image_origin,
control_image_field.image_name) control_image_field.image_name)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes