feat(nodes): refactor image types

- Remove `ImageType` entirely, it is confusing
- Create `ResourceOrigin`, may be `internal` or `external`
- Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on
- Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change.

The new setup should account for our types of images, including the combinations we couldn't really handle until now:
- Canvas init and masks
- Canvas when saved-to-gallery or merged
This commit is contained in:
psychedelicious
2023-05-27 21:39:20 +10:00
committed by Kent Keirsey
parent fd47e70c92
commit 160267c71a
17 changed files with 291 additions and 311 deletions

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
self.mask.image_origin, self.mask.image_name
)
# Convert to cv image/mask
@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create(
image=image_inpainted,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,

View File

@ -10,9 +10,9 @@ import torch
from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images.create(
image=generate_output.image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation):
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
)
@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image_dto = context.services.images.create(
image=generator_output.image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation):
None
if self.image is None
else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
)
# Handle invalid model parameter
@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_dto = context.services.images.create(
image=generator_output.image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ImageType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_type=self.image.image_type,
image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
if image:
image.show()
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_type=self.image.image_type,
image_origin=self.image.image_origin,
),
width=image.width,
height=image.height,
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
image_crop = Image.new(
@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=image_crop,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name
self.base_image.image_origin, self.base_image.image_name
)
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
self.mask.image_origin, self.mask.image_name
)
)
)
@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=new_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
image_mask = image.split()[-1]
@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=image_mask,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
return MaskOutput(
mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name
self.image1.image_origin, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name
self.image2.image_origin, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
image_origin=image_dto.image_origin, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
blur = (
@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=blur_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -409,7 +409,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -430,7 +430,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -440,7 +440,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=lerp_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -450,7 +450,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -471,7 +471,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -486,7 +486,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create(
image=ilerp_image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -496,7 +496,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
infilled = tile_fill_missing(
@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
if PatchMatch.patchmatch_available():
@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,

View File

@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_file_storage import ImageType
from ..services.image_file_storage import ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .compel import ConditioningField
@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation):
# and gnenerate unique image_name
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
# TODO: this only really needs the vae

View File

@ -2,7 +2,7 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,

View File

@ -4,7 +4,7 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
self.image.image_origin, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]],
@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation):
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.RESULT,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation):
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,