Partial migration of UI to nodes API (#3195)

* feat(ui): add axios client generator and simple example

* fix(ui): update client & nodes test code w/ new Edge type

* chore(ui): organize generated files

* chore(ui): update .eslintignore, .prettierignore

* chore(ui): update openapi.json

* feat(backend): fixes for nodes/generator

* feat(ui): generate object args for api client

* feat(ui): more nodes api prototyping

* feat(ui): nodes cancel

* chore(ui): regenerate api client

* fix(ui): disable OG web server socket connection

* fix(ui): fix scrollbar styles typing and prop

just noticed the typo, and made the types stronger.

* feat(ui): add socketio types

* feat(ui): wip nodes

- extract api client method arg types instead of manually declaring them
- update example to display images
- general tidy up

* start building out node translations from frontend state and add notes about missing features

* use reference to sampler_name

* use reference to sampler_name

* add optional apiUrl prop

* feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation

* feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node

* feat(ui): img2img implementation

* feat(ui): get intermediate images working but types are stubbed out

* chore(ui): add support for package mode

* feat(ui): add nodes mode script

* feat(ui): handle random seeds

* fix(ui): fix middleware types

* feat(ui): add rtk action type guard

* feat(ui): disable NodeAPITest

This was polluting the network/socket logs.

* feat(ui): fix parameters panel border color

This commit should be elsewhere but I don't want to break my flow

* feat(ui): make thunk types more consistent

* feat(ui): add type guards for outputs

* feat(ui): load images on socket connect

Rudimentary

* chore(ui): bump redux-toolkit

* docs(ui): update readme

* chore(ui): regenerate api client

* chore(ui): add typescript as dev dependency

I am having trouble with TS versions after vscode updated and now uses TS 5. `madge` has installed 3.9.10 and for whatever reason my vscode wants to use that. Manually specifying 4.9.5 and then setting vscode to use that as the workspace TS fixes the issue.

* feat(ui): begin migrating gallery to nodes

Along the way, migrate to use RTK `createEntityAdapter` for gallery images, and separate `results` and `uploads` into separate slices. Much cleaner this way.

* feat(ui): clean up & comment results slice

* fix(ui): separate thunk for initial gallery load so it properly gets index 0

* feat(ui): POST upload working

* fix(ui): restore removed type

* feat(ui): patch api generation for headers access

* chore(ui): regenerate api

* feat(ui): wip gallery migration

* feat(ui): wip gallery migration

* chore(ui): regenerate api

* feat(ui): wip refactor socket events

* feat(ui): disable panels based on app props

* feat(ui): invert logic to be disabled

* disable panels when app mounts

* feat(ui): add support to disableTabs

* docs(ui): organise and update docs

* lang(ui): add toast strings

* feat(ui): wip events, comments, and general refactoring

* feat(ui): add optional token for auth

* feat(ui): export StatusIndicator and ModelSelect for header use

* feat(ui) working on making socket URL dynamic

* feat(ui): dynamic middleware loading

* feat(ui): prep for socket jwt

* feat(ui): migrate cancelation

also updated action names to be event-like instead of declaration-like

sorry, i was scattered and this commit has a lot of unrelated stuff in it.

* fix(ui): fix img2img type

* chore(ui): regenerate api client

* feat(ui): improve InvocationCompleteEvent types

* feat(ui): increase StatusIndicator font size

* fix(ui): fix middleware order for multi-node graphs

* feat(ui): add exampleGraphs object w/ iterations example

* feat(ui): generate iterations graph

* feat(ui): update ModelSelect for nodes API

* feat(ui): add hi-res functionality for txt2img generations

* feat(ui): "subscribe" to particular nodes

feels like a dirty hack but oh well it works

* feat(ui): first steps to node editor ui

* fix(ui): disable event subscription

it is not fully baked just yet

* feat(ui): wip node editor

* feat(ui): remove extraneous field types

* feat(ui): nodes before deleting stuff

* feat(ui): cleanup nodes ui stuff

* feat(ui): hook up nodes to redux

* fix(ui): fix handle

* fix(ui): add basic node edges & connection validation

* feat(ui): add connection validation styling

* feat(ui): increase edge width

* feat(ui): it blends

* feat(ui): wip model handling and graph topology validation

* feat(ui): validation connections w/ graphlib

* docs(ui): update nodes doc

* feat(ui): wip node editor

* chore(ui): rebuild api, update types

* add redux-dynamic-middlewares as a dependency

* feat(ui): add url host transformation

* feat(ui): handle already-connected fields

* feat(ui): rewrite SqliteItemStore in sqlalchemy

* fix(ui): fix sqlalchemy dynamic model instantiation

* feat(ui, nodes): metadata wip

* feat(ui, nodes): models

* feat(ui, nodes): more metadata wip

* feat(ui): wip range/iterate

* fix(nodes): fix sqlite typing

* feat(ui): export new type for invoke component

* tests(nodes): fix test instantiation of ImageField

* feat(nodes): fix LoadImageInvocation

* feat(nodes): add `title` ui hint

* feat(nodes): make ImageField attrs optional

* feat(ui): wip nodes etc

* feat(nodes): roll back sqlalchemy

* fix(nodes): partially address feedback

* fix(backend): roll back changes to pngwriter

* feat(nodes): wip address metadata feedback

* feat(nodes): add seeded rng to RandomRange

* feat(nodes): address feedback

* feat(nodes): move GET images error handling to DiskImageStorage

* feat(nodes): move GET images error handling to DiskImageStorage

* fix(nodes): fix image output schema customization

* feat(ui): img2img/txt2img -> linear

- remove txt2img and img2img tabs
- add linear tab
- add initial image selection to linear parameters accordion

* feat(ui): tidy graph builders

* feat(ui): tidy misc

* feat(ui): improve invocation union types

* feat(ui): wip metadata viewer recall

* feat(ui): move fonts to normal deps

* feat(nodes): fix broken upload

* feat(nodes): add metadata module + tests, thumbnails

- `MetadataModule` is stateless and needed in places where the `InvocationContext` is not available, so have not made it a `service`
- Handles loading/parsing/building metadata, and creating png info objects
- added tests for MetadataModule
- Lifted thumbnail stuff to util

* fix(nodes): revert change to RandomRangeInvocation

* feat(nodes): address feedback

- make metadata a service
- rip out pydantic validation, implement metadata parsing as simple functions
- update tests
- address other minor feedback items

* fix(nodes): fix other tests

* fix(nodes): add metadata service to cli

* fix(nodes): fix latents/image field parsing

* feat(nodes): customise LatentsField schema

* feat(nodes): move metadata parsing to frontend

* fix(nodes): fix metadata test

---------

Co-authored-by: maryhipp <maryhipp@gmail.com>
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
This commit is contained in:
psychedelicious
2023-04-22 13:10:20 +10:00
committed by GitHub
parent fdad62e88b
commit 5f498e10bd
324 changed files with 13051 additions and 1400 deletions

View File

@ -3,6 +3,8 @@
import os
from argparse import Namespace
from invokeai.app.services.metadata import PngMetadataService, MetadataServiceBase
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -60,7 +62,9 @@ class ApiDependencies:
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
images = DiskImageStorage(f'{output_folder}/images')
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
@ -70,6 +74,7 @@ class ApiDependencies:
events=events,
latents=latents,
images=images,
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -1,7 +1,19 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
class ImageResponse(BaseModel):
@ -11,4 +23,12 @@ class ImageResponse(BaseModel):
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageMetadata = Field(description="The image's metadata")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,13 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Path, Query, Request, UploadFile
from fastapi import HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
from invokeai.app.services.metadata import InvokeAIMetadata
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@ -15,70 +19,110 @@ from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
) -> FileResponse | Response:
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(file: UploadFile, request: Request):
async def upload_image(
file: UploadFile, request: Request, response: Response
) -> ImageResponse:
if not file.content_type.startswith("image"):
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
im = Image.open(contents)
img = Image.open(io.BytesIO(contents))
except:
# Error opening the image
return Response(status_code=415)
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
ApiDependencies.invoker.services.images.save(ImageType.UPLOAD, filename, im)
return Response(
status_code=201,
headers={
"Location": request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
},
(image_path, thumbnail_path, ctime) = ApiDependencies.invoker.services.images.save(
ImageType.UPLOAD, filename, img
)
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
res = ImageResponse(
image_type=ImageType.UPLOAD,
image_name=filename,
image_url=f"api/v1/images/{ImageType.UPLOAD.value}/{filename}",
thumbnail_url=f"api/v1/images/{ImageType.UPLOAD.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
metadata=ImageResponseMetadata(
created=ctime,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = request.url_for(
"get_image", image_type=ImageType.UPLOAD.value, image_name=filename
)
return res
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page
)
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result

View File

@ -13,6 +13,8 @@ from typing import (
from pydantic import BaseModel
from pydantic.fields import Field
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -200,6 +202,8 @@ def invoke_cli():
events = EventServiceBase()
metadata = PngMetadataService()
output_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../../outputs")
)
@ -211,7 +215,8 @@ def invoke_cli():
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images'),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"

View File

@ -95,7 +95,7 @@ class UIConfig(TypedDict, total=False):
],
]
tags: List[str]
title: str
class CustomisedSchemaExtra(TypedDict):
ui: UIConfig

