feat(nodes): move fully* to new images service

* except i haven't rebuilt inpaint in latents
This commit is contained in:
psychedelicious 2023-05-24 15:50:55 +10:00 committed by Kent Keirsey
parent dd16f788ed
commit d2c223de8f
9 changed files with 273 additions and 258 deletions

View File

@ -55,16 +55,6 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
metadata = CoreMetadataService()
urls = LocalUrlService()
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -72,9 +62,16 @@ class ApiDependencies:
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
images_new = ImageService(
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
@ -87,8 +84,7 @@ class ApiDependencies:
model_manager=get_model_manager(config, logger),
events=events,
latents=latents,
images=image_file_storage,
images_new=images_new,
images=images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -45,7 +45,7 @@ async def upload_image(
raise HTTPException(status_code=415, detail="Failed to read image")
try:
image_dto = ApiDependencies.invoker.services.images_new.create(
image_dto = ApiDependencies.invoker.services.images.create(
pil_image,
image_type,
image_category,
@ -67,7 +67,7 @@ async def delete_image(
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images_new.delete(image_type, image_name)
ApiDependencies.invoker.services.images.delete(image_type, image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@ -85,7 +85,7 @@ async def get_image_metadata(
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images_new.get_dto(
return ApiDependencies.invoker.services.images.get_dto(
image_type, image_name
)
except Exception as e:
@ -113,11 +113,11 @@ async def get_image_full(
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images_new.get_path(
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name
)
if not ApiDependencies.invoker.services.images_new.validate_path(path):
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
@ -149,10 +149,10 @@ async def get_image_thumbnail(
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images_new.get_path(
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images_new.validate_path(path):
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
@ -174,10 +174,10 @@ async def get_image_urls(
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images_new.get_url(
image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images_new.get_url(
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True
)
return ImageUrlsDTO(
@ -205,7 +205,7 @@ async def list_images_with_metadata(
) -> PaginatedResults[ImageDTO]:
"""Gets a list of images with metadata"""
image_dtos = ApiDependencies.invoker.services.images_new.get_many(
image_dtos = ApiDependencies.invoker.services.images.get_many(
image_type,
image_category,
page,

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class CvInvocationConfig(BaseModel):
@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel):
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv."""
#fmt: off
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask))
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
@ -52,18 +55,19 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=image_inpainted,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy
from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
@ -84,12 +71,15 @@ class LoadImageInvocation(BaseInvocation):
image_name: str = Field(description="The name of the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image_type, self.image_name)
image = context.services.images.get_pil_image(self.image_type, self.image_name)
return build_image_output(
image_type=self.image_type,
return ImageOutput(
image=ImageField(
image_name=self.image_name,
image=image,
image_type=self.image_type,
),
width=image.width,
height=image.height,
)
@ -99,10 +89,12 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image"
# Inputs
image: ImageField = Field(default=None, description="The image to show")
image: Union[ImageField, None] = Field(
default=None, description="The image to show"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
if image:
@ -110,10 +102,13 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image=image,
image_type=self.image.image_type,
),
width=image.width,
height=image.height,
)
@ -124,7 +119,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["crop"] = "crop"
# Inputs
image: ImageField = Field(default=None, description="The image to crop")
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -132,7 +127,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -141,20 +136,21 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
)
image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=image_crop,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -165,25 +161,27 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["paste"] = "paste"
# Inputs
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name
)
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name)
context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@ -199,20 +197,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=new_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -223,12 +222,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -236,18 +235,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=image_mask,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return MaskOutput(
mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
@ -256,13 +259,13 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -273,18 +276,21 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
)
blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=blur_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -295,13 +301,13 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -310,18 +316,21 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=lerp_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -332,13 +341,13 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -352,16 +361,19 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=ilerp_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,17 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional, Union, get_args
from typing import Literal, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@ -125,36 +125,39 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -163,7 +166,9 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
@ -173,7 +178,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -182,20 +187,21 @@ class InfillTileInvocation(BaseInvocation):
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -204,10 +210,12 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -216,18 +224,19 @@ class InfillPatchMatchInvocation(BaseInvocation):
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -3,7 +3,7 @@
import random
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
import torch
from invokeai.app.invocations.util.choose_model import choose_model
@ -23,7 +23,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -362,19 +362,9 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
torch.cuda.empty_cache()
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# torch.cuda.empty_cache()
# context.services.images.save(image_type, image_name, image, metadata)
image_dto = context.services.images_new.create(
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
@ -382,10 +372,13 @@ class LatentsToImageInvocation(BaseInvocation):
node_id=self.id,
)
return build_image_output(
image_type=image_dto.image_type,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image=image,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -474,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -496,3 +489,4 @@ class ImageToLatentsInvocation(BaseInvocation):
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -2,20 +2,22 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
#fmt: off
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
@ -26,7 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
@ -39,18 +41,19 @@ class RestoreFaceInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -4,22 +4,22 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
#fmt: off
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
@ -30,7 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
@ -43,18 +43,19 @@ class UpscaleInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -6,7 +6,6 @@ from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_file_storage import ImageFileStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
@ -23,12 +22,11 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageFileStorageBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
images_new: ImageService
images: ImageService
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
@ -41,9 +39,8 @@ class InvocationServices:
events: EventServiceBase,
logger: Logger,
latents: LatentsStorageBase,
images: ImageFileStorageBase,
images: ImageService,
queue: InvocationQueueABC,
images_new: ImageService,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
@ -56,7 +53,6 @@ class InvocationServices:
self.latents = latents
self.images = images
self.queue = queue
self.images_new = images_new
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor