feat(nodes): move fully* to new images service

* except i haven't rebuilt inpaint in latents
This commit is contained in:
psychedelicious 2023-05-24 15:50:55 +10:00 committed by Kent Keirsey
parent dd16f788ed
commit d2c223de8f
9 changed files with 273 additions and 258 deletions

View File

@ -55,16 +55,6 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs") os.path.join(os.path.dirname(__file__), "../../../../outputs")
) )
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
metadata = CoreMetadataService()
urls = LocalUrlService()
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
# TODO: build a file/path manager? # TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db") db_location = os.path.join(output_folder, "invokeai.db")
@ -72,9 +62,16 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions" filename=db_location, table_name="graph_executions"
) )
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
images_new = ImageService( latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
images = ImageService(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
image_file_storage=image_file_storage, image_file_storage=image_file_storage,
metadata=metadata, metadata=metadata,
@ -87,8 +84,7 @@ class ApiDependencies:
model_manager=get_model_manager(config, logger), model_manager=get_model_manager(config, logger),
events=events, events=events,
latents=latents, latents=latents,
images=image_file_storage, images=images,
images_new=images_new,
queue=MemoryInvocationQueue(), queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph]( graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs" filename=db_location, table_name="graphs"

View File

@ -45,7 +45,7 @@ async def upload_image(
raise HTTPException(status_code=415, detail="Failed to read image") raise HTTPException(status_code=415, detail="Failed to read image")
try: try:
image_dto = ApiDependencies.invoker.services.images_new.create( image_dto = ApiDependencies.invoker.services.images.create(
pil_image, pil_image,
image_type, image_type,
image_category, image_category,
@ -67,7 +67,7 @@ async def delete_image(
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images_new.delete(image_type, image_name) ApiDependencies.invoker.services.images.delete(image_type, image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@ -85,7 +85,7 @@ async def get_image_metadata(
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images_new.get_dto( return ApiDependencies.invoker.services.images.get_dto(
image_type, image_name image_type, image_name
) )
except Exception as e: except Exception as e:
@ -113,11 +113,11 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images_new.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name image_type, image_name
) )
if not ApiDependencies.invoker.services.images_new.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
return FileResponse( return FileResponse(
@ -149,10 +149,10 @@ async def get_image_thumbnail(
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images_new.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True image_type, image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images_new.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
return FileResponse( return FileResponse(
@ -174,10 +174,10 @@ async def get_image_urls(
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images_new.get_url( image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name image_type, image_name
) )
thumbnail_url = ApiDependencies.invoker.services.images_new.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True image_type, image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
@ -205,7 +205,7 @@ async def list_images_with_metadata(
) -> PaginatedResults[ImageDTO]: ) -> PaginatedResults[ImageDTO]:
"""Gets a list of images with metadata""" """Gets a list of images with metadata"""
image_dtos = ApiDependencies.invoker.services.images_new.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
image_type, image_type,
image_category, image_category,
page, page,

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output from .image import ImageOutput
class CvInvocationConfig(BaseModel): class CvInvocationConfig(BaseModel):
@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel):
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv.""" """Simple inpaint using opencv."""
#fmt: off
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint" type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to inpaint") image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting") mask: ImageField = Field(default=None, description="The mask to use when inpainting")
#fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
mask = context.services.images.get(self.mask.image_type, self.mask.image_name) mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
# Convert to cv image/mask # Convert to cv image/mask
# TODO: consider making these utility functions # TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR) cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask)) cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
# Inpaint # Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA) cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
@ -52,18 +55,19 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# TODO: consider making a utility function # TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted, image=image_inpainted,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
) )

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io import io
from typing import Literal, Optional from typing import Literal, Optional, Union
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType from ..models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
schema_extra = {"required": ["type", "image", "width", "height"]} schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
)
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask""" """Base class for invocations that output a mask"""
# fmt: off # fmt: off
type: Literal["mask"] = "mask" type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask") mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on # fmt: on
class Config: class Config:
@ -84,12 +71,15 @@ class LoadImageInvocation(BaseInvocation):
image_name: str = Field(description="The name of the image") image_name: str = Field(description="The name of the image")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image_type, self.image_name) image = context.services.images.get_pil_image(self.image_type, self.image_name)
return build_image_output( return ImageOutput(
image_type=self.image_type, image=ImageField(
image_name=self.image_name, image_name=self.image_name,
image=image, image_type=self.image_type,
),
width=image.width,
height=image.height,
) )
@ -99,10 +89,12 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image" type: Literal["show_image"] = "show_image"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to show") image: Union[ImageField, None] = Field(
default=None, description="The image to show"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
if image: if image:
@ -110,10 +102,13 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure? # TODO: how to handle failure?
return build_image_output( return ImageOutput(
image_type=self.image.image_type, image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image=image, image_type=self.image.image_type,
),
width=image.width,
height=image.height,
) )
@ -124,7 +119,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["crop"] = "crop" type: Literal["crop"] = "crop"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to crop") image: Union[ImageField, None] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle") x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle") y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle") width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -132,7 +127,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -141,20 +136,21 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
) )
image_crop.paste(image, (-self.x, -self.y)) image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_crop, image=image_crop,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
) )
@ -165,25 +161,27 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["paste"] = "paste" type: Literal["paste"] = "paste"
# Inputs # Inputs
base_image: ImageField = Field(default=None, description="The base image") base_image: Union[ImageField, None] = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste") image: Union[ImageField, None] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image") x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image") y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get( base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name self.base_image.image_type, self.base_image.image_name
) )
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name) context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
) )
) )
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
@ -199,20 +197,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
new_image.paste(base_image, (abs(min_x), abs(min_y))) new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask) new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=new_image, image=new_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
) )
@ -223,12 +222,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask" type: Literal["tomask"] = "tomask"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to create the mask from") image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask") invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -236,18 +235,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
if self.invert: if self.invert:
image_mask = ImageOps.invert(image_mask) image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=image_mask,
context.graph_execution_state_id, self.id image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return MaskOutput(
session_id=context.graph_execution_state_id, node=self mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
) )
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig): class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image""" """Blurs an image"""
@ -256,13 +259,13 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["blur"] = "blur" type: Literal["blur"] = "blur"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to blur") image: Union[ImageField, None] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius") radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -273,18 +276,21 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
) )
blur_image = image.filter(blur) blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=blur_image,
context.graph_execution_state_id, self.id image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, blur_image, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, image_name=image_name, image=blur_image height=image_dto.height,
) )
@ -295,13 +301,13 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["lerp"] = "lerp" type: Literal["lerp"] = "lerp"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to lerp") image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value") min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value") max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -310,18 +316,21 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
lerp_image = Image.fromarray(numpy.uint8(image_arr)) lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=lerp_image,
context.graph_execution_state_id, self.id image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, lerp_image, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, image_name=image_name, image=lerp_image height=image_dto.height,
) )
@ -332,13 +341,13 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["ilerp"] = "ilerp" type: Literal["ilerp"] = "ilerp"
# Inputs # Inputs
image: ImageField = Field(default=None, description="The image to lerp") image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value") min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value") max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -352,16 +361,19 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
ilerp_image = Image.fromarray(numpy.uint8(image_arr)) ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=ilerp_image,
context.graph_execution_state_id, self.id image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, ilerp_image, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, image_name=image_name, image=ilerp_image height=image_dto.height,
) )

View File

@ -1,17 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional, Union, get_args from typing import Literal, Union, get_args
import numpy as np import numpy as np
import math import math
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import Field from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType from ..models.image import ColorField, ImageCategory, ImageField, ImageType
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationContext, InvocationContext,
@ -125,36 +125,39 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color""" """Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba" type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill") image: Union[ImageField, None] = Field(
color: Optional[ColorField] = Field( default=None, description="The image to infill"
)
color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255), default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill", description="The color to use to infill",
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1]) infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=infilled,
context.graph_execution_state_id, self.id image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, infilled, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, height=image_dto.height,
image_name=image_name,
image=image,
) )
@ -163,7 +166,9 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile" type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill") image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)") tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field( seed: int = Field(
ge=0, ge=0,
@ -173,7 +178,7 @@ class InfillTileInvocation(BaseInvocation):
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -182,20 +187,21 @@ class InfillTileInvocation(BaseInvocation):
) )
infilled.paste(image, (0, 0), image.split()[-1]) infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=infilled,
context.graph_execution_state_id, self.id image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, infilled, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, height=image_dto.height,
image_name=image_name,
image=image,
) )
@ -204,10 +210,12 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch" type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill") image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -216,18 +224,19 @@ class InfillPatchMatchInvocation(BaseInvocation):
else: else:
raise ValueError("PatchMatch is not available on this system") raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=infilled,
context.graph_execution_state_id, self.id image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, infilled, metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, height=image_dto.height,
image_name=image_name,
image=image,
) )

View File

@ -3,7 +3,7 @@
import random import random
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import einops import einops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
import torch import torch
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
@ -23,7 +23,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput
from .compel import ConditioningField from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -362,19 +362,9 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents) np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0] image = model.numpy_to_pil(np_image)[0]
# image_type = ImageType.RESULT torch.cuda.empty_cache()
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
# metadata = context.services.metadata.build_metadata( image_dto = context.services.images.create(
# session_id=context.graph_execution_state_id, node=self
# )
# torch.cuda.empty_cache()
# context.services.images.save(image_type, image_name, image, metadata)
image_dto = context.services.images_new.create(
image=image, image=image,
image_type=ImageType.RESULT, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
@ -382,10 +372,13 @@ class LatentsToImageInvocation(BaseInvocation):
node_id=self.id, node_id=self.id,
) )
return build_image_output( return ImageOutput(
image_type=image_dto.image_type, image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image=image, image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
) )
@ -474,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
@ -496,3 +489,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents) context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents) return build_latents_output(latents_name=name, latents=latents)

View File

@ -2,20 +2,22 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation): class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image.""" """Restores faces in an image."""
#fmt: off
# fmt: off
type: Literal["restore_face"] = "restore_face" type: Literal["restore_face"] = "restore_face"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image") image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on # fmt: on
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -26,7 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
@ -39,18 +41,19 @@ class RestoreFaceInvocation(BaseInvocation):
# Results are image and seed, unwrap for now # Results are image and seed, unwrap for now
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=results[0][0],
context.graph_execution_state_id, self.id image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, results[0][0], metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, height=image_dto.height,
image_name=image_name,
image=results[0][0]
) )