View File

@ -1,16 +1,17 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
from typing import Literal
from typing import Literal, Optional
import cv2 as cv
import numpy as np
import numpy.random
from PIL import Image, ImageOps
from pydantic import Field
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext, BaseInvocationOutput
from .image import ImageField, ImageOutput
from .baseinvocation import (
BaseInvocation,
InvocationConfig,
InvocationContext,
BaseInvocationOutput,
)
class IntCollectionOutput(BaseInvocationOutput):
@ -33,7 +34,9 @@ class RangeInvocation(BaseInvocation):
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RandomRangeInvocation(BaseInvocation):
@ -43,8 +46,19 @@ class RandomRangeInvocation(BaseInvocation):
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: Optional[int] = Field(
ge=0,
le=np.iinfo(np.int32).max,
description="The seed for the RNG",
default_factory=lambda: numpy.random.randint(0, np.iinfo(np.int32).max),
)
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(numpy.random.randint(self.low, self.high, size=self.size)))
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(
collection=list(rng.integers(low=self.low, high=self.high, size=self.size))
)

View File

@ -9,7 +9,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class CvInvocationConfig(BaseModel):
@ -56,7 +56,14 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_inpainted)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@ -9,13 +9,12 @@ from torch import Tensor
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..models.exceptions import CanceledException
from ..util.step_callback import diffusers_step_callback_adapter
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
@ -58,28 +57,31 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context),
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
@ -95,9 +97,18 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=generate_output.image,
)
@ -117,20 +128,17 @@ class ImageToImageInvocation(TextToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
@ -145,15 +153,21 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Img2Img(model).generate(
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_image=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@ -168,11 +182,19 @@ class ImageToImageInvocation(TextToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -188,20 +210,17 @@ class InpaintInvocation(ImageToImageInvocation):
)
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
self,
context: InvocationContext,
source_node_id: str,
intermediate_state: PipelineIntermediateState,
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = (
@ -218,17 +237,23 @@ class InpaintInvocation(ImageToImageInvocation):
)
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)
model = choose_model(context.services.model_manager, self.model)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
outputs = Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
@ -243,7 +268,14 @@ class InpaintInvocation(ImageToImageInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, result_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Optional
import numpy
@ -8,8 +7,12 @@ from PIL import Image, ImageFilter, ImageOps
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
@ -22,50 +25,73 @@ class PILInvocationConfig(BaseModel):
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
#fmt: off
# fmt: off
type: Literal["image"] = "image"
image: ImageField = Field(default=None, description="The output image")
#fmt: on
width: Optional[int] = Field(default=None, description="The width of the image in pixels")
height: Optional[int] = Field(default=None, description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'image',
]
"required": ["type", "image", "width", "height", "mode"]
}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
mode=image.mode,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
#fmt: off
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
#fmt: on
# fmt: on
class Config:
schema_extra = {
'required': [
'type',
'mask',
"required": [
"type",
"mask",
]
}
# TODO: this isn't really necessary anymore
class LoadImageInvocation(BaseInvocation):
"""Load an image from a filename and provide it as output."""
#fmt: off
"""Load an image and provide it as output."""
# fmt: off
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
return ImageOutput(
image=ImageField(image_type=self.image_type, image_name=self.image_name)
image = context.services.images.get(self.image_type, self.image_name)
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
)
@ -86,16 +112,17 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return ImageOutput(
image=ImageField(
image_type=self.image.image_type, image_name=self.image.image_name
)
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
#fmt: off
# fmt: off
type: Literal["crop"] = "crop"
# Inputs
@ -104,7 +131,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -120,15 +147,23 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_crop)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_crop,
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
#fmt: off
# fmt: off
type: Literal["paste"] = "paste"
# Inputs
@ -137,7 +172,7 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
@ -170,21 +205,29 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, new_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=new_image,
)
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
"""Extracts the alpha channel of an image as a mask."""
#fmt: off
# fmt: off
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
@ -199,22 +242,27 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image_mask)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class BlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
#fmt: off
# fmt: off
type: Literal["blur"] = "blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@ -231,22 +279,28 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, blur_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["lerp"] = "lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
@ -262,23 +316,29 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, lerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
#fmt: off
# fmt: off
type: Literal["ilerp"] = "ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
self.image.image_type, self.image.image_name
@ -298,7 +358,12 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, ilerp_image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
)

View File

@ -5,9 +5,9 @@ from typing import Literal, Optional
from pydantic import BaseModel, Field
import torch
from invokeai.app.models.exceptions import CanceledException
from invokeai.app.invocations.util.get_model import choose_model
from invokeai.app.util.step_callback import diffusers_step_callback_adapter
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
@ -19,7 +19,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont
import numpy as np
from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from .image import ImageField, ImageOutput, build_image_output
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
@ -31,6 +31,8 @@ class LatentsField(BaseModel):
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
@ -170,21 +172,14 @@ class TextToLatentsInvocation(BaseInvocation):
# TODO: pass this an emitter method or something? or a session for dispatching?
def dispatch_progress(
self, context: InvocationContext, intermediate_state: PipelineIntermediateState
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
if (context.services.queue.is_canceled(context.graph_execution_state_id)):
raise CanceledException
step = intermediate_state.step
if intermediate_state.predicted_original is not None:
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be.
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
diffusers_step_callback_adapter(sample, step, steps=self.steps, id=self.id, context=context)
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model)
@ -231,8 +226,12 @@ class TextToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -281,8 +280,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, state)
self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model)
@ -355,7 +358,14 @@ class LatentsToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image
)

View File

@ -1,12 +1,11 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
@ -44,7 +43,14 @@ class RestoreFaceInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,14 +1,12 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput
from .image import ImageOutput, build_image_output
class UpscaleInvocation(BaseInvocation):
@ -49,7 +47,14 @@ class UpscaleInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -1,11 +1,14 @@
from invokeai.app.invocations.baseinvocation import InvocationContext
from invokeai.backend.model_management.model_manager import ModelManager
def choose_model(model_manager: ModelManager, model_name: str):
"""Returns the default model if the `model_name` not a valid model, else returns the selected model."""
if model_manager.valid_model(model_name):
return model_manager.get_model(model_name)
model = model_manager.get_model(model_name)
else:
print(f"* Warning: '{model_name}' is not a valid model name. Using default model instead.")
return model_manager.get_model()
model = model_manager.get_model()
print(
f"* Warning: '{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead."
)
return model