View File

@ -4,22 +4,22 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output from .image import ImageOutput
class UpscaleInvocation(BaseInvocation): class UpscaleInvocation(BaseInvocation):
"""Upscales an image.""" """Upscales an image."""
#fmt: off
# fmt: off
type: Literal["upscale"] = "upscale" type: Literal["upscale"] = "upscale"
# Inputs # Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None) image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level") level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on # fmt: on
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -30,7 +30,7 @@ class UpscaleInvocation(BaseInvocation):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_type, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
@ -43,18 +43,19 @@ class UpscaleInvocation(BaseInvocation):
# Results are image and seed, unwrap for now # Results are image and seed, unwrap for now
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_type = ImageType.RESULT image_dto = context.services.images.create(
image_name = context.services.images.create_name( image=results[0][0],
context.graph_execution_state_id, self.id image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
) )
metadata = context.services.metadata.build_metadata( return ImageOutput(
session_id=context.graph_execution_state_id, node=self image=ImageField(
) image_name=image_dto.image_name,
image_type=image_dto.image_type,
context.services.images.save(image_type, image_name, results[0][0], metadata) ),
return build_image_output( width=image_dto.width,
image_type=image_type, height=image_dto.height,
image_name=image_name,
image=results[0][0]
) )

View File

@ -6,7 +6,6 @@ from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager from invokeai.backend import ModelManager
from .events import EventServiceBase from .events import EventServiceBase
from .latent_storage import LatentsStorageBase from .latent_storage import LatentsStorageBase
from .image_file_storage import ImageFileStorageBase
from .restoration_services import RestorationServices from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -23,12 +22,11 @@ class InvocationServices:
events: EventServiceBase events: EventServiceBase
latents: LatentsStorageBase latents: LatentsStorageBase
images: ImageFileStorageBase
queue: InvocationQueueABC queue: InvocationQueueABC
model_manager: ModelManager model_manager: ModelManager
restoration: RestorationServices restoration: RestorationServices
configuration: InvokeAISettings configuration: InvokeAISettings
images_new: ImageService images: ImageService
# NOTE: we must forward-declare any types that include invocations, since invocations can use services # NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"] graph_library: ItemStorageABC["LibraryGraph"]
@ -41,9 +39,8 @@ class InvocationServices:
events: EventServiceBase, events: EventServiceBase,
logger: Logger, logger: Logger,
latents: LatentsStorageBase, latents: LatentsStorageBase,
images: ImageFileStorageBase, images: ImageService,
queue: InvocationQueueABC, queue: InvocationQueueABC,
images_new: ImageService,
graph_library: ItemStorageABC["LibraryGraph"], graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC", processor: "InvocationProcessorABC",
@ -56,7 +53,6 @@ class InvocationServices:
self.latents = latents self.latents = latents
self.images = images self.images = images
self.queue = queue self.queue = queue
self.images_new = images_new
self.graph_library = graph_library self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
self.processor = processor self.processor = processor