View File

@ -9,6 +9,14 @@ class ImageType(str, Enum):
UPLOAD = "uploads"
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
@ -18,9 +26,4 @@ class ImageField(BaseModel):
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {
"required": [
"image_type",
"image_name",
]
}
schema_extra = {"required": ["image_type", "image_name"]}

View File

@ -1,11 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: float = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")

View File

@ -1,10 +1,9 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Dict, TypedDict
from typing import Any
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp
ProgressImage = TypedDict(
"ProgressImage", {"dataURL": str, "width": int, "height": int}
)
class EventServiceBase:
session_event: str = "session_event"
@ -14,7 +13,8 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: Dict) -> None:
def __emit_session_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
@ -25,7 +25,8 @@ class EventServiceBase:
def emit_generator_progress(
self,
graph_execution_state_id: str,
invocation_id: str,
node: dict,
source_node_id: str,
progress_image: ProgressImage | None,
step: int,
total_steps: int,
@ -35,48 +36,60 @@ class EventServiceBase:
event_name="generator_progress",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
progress_image=progress_image,
node=node,
source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None,
step=step,
total_steps=total_steps,
),
)
def emit_invocation_complete(
self, graph_execution_state_id: str, invocation_id: str, result: Dict
self,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_complete",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
result=result,
),
)
def emit_invocation_error(
self, graph_execution_state_id: str, invocation_id: str, error: str
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
event_name="invocation_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
error=error,
),
)
def emit_invocation_started(
self, graph_execution_state_id: str, invocation_id: str
self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
invocation_id=invocation_id,
node=node,
source_node_id=source_node_id,
),
)
@ -84,5 +97,7 @@ class EventServiceBase:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
event_name="graph_execution_state_complete",
payload=dict(graph_execution_state_id=graph_execution_state_id),
payload=dict(
graph_execution_state_id=graph_execution_state_id,
),
)

View File

@ -1,24 +1,24 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import datetime
import os
from glob import glob
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Callable, Dict, List
from typing import Dict, List, Tuple
from PIL.Image import Image
import PIL.Image as PILImage
from pydantic import BaseModel
from invokeai.app.api.models.images import ImageResponse
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.api.models.images import ImageResponse, ImageResponseMetadata
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
@ -26,12 +26,14 @@ class ImageStorageBase(ABC):
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@ -39,35 +41,51 @@ class ImageStorageBase(ABC):
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> Tuple[str, str, int]:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image path, thumbnail path, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
return f"{context_id}_{node_id}_{str(int(datetime.datetime.now(datetime.timezone.utc).timestamp()))}.png"
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__pngWriter: PngWriter
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str):
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__pngWriter = PngWriter(output_folder)
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
@ -100,6 +118,9 @@ class DiskImageStorage(ImageStorageBase):
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
@ -107,11 +128,12 @@ class DiskImageStorage(ImageStorageBase):
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{image_type.value}/{filename}",
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=os.path.getctime(path),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
@ -142,26 +164,50 @@ class DiskImageStorage(ImageStorageBase):
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", image_name
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, image_name)
path = os.path.join(self.__output_folder, image_type, basename)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
image_subpath = os.path.join(image_type, image_name)
self.__pngWriter.save_image_and_prompt_to_png(
image, "", image_subpath, None
) # TODO: just pass full path to png writer
save_thumbnail(
image=image,
filename=image_name,
path=os.path.join(self.__output_folder, image_type, "thumbnails"),
)
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> Tuple[str, str, int]:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return (image_path, thumbnail_path, int(os.path.getctime(image_path)))
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)

View File

@ -1,4 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
from .events import EventServiceBase
@ -14,6 +15,7 @@ class InvocationServices:
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
@ -29,6 +31,7 @@ class InvocationServices:
events: EventServiceBase,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
@ -39,6 +42,7 @@ class InvocationServices:
self.events = events
self.latents = latents
self.images = images
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager

View File

@ -0,0 +1,96 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, str | int | float | bool | MetadataImageField | MetadataLatentsField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
try:
info = image.info.get("invokeai")
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
return metadata

View File

@ -43,10 +43,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
queue_item.invocation_id
)
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
# Send starting event
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id
)
# Invoke
@ -75,7 +79,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
)
@ -99,7 +104,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
node=invocation.dict(),
source_node_id=source_node_id,
error=error,
)

View File

@ -35,7 +35,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._create_table()
def _create_table(self):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""CREATE TABLE IF NOT EXISTS {self._table_name} (
item TEXT,
@ -44,27 +45,34 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(
f"""CREATE UNIQUE INDEX IF NOT EXISTS {self._table_name}_id ON {self._table_name}(id);"""
)
self._conn.commit()
finally:
self._lock.release()
def _parse_item(self, item: str) -> T:
item_type = get_args(self.__orig_class__)[0]
return parse_raw_as(item_type, item)
def set(self, item: T):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""INSERT OR REPLACE INTO {self._table_name} (item) VALUES (?);""",
(item.json(),),
)
self._conn.commit()
finally:
self._lock.release()
self._on_changed(item)
def get(self, id: str) -> Union[T, None]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
if not result:
return None
@ -72,15 +80,19 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
return self._parse_item(result[0])
def delete(self, id: str):
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
self._on_deleted(id)
def list(self, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} LIMIT ? OFFSET ?;""",
(per_page, page * per_page),
@ -91,6 +103,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
self._cursor.execute(f"""SELECT count(*) FROM {self._table_name};""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
@ -101,7 +115,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
with self._lock:
try:
self._lock.acquire()
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE item LIKE ? LIMIT ? OFFSET ?;""",
(f"%{query}%", per_page, page * per_page),
@ -115,6 +130,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1

View File

@ -0,0 +1,5 @@
import datetime
def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())

View File

@ -1,25 +0,0 @@
import os
from PIL import Image
def save_thumbnail(
image: Image.Image,
filename: str,
path: str,
size: int = 256,
) -> str:
"""
Saves a thumbnail of an image, returning its path.
"""
base_filename = os.path.splitext(filename)[0]
thumbnail_path = os.path.join(path, base_filename + ".webp")
if os.path.exists(thumbnail_path):
return thumbnail_path
image_copy = image.copy()
image_copy.thumbnail(size=(size, size))
image_copy.save(thumbnail_path, "WEBP")
return thumbnail_path

View File

@ -1,16 +1,41 @@
import torch
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException
from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator
from ...backend.stable_diffusion import PipelineIntermediateState
def fast_latents_step_callback(
sample: torch.Tensor,
step: int,
steps: int,
id: str,
def stable_diffusion_step_callback(
context: InvocationContext,
intermediate_state: PipelineIntermediateState,
node: dict,
source_node_id: str,
):
if context.services.queue.is_canceled(context.graph_execution_state_id):
raise CanceledException
# Some schedulers report not only the noisy latents at the current timestep,
# but also their estimate so far of what the de-noised latents will be. Use
# that estimate if it is available.
if intermediate_state.predicted_original is not None:
sample = intermediate_state.predicted_original
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
image = Generator.sample_to_lowres_estimated_image(sample)
@ -21,23 +46,10 @@ def fast_latents_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
context.graph_execution_state_id,
id,
{"width": width, "height": height, "dataURL": dataURL},
step,
steps,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
total_steps=node["steps"],
)
def diffusers_step_callback_adapter(*cb_args, **kwargs):
"""
txt2img gives us a Tensor in the step_callbak, while img2img gives us a PipelineIntermediateState.
This adapter grabs the needed data and passes it along to the callback function.
"""
if isinstance(cb_args[0], PipelineIntermediateState):
progress_state: PipelineIntermediateState = cb_args[0]
return fast_latents_step_callback(
progress_state.latents, progress_state.step, **kwargs
)
else:
return fast_latents_step_callback(*cb_args, **kwargs)

View File

@ -0,0 +1,15 @@
import os
from PIL import Image
def get_thumbnail_name(image_name: str) -> str:
"""Formats given an image name, returns the appropriate thumbnail image name"""
thumbnail_name = os.path.splitext(image_name)[0] + ".webp"
return thumbnail_name
def make_thumbnail(image: Image.Image, size: int = 256) -> Image.Image:
"""Makes a thumbnail from a PIL Image"""
thumbnail = image.copy()
thumbnail.thumbnail(size=(size, size))
return thumbnail