Merge branch 'main' into release/make-web-dist-startable

This commit is contained in:
Lincoln Stein 2023-05-25 19:06:09 -04:00 committed by GitHub
commit 9110838fe4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
188 changed files with 3446 additions and 4400 deletions

View File

@ -125,6 +125,7 @@ jobs:
--no-nsfw_checker
--precision=float32
--always_use_cpu
--use_memory_db
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
--from_file ${{ env.TEST_PROMPTS }}

View File

@ -216,7 +216,7 @@ manager, please follow these steps:
9. Run the command-line- or the web- interface:
From within INVOKEAI_ROOT, activate the environment
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
the script `invokeai`. If the virtual environment you selected is NOT inside
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
`--root_dir \path\to\invokeai` to the commands below:

View File

@ -1,22 +1,24 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
import os
import invokeai.backend.util.logging as logger
from typing import types
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
from ..services.processor import DefaultInvocationProcessor
from ..services.sqlite import SqliteItemStorage
from ..services.metadata import PngMetadataService
from .events import FastAPIEventService
@ -36,42 +38,59 @@ def check_internet() -> bool:
return False
logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
@staticmethod
def initialize(config, event_handler_id: int, logger: Logger = logger):
logger.info(f"Internet connectivity is {config.internet_available}")
events = FastAPIEventService(event_handler_id)
output_folder = config.output_path
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
services = InvocationServices(
model_manager=get_model_manager(config,logger),
model_manager=get_model_manager(config, logger),
events=events,
latents=latents,
images=images,
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
restoration=RestorationServices(config, logger),
configuration=config,
logger=logger,
)

View File

@ -2,7 +2,6 @@ from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import InvokeAIMetadata
class ImageResponseMetadata(BaseModel):
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
invokeai: Optional[InvokeAIMetadata] = Field(
description="The image's InvokeAI-specific metadata"
)
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):

View File

@ -1,148 +1,215 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
"/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
201: {"description": "The image was uploaded successfully"},
415: {"description": "Image upload failed"},
},
status_code=201,
response_model=ImageDTO,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
file: UploadFile,
image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.GENERAL,
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
try:
img = Image.open(io.BytesIO(contents))
pil_image = Image.open(io.BytesIO(contents))
except:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
try:
image_dto = ApiDependencies.invoker.services.images.create(
pil_image,
image_type,
image_category,
)
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
response.status_code = 201
response.headers["Location"] = image_dto.image_url
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
return image_dto
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to create image")
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Query(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
try:
ApiDependencies.invoker.services.images.delete(image_type, image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
response.status_code = 201
response.headers["Location"] = image_url
return res
@images_router.get(
"/{image_type}/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(
image_type, image_name
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
200: {
"description": "Return the full-resolution image",
"content": {"image/png": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_full(
image_type: ImageType = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path,
media_type="image/png",
filename=image_name,
content_disposition_type="inline",
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
200: {
"description": "Return the image thumbnail",
"content": {"image/webp": {}},
},
404: {"description": "Image not found"},
},
)
async def get_image_thumbnail(
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
return FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_type: ImageType = Path(description="The type of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True
)
return ImageUrlsDTO(
image_type=image_type,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
operation_id="list_images_with_metadata",
response_model=PaginatedResults[ImageDTO],
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
async def list_images_with_metadata(
image_type: ImageType = Query(description="The type of images to list"),
image_category: ImageCategory = Query(description="The kind of images to list"),
page: int = Query(default=0, description="The page of image metadata to get"),
per_page: int = Query(
default=10, description="The number of image metadata per page"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result
) -> PaginatedResults[ImageDTO]:
"""Gets a list of images with metadata"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
image_type,
image_category,
page,
per_page,
)
return image_dtos

View File

@ -3,8 +3,8 @@ import asyncio
from inspect import signature
import uvicorn
import invokeai.backend.util.logging as logger
import invokeai.frontend.web as web_dir
from invokeai.backend.util.logging import InvokeAILogger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
@ -16,11 +16,13 @@ from pathlib import Path
from pydantic.schema import schema
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.routers import sessions, models, images
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
from .services.config import InvokeAIAppConfig
logger = InvokeAILogger.getLogger()
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
@ -71,10 +73,9 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@ -123,6 +124,7 @@ app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
@ -140,8 +142,13 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0],"dist"), html=True), name="ui")
app.mount("/",
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
html=True
), name="ui"
)
def invoke_api():
# Start our own event loop for eventing usage

View File

@ -13,10 +13,13 @@ from typing import (
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.urls import LocalUrlService
import invokeai.backend.util.logging as logger
from invokeai.app.services.metadata import PngMetadataService
from .services.default_graphs import create_system_graphs
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -28,7 +31,7 @@ from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
@ -188,6 +191,9 @@ def invoke_all(context: CliContext):
raise SessionError()
logger = logger.InvokeAILogger.getLogger()
def invoke_cli():
# this gets the basic configuration
config = get_invokeai_config()
@ -206,24 +212,43 @@ def invoke_cli():
events = EventServiceBase()
output_folder = config.output_path
metadata = PngMetadataService()
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
if config.use_memory_db:
db_location = ":memory:"
else:
db_location = os.path.join(output_folder, "invokeai.db")
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
images = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
images=images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger=logger),
logger=logger,

View File

@ -1,12 +1,15 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvocationContext:

View File

@ -1,9 +1,9 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional
from typing import Literal
import numpy as np
from pydantic import Field
from pydantic import Field, validator
from invokeai.app.util.misc import SEED_MAX, get_random_seed
@ -24,7 +24,7 @@ class IntCollectionOutput(BaseInvocationOutput):
class RangeInvocation(BaseInvocation):
"""Creates a range"""
"""Creates a range of numbers from start to stop with step"""
type: Literal["range"] = "range"
@ -33,12 +33,34 @@ class RangeInvocation(BaseInvocation):
stop: int = Field(default=10, description="The stop of the range")
step: int = Field(default=1, description="The step of the range")
@validator("stop")
def stop_gt_start(cls, v, values):
if "start" in values and v <= values["start"]:
raise ValueError("stop must be greater than start")
return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RangeOfSizeInvocation(BaseInvocation):
"""Creates a range from start to start + size with step"""
type: Literal["range_of_size"] = "range_of_size"
# Inputs
start: int = Field(default=0, description="The start of the range")
size: int = Field(default=1, description="The number of values")
step: int = Field(default=1, description="The step of the range")
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(
collection=list(range(self.start, self.start + self.size, self.step))
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@ -118,7 +118,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(

View File

@ -7,9 +7,9 @@ import numpy
from PIL import Image, ImageOps
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class CvInvocationConfig(BaseModel):
@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel):
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
"""Simple inpaint using opencv."""
#fmt: off
# fmt: off
type: Literal["cv_inpaint"] = "cv_inpaint"
# Inputs
image: ImageField = Field(default=None, description="The image to inpaint")
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
#fmt: on
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
# Convert to cv image/mask
# TODO: consider making these utility functions
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv_mask = numpy.array(ImageOps.invert(mask))
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
# Inpaint
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
@ -52,18 +55,19 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
# TODO: consider making a utility function
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=image_inpainted,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_inpainted, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image_inpainted,
)

View File

@ -10,17 +10,21 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
from ...backend.stable_diffusion import PipelineIntermediateState
from ..util.step_callback import stable_diffusion_step_callback
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
class SDImageInvocation(BaseModel):
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
@ -91,25 +95,21 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=generate_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -145,7 +145,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get(
else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
)
@ -175,26 +175,23 @@ class ImageToImageInvocation(TextToImageInvocation):
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=generator_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -204,16 +201,38 @@ class InpaintInvocation(ImageToImageInvocation):
# Inputs
mask: Union[ImageField, None] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)"
)
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
seam_steps: int = Field(
default=30, ge=1, description="The number of steps to use for seam inpaint"
)
tile_size: int = Field(
default=32, ge=1, description="The tile infill method size (px)"
)
infill_method: INFILL_METHODS = Field(
default=DEFAULT_INFILL_METHOD,
description="The method used to infill empty regions (px)",
)
inpaint_width: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The width of the inpaint region (px)",
)
inpaint_height: Optional[int] = Field(
default=None,
multiple_of=8,
gt=0,
description="The height of the inpaint region (px)",
)
inpaint_fill: Optional[ColorField] = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The solid infill method color",
)
inpaint_replace: float = Field(
default=0.0,
ge=0.0,
@ -238,14 +257,14 @@ class InpaintInvocation(ImageToImageInvocation):
image = (
None
if self.image is None
else context.services.images.get(
else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
)
mask = (
None
if self.mask is None
else context.services.images.get(self.mask.image_type, self.mask.image_name)
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
)
# Handle invalid model parameter
@ -271,23 +290,19 @@ class InpaintInvocation(ImageToImageInvocation):
# each time it is called. We only need the first one.
generator_output = next(outputs)
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=generator_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, result_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=result_image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,13 +1,13 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from typing import Literal, Optional
from typing import Literal, Optional, Union
import numpy
from PIL import Image, ImageFilter, ImageOps
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from ..models.image import ImageField, ImageType
from ..models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image"] = "image"
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
schema_extra = {"required": ["type", "image", "width", "height"]}
def build_image_output(
image_type: ImageType, image_name: str, image: Image.Image
) -> ImageOutput:
"""Builds an ImageOutput and its ImageField"""
image_field = ImageField(
image_name=image_name,
image_type=image_type,
)
return ImageOutput(
image=image_field,
width=image.width,
height=image.height,
)
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
@ -80,16 +67,20 @@ class LoadImageInvocation(BaseInvocation):
type: Literal["load_image"] = "load_image"
# Inputs
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image: Union[ImageField, None] = Field(
default=None, description="The image to load"
)
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(self.image_type, self.image_name)
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
return build_image_output(
image_type=self.image_type,
image_name=self.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_type=self.image.image_type,
),
width=image.width,
height=image.height,
)
@ -99,10 +90,12 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image"
# Inputs
image: ImageField = Field(default=None, description="The image to show")
image: Union[ImageField, None] = Field(
default=None, description="The image to show"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
if image:
@ -110,21 +103,24 @@ class ShowImageInvocation(BaseInvocation):
# TODO: how to handle failure?
return build_image_output(
image_type=self.image.image_type,
image_name=self.image.image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=self.image.image_name,
image_type=self.image.image_type,
),
width=image.width,
height=image.height,
)
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image."""
# fmt: off
type: Literal["crop"] = "crop"
type: Literal["img_crop"] = "img_crop"
# Inputs
image: ImageField = Field(default=None, description="The image to crop")
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
@ -132,7 +128,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -141,49 +137,52 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
)
image_crop.paste(image, (-self.x, -self.y))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, image_crop, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=image_crop,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image."""
# fmt: off
type: Literal["paste"] = "paste"
type: Literal["img_paste"] = "img_paste"
# Inputs
base_image: ImageField = Field(default=None, description="The base image")
image: ImageField = Field(default=None, description="The image to paste")
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get(
base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name
)
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
mask = (
None
if self.mask is None
else ImageOps.invert(
context.services.images.get(self.mask.image_type, self.mask.image_name)
context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name
)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it?
@ -199,20 +198,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, new_image, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_dto = context.services.images.create(
image=new_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -223,12 +223,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["tomask"] = "tomask"
# Inputs
image: ImageField = Field(default=None, description="The image to create the mask from")
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
invert: bool = Field(default=False, description="Whether or not to invert the mask")
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -236,33 +236,151 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
if self.invert:
image_mask = ImageOps.invert(image_mask)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=image_mask,
image_type=ImageType.RESULT,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return MaskOutput(
mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, image_mask, metadata)
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
class BlurInvocation(BaseInvocation, PILInvocationConfig):
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
"""Gets a channel from an image."""
# fmt: off
type: Literal["img_chan"] = "img_chan"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
"""Converts an image to a different mode."""
# fmt: off
type: Literal["img_conv"] = "img_conv"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image"""
# fmt: off
type: Literal["blur"] = "blur"
type: Literal["img_blur"] = "img_blur"
# Inputs
image: ImageField = Field(default=None, description="The image to blur")
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
radius: float = Field(default=8.0, ge=0, description="The blur radius")
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -273,35 +391,38 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
)
blur_image = image.filter(blur)
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=blur_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, blur_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=blur_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
class LerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["lerp"] = "lerp"
type: Literal["img_lerp"] = "img_lerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -310,35 +431,38 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
lerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=lerp_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, lerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=lerp_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image"""
# fmt: off
type: Literal["ilerp"] = "ilerp"
type: Literal["img_ilerp"] = "img_ilerp"
# Inputs
image: ImageField = Field(default=None, description="The image to lerp")
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -352,16 +476,19 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
image_type = ImageType.INTERMEDIATE
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=ilerp_image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, ilerp_image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=ilerp_image
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -1,17 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import Literal, Optional, Union, get_args
from typing import Literal, Union, get_args
import numpy as np
import math
from PIL import Image, ImageOps
from pydantic import Field
from invokeai.app.invocations.image import ImageOutput, build_image_output
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageField, ImageType
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
from .baseinvocation import (
BaseInvocation,
InvocationContext,
@ -125,36 +125,39 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
color: Optional[ColorField] = Field(
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image)
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -163,7 +166,9 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
@ -173,7 +178,7 @@ class InfillTileInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -182,20 +187,21 @@ class InfillTileInvocation(BaseInvocation):
)
infilled.paste(image, (0, 0), image.split()[-1])
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -204,10 +210,12 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Union[ImageField, None] = Field(
default=None, description="The image to infill"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -216,18 +224,19 @@ class InfillPatchMatchInvocation(BaseInvocation):
else:
raise ValueError("PatchMatch is not available on this system")
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=infilled,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
context.services.images.save(image_type, image_name, infilled, metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=image,
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -3,10 +3,11 @@
import random
from typing import Literal, Optional, Union
import einops
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
import torch
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.app.util.step_callback import stable_diffusion_step_callback
@ -20,9 +21,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, Sta
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .image import ImageField, ImageOutput
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -139,12 +140,17 @@ class NoiseInvocation(BaseInvocation):
},
}
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
device = torch.device(choose_torch_device())
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
@ -163,8 +169,8 @@ class TextToLatentsInvocation(BaseInvocation):
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
# Schema customisation
@ -199,17 +205,17 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
# if isinstance(model, DiffusionPipeline):
# for component in [model.unet, model.vae]:
# configure_model_padding(component,
# self.seamless,
# self.seamless_axes
# )
# else:
# configure_model_padding(model,
# self.seamless,
# self.seamless_axes
# )
return model
@ -260,7 +266,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -319,7 +325,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -356,20 +362,23 @@ class LatentsToImageInvocation(BaseInvocation):
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
torch.cuda.empty_cache()
context.services.images.save(image_type, image_name, image, metadata)
return build_image_output(
image_type=image_type, image_name=image_name, image=image
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
@ -404,7 +413,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -434,7 +443,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -458,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
@ -478,5 +487,6 @@ class ImageToLatentsInvocation(BaseInvocation):
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -2,21 +2,23 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class RestoreFaceInvocation(BaseInvocation):
"""Restores faces in an image."""
#fmt: off
# fmt: off
type: Literal["restore_face"] = "restore_face"
# Inputs
image: Union[ImageField, None] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
@ -26,7 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
@ -39,18 +41,19 @@ class RestoreFaceInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -4,22 +4,22 @@ from typing import Literal, Union
from pydantic import Field
from invokeai.app.models.image import ImageField, ImageType
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput, build_image_output
from .image import ImageOutput
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
#fmt: off
# fmt: off
type: Literal["upscale"] = "upscale"
# Inputs
image: Union[ImageField, None] = Field(description="The input image", default=None)
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
level: Literal[2, 4] = Field(default=2, description="The upscale level")
#fmt: on
# fmt: on
# Schema customisation
class Config(InvocationConfig):
@ -30,7 +30,7 @@ class UpscaleInvocation(BaseInvocation):
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get(
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
results = context.services.restoration.upscale_and_reconstruct(
@ -43,18 +43,19 @@ class UpscaleInvocation(BaseInvocation):
# Results are image and seed, unwrap for now
# TODO: can this return multiple results?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_type=image_dto.image_type,
),
width=image_dto.width,
height=image_dto.height,
)
context.services.images.save(image_type, image_name, results[0][0], metadata)
return build_image_output(
image_type=image_type,
image_name=image_name,
image=results[0][0]
)

View File

@ -2,19 +2,44 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum):
"""The type of an image."""
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
def is_image_type(obj):
try:
ImageType(obj)
except ValueError:
return False
return True
class InvalidImageTypeException(ValueError):
"""Raised when a provided value is not a valid ImageType.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image type."):
super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
GENERAL = "general"
CONTROL = "control"
MASK = "mask"
OTHER = "other"
class InvalidImageCategoryException(ValueError):
"""Raised when a provided value is not a valid ImageCategory.
Subclasses `ValueError`.
"""
def __init__(self, message="Invalid image category."):
super().__init__(message)
class ImageField(BaseModel):

View File

@ -0,0 +1,91 @@
from typing import Optional
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
class Config:
extra = Extra.allow
"""
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
won't add any fields that are not already defined, but other a different metadata service
implementation might.
"""
type: Optional[StrictStr] = Field(
default=None,
description="The type of the ancestor node of the image output node.",
)
"""The type of the ancestor node of the image output node."""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
"""The positive conditioning"""
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
"""The negative conditioning"""
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/latents in pixels."
)
"""Width of the image/latents in pixels"""
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/latents in pixels."
)
"""Height of the image/latents in pixels"""
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
"""The seed used for noise generation"""
cfg_scale: Optional[StrictFloat] = Field(
default=None, description="The classifier-free guidance scale."
)
"""The classifier-free guidance scale"""
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
"""The number of steps used for inference"""
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
"""The scheduler used for inference"""
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
"""The model used for inference"""
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/latents-to-latents.",
)
"""The strength used for image-to-image/latents-to-latents."""
latents: Optional[StrictStr] = Field(
default=None, description="The ID of the initial latents."
)
"""The ID of the initial latents"""
vae: Optional[StrictStr] = Field(
default=None, description="The VAE used for decoding."
)
"""The VAE used for decoding"""
unet: Optional[StrictStr] = Field(
default=None, description="The UNet used dor inference."
)
"""The UNet used dor inference"""
clip: Optional[StrictStr] = Field(
default=None, description="The CLIP Encoder used for conditioning."
)
"""The CLIP Encoder used for conditioning"""
extra: Optional[StrictStr] = Field(
default=None,
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
)
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""

View File

@ -353,6 +353,7 @@ setting environment variables INVOKEAI_<setting>.
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
@ -362,6 +363,7 @@ setting environment variables INVOKEAI_<setting>.
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
@ -511,7 +513,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
'''

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from typing import Any, Optional
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp

View File

@ -713,6 +713,13 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:

View File

@ -0,0 +1,204 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional
from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin
from send2trash import send2trash
from invokeai.app.models.image import ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
# TODO: Should these excpetions subclass existing python exceptions?
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
# 500 internal server error. I don't like having this method on the service.
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
except FileNotFoundError as e:
raise ImageFileNotFoundException from e
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_type, image_name)
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", metadata.json())
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e:
raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def validate_path(self, path: str) -> bool:
"""Validates the path given for an image or thumbnail."""
try:
os.stat(path)
return True
except:
return False
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,317 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, cast
import sqlite3
import threading
from typing import Optional, Union
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
from invokeai.app.services.item_storage import PaginatedResults
# TODO: Should these excpetions subclass existing python exceptions?
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
# TODO: Implement an `update()` method
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
# TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image record."""
pass
@abstractmethod
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
width: int,
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
) -> datetime:
"""Saves an image record."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
# Create the `images` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
session_id TEXT,
node_id TEXT,
metadata TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Soft delete, currently unused
deleted_at DATETIME
);
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
"""
)
# Add trigger for `updated_at`.
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
AFTER UPDATE
ON images FOR EACH ROW
BEGIN
UPDATE images SET updated_at = current_timestamp
WHERE image_name = old.image_name;
END;
"""
)
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise ImageRecordNotFoundException
return deserialize_image_record(dict(result))
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ?
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, per_page, page * per_page),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
self._cursor.execute(
"""--sql
SELECT count(*) FROM images
WHERE image_type = ? AND image_category = ?
""",
(image_type.value, image_category.value),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
)
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
) -> datetime:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
)
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_type,
image_category,
width,
height,
node_id,
session_id,
metadata
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_type.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata_json,
),
)
self._conn.commit()
self._cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()

View File

@ -1,274 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from glob import glob
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,375 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType
from invokeai.app.models.image import (
ImageCategory,
ImageType,
InvalidImageCategoryException,
InvalidImageTypeException,
)
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
ImageFileDeleteException,
ImageFileNotFoundException,
ImageFileSaveException,
ImageFileStorageBase,
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
class ImageServiceABC(ABC):
"""High-level service for image management."""
@abstractmethod
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path."""
pass
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image's path."""
pass
@abstractmethod
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets an image's or thumbnail's URL."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
self.logger = logger
self.graph_execution_manager = graph_execution_manager
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
graph_execution_manager=graph_execution_manager,
)
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
metadata = self._get_metadata(session_id, node_id)
(width, height) = image.size
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
created_at = self._services.records.save(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
)
self._services.files.save(
image_type=image_type,
image_name=image_name,
image=image,
metadata=metadata,
)
image_url = self._services.urls.get_image_url(image_type, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True
)
return ImageDTO(
# Non-nullable fields
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
# Nullable fields
node_id=node_id,
session_id=session_id,
metadata=metadata,
# Meta fields
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_type, image_name)
except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_type, image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_type, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_type, image_name),
self._services.urls.get_image_url(image_type, image_name, True),
)
return image_dto
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def validate_path(self, path: str) -> bool:
try:
return self._services.files.validate_path(path)
except Exception as e:
self._services.logger.error("Problem validating image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
image_type,
image_category,
page,
per_page,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(image_type, r.image_name),
self._services.urls.get_image_url(
image_type, r.image_name, True
),
),
results.items,
)
)
return PaginatedResults[ImageDTO](
items=image_dtos,
page=results.page,
pages=results.pages,
per_page=results.per_page,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_type: ImageType, image_name: str):
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:
"""Get the metadata for a node."""
metadata = None
if node_id is not None and session_id is not None:
session = self._services.graph_execution_manager.get(session_id)
metadata = self._services.metadata.create_image_metadata(session, node_id)
return metadata

View File

@ -1,55 +1,57 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.backend import ModelManager
if TYPE_CHECKING:
from logging import Logger
from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.restoration_services import RestorationServices
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.config import InvokeAISettings
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
from .config import InvokeAISettings
class InvocationServices:
"""Services that can be used by invocations"""
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
configuration: InvokeAISettings
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
events: "EventServiceBase"
latents: "LatentsStorageBase"
queue: "InvocationQueueABC"
model_manager: "ModelManager"
restoration: "RestorationServices"
configuration: "InvokeAISettings"
images: "ImageService"
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
graph_library: "ItemStorageABC"["LibraryGraph"]
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
processor: "InvocationProcessorABC"
def __init__(
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
configuration: InvokeAISettings=None,
self,
model_manager: "ModelManager",
events: "EventServiceBase",
logger: "Logger",
latents: "LatentsStorageBase",
images: "ImageService",
queue: "InvocationQueueABC",
graph_library: "ItemStorageABC"["LibraryGraph"],
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: "RestorationServices",
configuration: "InvokeAISettings",
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata
self.queue = queue
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager

View File

@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)

View File

@ -1,105 +1,142 @@
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TypedDict
from PIL import Image, PngImagePlugin
from pydantic import BaseModel
from typing import Any, Union
import networkx as nx
from invokeai.app.models.image import ImageType, is_image_type
class MetadataImageField(TypedDict):
"""Pydantic-less ImageField, used for metadata parsing."""
image_type: ImageType
image_name: str
class MetadataLatentsField(TypedDict):
"""Pydantic-less LatentsField, used for metadata parsing."""
latents_name: str
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
]
class InvokeAIMetadata(TypedDict, total=False):
"""InvokeAI-specific metadata format."""
session_id: Optional[str]
node: Optional[NodeMetadata]
def build_invokeai_metadata_pnginfo(
metadata: InvokeAIMetadata | None,
) -> PngImagePlugin.PngInfo:
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai", json.dumps(metadata))
return pnginfo
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.graph import Graph, GraphExecutionState
class MetadataServiceBase(ABC):
@abstractmethod
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
pass
"""Handles building metadata for nodes, images, and outputs."""
@abstractmethod
def build_metadata(
self, session_id: str, node: BaseModel
) -> InvokeAIMetadata | None:
"""Builds an InvokeAIMetadata object"""
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""Builds an ImageMetadata object for a node."""
pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""
class CoreMetadataService(MetadataServiceBase):
_ANCESTOR_TYPES = ["t2l", "l2l"]
"""The ancestor types that contain the core metadata"""
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
def _load_metadata(self, image: Image.Image) -> dict | None:
"""Loads a specific info entry from a PIL Image."""
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
"""The core metadata parameters in the ancestor types"""
try:
info = image.info.get("invokeai")
_NOISE_FIELDS = ["seed", "width", "height"]
"""The core metadata parameters in the noise node"""
if type(info) is not str:
return None
loaded_metadata = json.loads(info)
if type(loaded_metadata) is not dict:
return None
if len(loaded_metadata.items()) == 0:
return None
return loaded_metadata
except:
return None
def get_metadata(self, image: Image.Image) -> dict | None:
"""Retrieves an image's metadata as a dict"""
loaded_metadata = self._load_metadata(image)
return loaded_metadata
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
def create_image_metadata(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
metadata = self._build_metadata_from_graph(session, node_id)
return metadata
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
"""
Finds the id of the nearest ancestor (of a valid type) of a given node.
Parameters:
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
have the same data as the execution graph.
node_id (str): The ID of the node.
Returns:
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, return its id
if node.get("type") in self._ANCESTOR_TYPES:
return node.get("id")
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = self._find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def _get_additional_metadata(
self, graph: Graph, node_id: str
) -> Union[dict[str, Any], None]:
"""
Returns additional metadata for a given node.
Parameters:
graph (Graph): The execution graph.
node_id (str): The ID of the node.
Returns:
dict[str, Any] | None: A dictionary of additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node_dict = graph.nodes[edge.source.node_id].dict()
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# Prompt
if dest_field == "positive_conditioning":
metadata["positive_conditioning"] = source_node_dict.get("prompt")
# Negative prompt
if dest_field == "negative_conditioning":
metadata["negative_conditioning"] = source_node_dict.get("prompt")
# Seed, width and height
if dest_field == "noise":
for field in self._NOISE_FIELDS:
metadata[field] = source_node_dict.get(field)
return metadata
def _build_metadata_from_graph(
self, session: GraphExecutionState, node_id: str
) -> ImageMetadata:
"""
Builds an ImageMetadata object for a node.
Parameters:
session (GraphExecutionState): The session.
node_id (str): The ID of the node.
Returns:
ImageMetadata: The metadata for the node.
"""
# We need to do all the traversal on the execution graph
graph = session.execution_graph
# Find the nearest `t2l`/`l2l` ancestor of the given node
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
# If no ancestor was found, return an empty ImageMetadata object
if ancestor_id is None:
return ImageMetadata()
ancestor_node = graph.get_node(ancestor_id)
# Grab all the core metadata from the ancestor node
ancestor_metadata = {
param: val
for param, val in ancestor_node.dict().items()
if param in self._ANCESTOR_PARAMS
}
# Get this image's prompts and noise parameters
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
ancestor_metadata.update(addl_metadata)
return ImageMetadata(**ancestor_metadata)

View File

@ -0,0 +1,118 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image."""
width: int = Field(description="The width of the image in px.")
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
"""The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field(
description="The updated timestamp of the image."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
)
"""The session ID that generated this image, if it is a generated image."""
node_id: Optional[str] = Field(
default=None,
description="The node ID that generated this image, if it is a generated image.",
)
"""The node ID that generated this image, if it is a generated image."""
metadata: Optional[ImageMetadata] = Field(
default=None,
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
)
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.")
"""The type of the image."""
image_url: str = Field(description="The URL of the image.")
"""The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
"""The URL of the image's thumbnail."""
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend with URLs."""
pass
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
**image_record.dict(),
image_url=image_url,
thumbnail_url=thumbnail_url,
)
def deserialize_image_record(image_dict: dict) -> ImageRecord:
"""Deserializes an image record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)
node_id = image_dict.get("node_id", None)
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
raw_metadata = image_dict.get("metadata")
if raw_metadata is not None:
metadata = ImageMetadata.parse_raw(raw_metadata)
else:
metadata = None
return ImageRecord(
image_name=image_name,
image_type=image_type,
image_category=image_category,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
metadata=metadata,
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
)

View File

@ -0,0 +1,34 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail."""
pass
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail:
return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
)
return f"{self._base_url}/images/{image_type.value}/{image_basename}"

View File

@ -0,0 +1,15 @@
from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support additional features in Enums.
- `in` operator support: `'value' in MyEnum -> bool`
"""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True

View File

@ -6,6 +6,14 @@ def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max

View File

@ -196,7 +196,7 @@ class Inpaint(Img2Img):
seam_noise = self.get_noise(im.width, im.height)
result = make_image(seam_noise, seed)
result = make_image(seam_noise, seed=None)
return result

View File

@ -76,16 +76,16 @@ class InvokeAILogFormatter(logging.Formatter):
reset = "\x1b[0m"
# Log Format
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
log_format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
# Format Map
FORMATS = {
logging.DEBUG: cyan + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
logging.DEBUG: cyan + log_format + reset,
logging.INFO: grey + log_format + reset,
logging.WARNING: yellow + log_format + reset,
logging.ERROR: red + log_format + reset,
logging.CRITICAL: bold_red + log_format + reset
}
def format(self, record):
@ -98,13 +98,13 @@ class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
if name not in self.loggers:
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
if name not in cls.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
self.loggers[name] = logger
return self.loggers[name]
cls.loggers[name] = logger
return cls.loggers[name]

View File

@ -23,8 +23,8 @@
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",

View File

@ -10,7 +10,7 @@ export const readinessSelector = createSelector(
[generationSelector, systemSelector, activeTabNameSelector],
(generation, system, activeTabName) => {
const {
prompt,
positivePrompt: prompt,
shouldGenerateVariations,
seedWeights,
initialImage,

View File

@ -5,7 +5,6 @@ import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { addToast } from 'features/system/store/systemSlice';
import { imageUploaded } from 'services/thunks/image';
import { v4 as uuidv4 } from 'uuid';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
@ -66,7 +65,7 @@ export const addCanvasMergedListener = () => {
action.meta.arg.formData.file.name === filename
);
const mergedCanvasImage = deserializeImageResponse(payload.response);
const mergedCanvasImage = payload.response;
dispatch(
setMergedCanvas({

View File

@ -17,24 +17,24 @@ export const addRequestedImageDeletionListener = () => {
return;
}
const { name, type } = image;
const { image_name, image_type } = image;
if (type !== 'uploads' && type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
if (image_type !== 'uploads' && image_type !== 'results') {
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
return;
}
const selectedImageName = getState().gallery.selectedImage?.name;
const selectedImageName = getState().gallery.selectedImage?.image_name;
if (selectedImageName === name) {
const allIds = getState()[type].ids;
const allEntities = getState()[type].entities;
if (selectedImageName === image_name) {
const allIds = getState()[image_type].ids;
const allEntities = getState()[image_type].entities;
const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === name
(result) => result.toString() === image_name
);
const filteredIds = allIds.filter((id) => id.toString() !== name);
const filteredIds = allIds.filter((id) => id.toString() !== image_name);
const newSelectedImageIndex = clamp(
deletedImageIndex,
@ -53,7 +53,7 @@ export const addRequestedImageDeletionListener = () => {
}
}
dispatch(imageDeleted({ imageName: name, imageType: type }));
dispatch(imageDeleted({ imageName: image_name, imageType: image_type }));
},
});
};

View File

@ -1,4 +1,3 @@
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { startAppListening } from '..';
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
@ -7,6 +6,7 @@ import { addToast } from 'features/system/store/systemSlice';
import { initialImageSelected } from 'features/parameters/store/actions';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
export const addImageUploadedListener = () => {
startAppListening({
@ -14,13 +14,11 @@ export const addImageUploadedListener = () => {
imageUploaded.fulfilled.match(action) &&
action.payload.response.image_type !== 'intermediates',
effect: (action, { dispatch, getState }) => {
const { response } = action.payload;
const { imageType } = action.meta.arg;
const { response: image } = action.payload;
const state = getState();
const image = deserializeImageResponse(response);
if (imageType === 'uploads') {
if (isUploadsImageDTO(image)) {
dispatch(uploadAdded(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
@ -38,7 +36,7 @@ export const addImageUploadedListener = () => {
}
}
if (imageType === 'results') {
if (isResultsImageDTO(image)) {
dispatch(resultAdded(image));
}
},

View File

@ -1,12 +1,15 @@
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { Image, isInvokeAIImage } from 'app/types/invokeai';
import { selectResultsById } from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { t } from 'i18next';
import { addToast } from 'features/system/store/systemSlice';
import { startAppListening } from '..';
import { initialImageSelected } from 'features/parameters/store/actions';
import {
initialImageSelected,
isImageDTO,
} from 'features/parameters/store/actions';
import { makeToast } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
export const addInitialImageSelectedListener = () => {
startAppListening({
@ -21,21 +24,21 @@ export const addInitialImageSelectedListener = () => {
return;
}
if (isInvokeAIImage(action.payload)) {
if (isImageDTO(action.payload)) {
dispatch(initialImageChanged(action.payload));
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
return;
}
const { name, type } = action.payload;
const { image_name, image_type } = action.payload;
let image: Image | undefined;
let image: ImageDTO | undefined;
const state = getState();
if (type === 'results') {
image = selectResultsById(state, name);
} else if (type === 'uploads') {
image = selectUploadsById(state, name);
if (image_type === 'results') {
image = selectResultsById(state, image_name);
} else if (image_type === 'uploads') {
image = selectUploadsById(state, image_name);
}
if (!image) {

View File

@ -1,14 +1,10 @@
import { invocationComplete } from 'services/events/actions';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { Image } from 'app/types/invokeai';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { startAppListening } from '..';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
const nodeDenylist = ['dataURL_image'];
@ -24,62 +20,40 @@ export const addImageResultReceivedListener = () => {
}
return false;
},
effect: (action, { getState, dispatch }) => {
effect: async (action, { getState, dispatch, take }) => {
if (!invocationComplete.match(action)) {
return;
}
const { data, shouldFetchImages } = action.payload;
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const name = result.image.image_name;
const type = result.image.image_type;
const state = getState();
const { image_name, image_type } = result.image;
// if we need to refetch, set URLs to placeholder for now
const { url, thumbnail } = shouldFetchImages
? { url: '', thumbnail: '' }
: buildImageUrls(type, name);
dispatch(
imageUrlsReceived({ imageName: image_name, imageType: image_type })
);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width,
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
dispatch(resultAdded(image));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
if (state.config.shouldFetchImages) {
dispatch(imageReceived({ imageName: name, imageType: type }));
dispatch(
thumbnailReceived({
thumbnailName: name,
thumbnailType: type,
})
);
}
dispatch(
imageMetadataReceived({
imageName: image_name,
imageType: image_type,
})
);
// Handle canvas image
if (
graph_execution_state_id ===
state.canvas.layerState.stagingArea.sessionId
getState().canvas.layerState.stagingArea.sessionId
) {
const [{ payload: image }] = await take(
(
action
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
imageMetadataReceived.fulfilled.match(action) &&
action.payload.image_name === image_name
);
dispatch(addImageToStagingArea(image));
}
}

View File

@ -122,21 +122,21 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
/**
* ResultImage
*/
export type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageResponseMetadata;
};
// export ty`pe Image = {
// name: string;
// type: ImageType;
// url: string;
// thumbnail: string;
// metadata: ImageResponseMetadata;
// };
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
if ('url' in obj && 'thumbnail' in obj) {
return true;
}
// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
// if ('url' in obj && 'thumbnail' in obj) {
// return true;
// }
return false;
};
// return false;
// };
/**
* Types related to the system status.
@ -346,7 +346,6 @@ export type AppConfig = {
/**
* Whether or not we need to re-fetch images
*/
shouldFetchImages: boolean;
disabledTabs: InvokeTabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@ -1,10 +1,10 @@
import { Badge, Flex } from '@chakra-ui/react';
import { Image } from 'app/types/invokeai';
import { isNumber, isString } from 'lodash-es';
import { useMemo } from 'react';
import { ImageDTO } from 'services/api';
type ImageMetadataOverlayProps = {
image: Image;
image: ImageDTO;
};
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
@ -17,11 +17,11 @@ const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
}, [image.metadata]);
const model = useMemo(() => {
if (!isString(image.metadata?.invokeai?.node?.model)) {
if (!isString(image.metadata?.model)) {
return;
}
return image.metadata?.invokeai?.node?.model;
return image.metadata?.model;
}, [image.metadata]);
return (

View File

@ -0,0 +1,12 @@
/**
* Comparator function for sorting dates in ascending order
*/
export const dateComparator = (a: string, b: string) => {
const dateA = new Date(a);
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) return 1;
if (dateA < dateB) return -1;
return 0;
};

View File

@ -46,7 +46,7 @@ const IAICanvasObjectRenderer = () => {
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
url={getUrl(obj.image.image_url)}
/>
);
} else if (isCanvasBaseLine(obj)) {

View File

@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
url={getUrl(currentStagingAreaImage.image.image_url)}
x={x}
y={y}
/>

View File

@ -157,17 +157,19 @@ const IAICanvasStagingAreaToolbar = () => {
}
colorScheme="accent"
/>
<IAIIconButton
{/* <IAIIconButton
tooltip={t('unifiedCanvas.saveToGallery')}
aria-label={t('unifiedCanvas.saveToGallery')}
icon={<FaSave />}
onClick={() =>
dispatch(
saveStagingAreaImageToGallery(currentStagingAreaImage.image.url)
saveStagingAreaImageToGallery(
currentStagingAreaImage.image.image_url
)
)
}
colorScheme="accent"
/>
/> */}
<IAIIconButton
tooltip={t('unifiedCanvas.discardAll')}
aria-label={t('unifiedCanvas.discardAll')}

View File

@ -1,6 +1,5 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/types/invokeai';
import {
roundDownToMultiple,
roundToMultiple,
@ -29,6 +28,7 @@ import {
isCanvasBaseImage,
isCanvasMaskLine,
} from './canvasTypes';
import { ImageDTO } from 'services/api';
export const initialLayerState: CanvasLayerState = {
objects: [],
@ -157,9 +157,9 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<ImageDTO>) => {
const image = action.payload;
const { width, height } = image.metadata;
const { width, height } = image;
const { stageDimensions } = state;
const newBoundingBoxDimensions = {
@ -302,7 +302,7 @@ export const canvasSlice = createSlice({
selectedImageIndex: -1,
};
},
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
addImageToStagingArea: (state, action: PayloadAction<ImageDTO>) => {
const image = action.payload;
if (!image || !state.layerState.stagingArea.boundingBox) {

View File

@ -1,6 +1,7 @@
import * as InvokeAI from 'app/types/invokeai';
import { IRect, Vector2d } from 'konva/lib/types';
import { RgbaColor } from 'react-colorful';
import { ImageDTO } from 'services/api';
export const LAYER_NAMES_DICT = [
{ key: 'Base', value: 'base' },
@ -37,7 +38,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI.Image;
image: ImageDTO;
};
export type CanvasMaskLine = {

View File

@ -195,14 +195,14 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
if (shouldTransformUrls) {
return getUrl(image.url);
return getUrl(image.image_url);
}
if (image.url.startsWith('http')) {
return image.url;
if (image.image_url.startsWith('http')) {
return image.image_url;
}
return window.location.toString() + image.url;
return window.location.toString() + image.image_url;
};
const url = getImageUrl();

View File

@ -61,8 +61,8 @@ const CurrentImagePreview = () => {
if (!image) {
return;
}
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
@ -108,7 +108,7 @@ const CurrentImagePreview = () => {
image && (
<>
<Image
src={getUrl(image.url)}
src={getUrl(image.image_url)}
fallbackStrategy="beforeLoadOrError"
fallback={<ImageFallbackSpinner />}
onDragStart={handleDragStart}

View File

@ -13,7 +13,6 @@ import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { ContextMenu } from 'chakra-ui-contextmenu';
import * as InvokeAI from 'app/types/invokeai';
import {
resizeAndScaleCanvas,
setInitialCanvasImage,
@ -39,6 +38,7 @@ import {
sentImageToImg2Img,
} from '../store/actions';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
export const selector = createSelector(
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
@ -70,14 +70,16 @@ export const selector = createSelector(
);
interface HoverableImageProps {
image: InvokeAI.Image;
image: ImageDTO;
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
) =>
prev.image.image_name === next.image.image_name &&
prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@ -100,7 +102,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
} = useDisclosure();
const { image, isSelected } = props;
const { url, thumbnail, name } = image;
const { image_url, thumbnail_url, image_name } = image;
const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false);
@ -144,8 +146,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type);
e.dataTransfer.effectAllowed = 'move';
},
[image]
@ -153,11 +155,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
// Recall parameters handlers
const handleRecallPrompt = useCallback(() => {
recallPrompt(image.metadata?.invokeai?.node?.prompt);
recallPrompt(image.metadata?.positive_conditioning);
}, [image, recallPrompt]);
const handleRecallSeed = useCallback(() => {
recallSeed(image.metadata.invokeai?.node?.seed);
recallSeed(image.metadata?.seed);
}, [image, recallSeed]);
const handleSendToImageToImage = useCallback(() => {
@ -165,9 +167,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
dispatch(initialImageSelected(image));
}, [dispatch, image]);
const handleRecallInitialImage = useCallback(() => {
recallInitialImage(image.metadata.invokeai?.node?.image);
}, [image, recallInitialImage]);
// const handleRecallInitialImage = useCallback(() => {
// recallInitialImage(image.metadata.invokeai?.node?.image);
// }, [image, recallInitialImage]);
/**
* TODO: the rest of these
@ -200,7 +202,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleOpenInNewTab = () => {
window.open(getUrl(image.url), '_blank');
window.open(getUrl(image.image_url), '_blank');
};
return (
@ -223,7 +225,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallPrompt}
isDisabled={image?.metadata?.invokeai?.node?.prompt === undefined}
isDisabled={image?.metadata?.positive_conditioning === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
@ -231,23 +233,23 @@ const HoverableImage = memo((props: HoverableImageProps) => {
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallSeed}
isDisabled={image?.metadata?.invokeai?.node?.seed === undefined}
isDisabled={image?.metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
{/* <MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleRecallInitialImage}
isDisabled={image?.metadata?.invokeai?.node?.type !== 'img2img'}
isDisabled={image?.metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
</MenuItem> */}
<MenuItem
icon={<IoArrowUndoCircleOutline />}
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img', 'inpaint'].includes(
String(image?.metadata?.invokeai?.node?.type)
String(image?.metadata?.type)
)
}
>
@ -278,7 +280,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={name}
key={image_name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@ -303,7 +305,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={getUrl(thumbnail || url)}
src={getUrl(thumbnail_url || image_url)}
fallback={<FaImage />}
sx={{
width: '100%',

View File

@ -12,7 +12,7 @@ import { memo, useCallback } from 'react';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import DeleteImageModal from '../DeleteImageModal';
import { requestedImageDeletion } from 'features/gallery/store/actions';
import { Image } from 'app/types/invokeai';
import { ImageDTO } from 'services/api';
const selector = createSelector(
[systemSelector],
@ -30,7 +30,7 @@ const selector = createSelector(
);
type DeleteImageButtonProps = {
image: Image | undefined;
image: ImageDTO | undefined;
};
const DeleteImageButton = (props: DeleteImageButtonProps) => {

View File

@ -5,7 +5,6 @@ import {
FlexProps,
Grid,
Icon,
Image,
Text,
forwardRef,
} from '@chakra-ui/react';
@ -51,10 +50,10 @@ import { uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
import { Image as ImageType } from 'app/types/invokeai';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import GalleryProgressImage from './GalleryProgressImage';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { ImageDTO } from 'services/api';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
@ -66,7 +65,7 @@ const categorySelector = createSelector(
const { currentCategory } = gallery;
if (currentCategory === 'results') {
const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
if (system.progressImage) {
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
@ -352,7 +351,7 @@ const ImageGalleryContent = () => {
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
: selectedImage?.image_name === image?.image_name;
return (
<Flex sx={{ pb: 2 }}>
@ -362,7 +361,7 @@ const ImageGalleryContent = () => {
/>
) : (
<HoverableImage
key={`${image.name}-${image.thumbnail}`}
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={isSelected}
/>
@ -385,13 +384,13 @@ const ImageGalleryContent = () => {
const isSelected =
image === PROGRESS_IMAGE_PLACEHOLDER
? false
: selectedImage?.name === image?.name;
: selectedImage?.image_name === image?.image_name;
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
) : (
<HoverableImage
key={`${image.name}-${image.thumbnail}`}
key={`${image.image_name}-${image.thumbnail_url}`}
image={image}
isSelected={isSelected}
/>

View File

@ -18,7 +18,9 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
setNegativePrompt,
setPerlin,
setPositivePrompt,
setScheduler,
setSeamless,
setSeed,
@ -36,6 +38,9 @@ import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { ImageDTO } from 'services/api';
import { filter } from 'lodash-es';
import { Scheduler } from 'app/constants';
type MetadataItemProps = {
isLink?: boolean;
@ -58,7 +63,6 @@ const MetadataItem = ({
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
return (
<Flex gap={2}>
{onClick && (
@ -104,14 +108,14 @@ const MetadataItem = ({
};
type ImageMetadataViewerProps = {
image: InvokeAI.Image;
image: ImageDTO;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
) => prev.image.image_name === next.image.image_name;
// TODO: Show more interesting information in this component.
@ -128,8 +132,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const sessionId = image.metadata.invokeai?.session_id;
const node = image.metadata.invokeai?.node as Record<string, any>;
const sessionId = image?.session_id;
const metadata = image?.metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
@ -154,110 +159,131 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<Link
href={getUrl(image.image_url)}
isExternal
maxW="calc(100% - 3rem)"
>
{image.image_name}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{node && Object.keys(node).length > 0 ? (
{metadata && Object.keys(metadata).length > 0 ? (
<>
{node.type && (
<MetadataItem label="Invocation type" value={node.type} />
{metadata.type && (
<MetadataItem label="Invocation type" value={metadata.type} />
)}
{node.model && <MetadataItem label="Model" value={node.model} />}
{node.prompt && (
{metadata.width && (
<MetadataItem
label="Width"
value={metadata.width}
onClick={() => dispatch(setWidth(Number(metadata.width)))}
/>
)}
{metadata.height && (
<MetadataItem
label="Height"
value={metadata.height}
onClick={() => dispatch(setHeight(Number(metadata.height)))}
/>
)}
{metadata.model && (
<MetadataItem label="Model" value={metadata.model} />
)}
{metadata.positive_conditioning && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof node.prompt === 'string'
? node.prompt
: promptToString(node.prompt)
typeof metadata.positive_conditioning === 'string'
? metadata.positive_conditioning
: promptToString(metadata.positive_conditioning)
}
onClick={() => setBothPrompts(node.prompt)}
onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
/>
)}
{node.seed !== undefined && (
{metadata.negative_conditioning && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof metadata.negative_conditioning === 'string'
? metadata.negative_conditioning
: promptToString(metadata.negative_conditioning)
}
onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
/>
)}
{metadata.seed !== undefined && (
<MetadataItem
label="Seed"
value={node.seed}
onClick={() => dispatch(setSeed(Number(node.seed)))}
value={metadata.seed}
onClick={() => dispatch(setSeed(Number(metadata.seed)))}
/>
)}
{node.threshold !== undefined && (
{/* {metadata.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={node.threshold}
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
value={metadata.threshold}
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
/>
)}
{node.perlin !== undefined && (
{metadata.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={node.perlin}
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
value={metadata.perlin}
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
/>
)}
{node.scheduler && (
)} */}
{metadata.scheduler && (
<MetadataItem
label="Scheduler"
value={node.scheduler}
onClick={() => dispatch(setScheduler(node.scheduler))}
/>
)}
{node.steps && (
<MetadataItem
label="Steps"
value={node.steps}
onClick={() => dispatch(setSteps(Number(node.steps)))}
/>
)}
{node.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={node.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
/>
)}
{node.variations && node.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(node.variations)}
value={metadata.scheduler}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
dispatch(setScheduler(metadata.scheduler as Scheduler))
}
/>
)}
{node.seamless && (
{metadata.steps && (
<MetadataItem
label="Steps"
value={metadata.steps}
onClick={() => dispatch(setSteps(Number(metadata.steps)))}
/>
)}
{metadata.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={metadata.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(metadata.cfg_scale)))}
/>
)}
{/* {metadata.variations && metadata.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(metadata.variations)}
onClick={() =>
dispatch(
setSeedWeights(seedWeightsToString(metadata.variations))
)
}
/>
)}
{metadata.seamless && (
<MetadataItem
label="Seamless"
value={node.seamless}
onClick={() => dispatch(setSeamless(node.seamless))}
value={metadata.seamless}
onClick={() => dispatch(setSeamless(metadata.seamless))}
/>
)}
{node.hires_fix && (
{metadata.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={node.hires_fix}
onClick={() => dispatch(setHiresFix(node.hires_fix))}
value={metadata.hires_fix}
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
/>
)}
{node.width && (
<MetadataItem
label="Width"
value={node.width}
onClick={() => dispatch(setWidth(Number(node.width)))}
/>
)}
{node.height && (
<MetadataItem
label="Height"
value={node.height}
onClick={() => dispatch(setHeight(Number(node.height)))}
/>
)}
)} */}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
@ -266,22 +292,22 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{node.strength && (
{metadata.strength && (
<MetadataItem
label="Image to image strength"
value={node.strength}
value={metadata.strength}
onClick={() =>
dispatch(setImg2imgStrength(Number(node.strength)))
dispatch(setImg2imgStrength(Number(metadata.strength)))
}
/>
)}
{node.fit && (
{/* {metadata.fit && (
<MetadataItem
label="Image to image fit"
value={node.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
value={metadata.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
/>
)}
)} */}
</>
) : (
<Center width="100%" pt={10}>

View File

@ -1,470 +0,0 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import {
Box,
Center,
Flex,
Heading,
IconButton,
Link,
Text,
Tooltip,
} from '@chakra-ui/react';
import * as InvokeAI from 'app/types/invokeai';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
setSeamless,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setThreshold,
setWidth,
} from 'features/parameters/store/generationSlice';
import {
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
} from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = {
image: InvokeAI.Image;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const setBothPrompts = useSetBothPrompts();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const {
cfg_scale,
fit,
height,
hires_fix,
init_image_path,
mask_image_path,
orig_path,
perlin,
postprocessing,
prompt,
sampler,
seamless,
seed,
steps,
strength,
threshold,
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
return (
<Flex
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600',
_dark: {
bg: 'blackAlpha.600',
},
}}
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
)}
{prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
}
onClick={() => setBothPrompts(prompt)}
/>
)}
{seed !== undefined && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
/>
)}
{threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={threshold}
onClick={() => dispatch(setThreshold(threshold))}
/>
)}
{perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={perlin}
onClick={() => dispatch(setPerlin(perlin))}
/>
)}
{sampler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
/>
)}
{steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
/>
)}
{cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
/>
)}
{variations && variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
}
/>
)}
{seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setSeamless(seamless))}
/>
)}
{hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={hires_fix}
onClick={() => dispatch(setHiresFix(hires_fix))}
/>
)}
{width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing && postprocessing.length > 0 && (
<>
<Heading size="sm">Postprocessing</Heading>
{postprocessing.map(
(
postprocess: InvokeAI.PostProcessedImageMetadata,
i: number
) => {
if (postprocess.type === 'esrgan') {
const { scale, strength, denoise_str } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
<MetadataItem
label="Scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Strength"
value={strength}
onClick={() =>
dispatch(setUpscalingStrength(strength))
}
/>
{denoise_str !== undefined && (
<MetadataItem
label="Denoising strength"
value={denoise_str}
onClick={() =>
dispatch(setUpscalingDenoising(denoise_str))
}
/>
)}
</Flex>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (GFPGAN)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('gfpgan'));
}}
/>
</Flex>
);
} else if (postprocess.type === 'codeformer') {
const { strength, fidelity } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (Codeformer)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('codeformer'));
}}
/>
{fidelity && (
<MetadataItem
label="Fidelity"
value={fidelity}
onClick={() => {
dispatch(setCodeformerFidelity(fidelity));
dispatch(setFacetoolType('codeformer'));
}}
/>
)}
</Flex>
);
}
}
)}
</>
)}
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
export default ImageMetadataViewer;

View File

@ -13,11 +13,9 @@ const useGetImageByNameSelector = createSelector(
const useGetImageByNameAndType = () => {
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
return (name: string, type: ImageType) => {
if (type === 'results') {
const resultImagesResult = allResults[name];
if (resultImagesResult) {
return resultImagesResult;
}

View File

@ -1,9 +1,9 @@
import { createAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { SelectedImage } from 'features/parameters/store/actions';
import { ImageNameAndType } from 'features/parameters/store/actions';
import { ImageDTO } from 'services/api';
export const requestedImageDeletion = createAction<
Image | SelectedImage | undefined
ImageDTO | ImageNameAndType | undefined
>('gallery/requestedImageDeletion');
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');

View File

@ -1,16 +1,15 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api';
type GalleryImageObjectFitType = 'contain' | 'cover';
export interface GalleryState {
selectedImage?: Image;
selectedImage?: ImageDTO;
galleryImageMinimumWidth: number;
galleryImageObjectFit: GalleryImageObjectFitType;
shouldAutoSwitchToNewImages: boolean;
@ -30,7 +29,7 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState: initialGalleryState,
reducers: {
imageSelected: (state, action: PayloadAction<Image | undefined>) => {
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
state.selectedImage = action.payload;
// TODO: if the user selects an image, disable the auto switch?
// state.shouldAutoSwitchToNewImages = false;
@ -61,37 +60,18 @@ export const gallerySlice = createSlice({
},
},
extraReducers(builder) {
builder.addCase(imageReceived.fulfilled, (state, action) => {
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
// which is currently its own object (instead of a reference to an image in results/uploads)
const { imagePath } = action.payload;
const { imageName } = action.meta.arg;
if (state.selectedImage?.name === imageName) {
state.selectedImage.url = imagePath;
}
});
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
// which is currently its own object (instead of a reference to an image in results/uploads)
const { thumbnailPath } = action.payload;
const { thumbnailName } = action.meta.arg;
if (state.selectedImage?.name === thumbnailName) {
state.selectedImage.thumbnail = thumbnailPath;
}
});
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.name
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.url = selectedImageInResults.image_url;
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
}
@ -102,10 +82,12 @@ export const gallerySlice = createSlice({
const selectedImage = state.selectedImage;
if (selectedImage) {
const selectedImageInResults = action.payload.items.find(
(image) => image.image_name === selectedImage.name
(image) => image.image_name === selectedImage.image_name
);
if (selectedImageInResults) {
selectedImage.url = selectedImageInResults.image_url;
selectedImage.image_url = selectedImageInResults.image_url;
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
state.selectedImage = selectedImage;
}
}

View File

@ -1,21 +1,24 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { RootState } from 'app/store/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import {
imageDeleted,
imageReceived,
thumbnailReceived,
imageMetadataReceived,
imageUrlsReceived,
} from 'services/thunks/image';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export const resultsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'results';
};
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalResultsState = {
@ -53,13 +56,12 @@ const resultsSlice = createSlice({
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const { page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// We know these will all be of the results type, but it's not represented in the API types
const items = action.payload.items as ResultsImageDTO[];
resultsAdapter.setMany(state, resultImages);
resultsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
@ -68,33 +70,32 @@ const resultsSlice = createSlice({
});
/**
* Image Received - FULFILLED
* Image Metadata Received - FULFILLED
*/
builder.addCase(imageReceived.fulfilled, (state, action) => {
const { imagePath } = action.payload;
const { imageName } = action.meta.arg;
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
const { image_type } = action.payload;
resultsAdapter.updateOne(state, {
id: imageName,
changes: {
url: imagePath,
},
});
if (image_type === 'results') {
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
}
});
/**
* Thumbnail Received - FULFILLED
* Image URLs Received - FULFILLED
*/
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
const { thumbnailPath } = action.payload;
const { thumbnailName } = action.meta.arg;
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
resultsAdapter.updateOne(state, {
id: thumbnailName,
changes: {
thumbnail: thumbnailPath,
},
});
if (image_type === 'results') {
resultsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**

View File

@ -1,17 +1,21 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { RootState } from 'app/store/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageDeleted } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
image_type: 'uploads';
};
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
});
type AdditionalUploadsState = {
@ -49,11 +53,12 @@ const uploadsSlice = createSlice({
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const { page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
// We know these will all be of the uploads type, but it's not represented in the API types
const items = action.payload.items as UploadsImageDTO[];
uploadsAdapter.setMany(state, images);
uploadsAdapter.setMany(state, items);
state.page = page;
state.pages = pages;
@ -61,6 +66,24 @@ const uploadsSlice = createSlice({
state.isLoading = false;
});
/**
* Image URLs Received - FULFILLED
*/
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_type, image_url, thumbnail_url } =
action.payload;
if (image_type === 'uploads') {
uploadsAdapter.updateOne(state, {
id: image_name,
changes: {
image_url: image_url,
thumbnail_url: thumbnail_url,
},
});
}
});
/**
* Delete Image - pending
* Pre-emptively remove the image from the gallery

View File

@ -1,10 +1,10 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/types/invokeai';
import { useGetUrl } from 'common/util/getUrl';
import { ImageDTO } from 'services/api';
type ReactPanZoomProps = {
image: InvokeAI.Image;
image: ImageDTO;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@ -37,7 +37,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={getUrl(image.url)}
src={getUrl(image.image_url)}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@ -21,7 +21,7 @@ const ImageInputFieldComponent = (
const getImageByNameAndType = useGetImageByNameAndType();
const dispatch = useAppDispatch();
const [url, setUrl] = useState<string>();
const [url, setUrl] = useState<string | undefined>(field.value?.image_url);
const { getUrl } = useGetUrl();
const handleDrop = useCallback(
@ -39,16 +39,13 @@ const ImageInputFieldComponent = (
return;
}
setUrl(image.url);
setUrl(image.image_url);
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: {
image_name: name,
image_type: type,
},
value: image,
})
);
},

View File

@ -11,7 +11,7 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { ImageField } from 'services/api';
import { ImageDTO } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
@ -65,13 +65,7 @@ const nodesSlice = createSlice({
action: PayloadAction<{
nodeId: string;
fieldName: string;
value:
| string
| number
| boolean
| Pick<ImageField, 'image_name' | 'image_type'>
| RgbaColor
| undefined;
value: string | number | boolean | ImageDTO | RgbaColor | undefined;
}>
) => {
const { nodeId, fieldName, value } = action.payload;

View File

@ -1,7 +1,10 @@
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { ImageField } from 'services/api';
import { Graph, ImageDTO } from 'services/api';
import { AnyInvocationType } from 'services/events/types';
import { O } from 'ts-toolbelt';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
export type InvocationValue = {
id: string;
@ -179,7 +182,7 @@ export type ConditioningInputFieldValue = FieldValueBase & {
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: Pick<ImageField, 'image_name' | 'image_type'>;
value?: ImageDTO;
};
export type ModelInputFieldValue = FieldValueBase & {
@ -245,7 +248,7 @@ export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
};
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: Pick<ImageField, 'image_name' | 'image_type'>;
default: ImageDTO;
type: 'image';
};

View File

@ -1,35 +1,131 @@
import { RootState } from 'app/store/store';
import { Graph } from 'services/api';
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
import { buildEdges } from '../edgeBuilders/buildEdges';
import {
CompelInvocation,
Graph,
ImageToLatentsInvocation,
LatentsToImageInvocation,
LatentsToLatentsInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ namespace: 'buildImageToImageGraph' });
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
const IMAGE_TO_LATENTS = 'image_to_latents';
const LATENTS_TO_LATENTS = 'latents_to_latents';
const LATENTS_TO_IMAGE = 'latents_to_image';
/**
* Builds the Linear workflow graph.
* Builds the Image to Image tab graph.
*/
export const buildImageToImageGraph = (state: RootState): Graph => {
const baseNode = buildImg2ImgNode(state);
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
initialImage,
img2imgStrength: strength,
} = state.generation;
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
const rangeNode = buildRangeNode(state);
const iterateNode = buildIterateNode();
if (!initialImage) {
moduleLog.error('No initial image found in state');
throw new Error('No initial image found in state');
}
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
// Assemble!
const graph = {
nodes: {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
edges,
let graph: NonNullableGraph = {
nodes: {},
edges: [],
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
// Create the conditioning, t2l and l2i nodes
const positiveConditioningNode: CompelInvocation = {
id: POSITIVE_CONDITIONING,
type: 'compel',
prompt: positivePrompt,
model,
};
const negativeConditioningNode: CompelInvocation = {
id: NEGATIVE_CONDITIONING,
type: 'compel',
prompt: negativePrompt,
model,
};
const imageToLatentsNode: ImageToLatentsInvocation = {
id: IMAGE_TO_LATENTS,
type: 'i2l',
model,
image: {
image_name: initialImage?.image_name,
image_type: initialImage?.image_type,
},
};
const latentsToLatentsNode: LatentsToLatentsInvocation = {
id: LATENTS_TO_LATENTS,
type: 'l2l',
cfg_scale,
model,
scheduler,
steps,
strength,
};
const latentsToImageNode: LatentsToImageInvocation = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
model,
};
// Add to the graph
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
// Connect them
graph.edges.push({
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'negative_conditioning',
},
});
graph.edges.push({
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
destination: {
node_id: LATENTS_TO_LATENTS,
field: 'latents',
},
});
graph.edges.push({
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
});
// Create and add the noise nodes
graph = addNoiseNodes(graph, latentsToLatentsNode.id, state);
return graph;
};

View File

@ -1,35 +1,99 @@
import { RootState } from 'app/store/store';
import { Graph } from 'services/api';
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
import { buildEdges } from '../edgeBuilders/buildEdges';
import {
CompelInvocation,
Graph,
LatentsToImageInvocation,
TextToLatentsInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
const POSITIVE_CONDITIONING = 'positive_conditioning';
const NEGATIVE_CONDITIONING = 'negative_conditioning';
const TEXT_TO_LATENTS = 'text_to_latents';
const LATENTS_TO_IMAGE = 'latnets_to_image';
/**
* Builds the Linear workflow graph.
* Builds the Text to Image tab graph.
*/
export const buildTextToImageGraph = (state: RootState): Graph => {
const baseNode = buildTxt2ImgNode(state);
const {
positivePrompt,
negativePrompt,
model,
cfgScale: cfg_scale,
scheduler,
steps,
} = state.generation;
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
const rangeNode = buildRangeNode(state);
const iterateNode = buildIterateNode();
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
// Assemble!
const graph = {
nodes: {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
edges,
let graph: NonNullableGraph = {
nodes: {},
edges: [],
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
// Create the conditioning, t2l and l2i nodes
const positiveConditioningNode: CompelInvocation = {
id: POSITIVE_CONDITIONING,
type: 'compel',
prompt: positivePrompt,
model,
};
const negativeConditioningNode: CompelInvocation = {
id: NEGATIVE_CONDITIONING,
type: 'compel',
prompt: negativePrompt,
model,
};
const textToLatentsNode: TextToLatentsInvocation = {
id: TEXT_TO_LATENTS,
type: 't2l',
cfg_scale,
model,
scheduler,
steps,
};
const latentsToImageNode: LatentsToImageInvocation = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
model,
};
// Add to the graph
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
graph.nodes[TEXT_TO_LATENTS] = textToLatentsNode;
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
// Connect them
graph.edges.push({
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'positive_conditioning',
},
});
graph.edges.push({
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
destination: {
node_id: TEXT_TO_LATENTS,
field: 'negative_conditioning',
},
});
graph.edges.push({
source: { node_id: TEXT_TO_LATENTS, field: 'latents' },
destination: {
node_id: LATENTS_TO_IMAGE,
field: 'latents',
},
});
// Create and add the noise nodes
graph = addNoiseNodes(graph, TEXT_TO_LATENTS, state);
return graph;
};

View File

@ -0,0 +1,208 @@
import { RootState } from 'app/store/store';
import {
IterateInvocation,
NoiseInvocation,
RandomIntInvocation,
RangeOfSizeInvocation,
} from 'services/api';
import { NonNullableGraph } from 'features/nodes/types/types';
import { cloneDeep } from 'lodash-es';
const NOISE = 'noise';
const RANDOM_INT = 'rand_int';
const RANGE_OF_SIZE = 'range_of_size';
const ITERATE = 'iterate';
/**
* Adds the appropriate noise nodes to a linear UI t2l or l2l graph.
*
* @param graph The graph to add the noise nodes to.
* @param baseNodeId The id of the base node to connect the noise nodes to.
* @param state The app state..
*/
export const addNoiseNodes = (
graph: NonNullableGraph,
baseNodeId: string,
state: RootState
): NonNullableGraph => {
const graphClone = cloneDeep(graph);
// Create and add the noise nodes
const { width, height, seed, iterations, shouldRandomizeSeed } =
state.generation;
// Single iteration, explicit seed
if (!shouldRandomizeSeed && iterations === 1) {
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
seed: seed,
width,
height,
};
graphClone.nodes[NOISE] = noiseNode;
// Connect them
graphClone.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: baseNodeId,
field: 'noise',
},
});
}
// Single iteration, random seed
if (shouldRandomizeSeed && iterations === 1) {
// TODO: This assumes the `high` value is the max seed value
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
width,
height,
};
graphClone.nodes[RANDOM_INT] = randomIntNode;
graphClone.nodes[NOISE] = noiseNode;
graphClone.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: {
node_id: NOISE,
field: 'seed',
},
});
graphClone.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: baseNodeId,
field: 'noise',
},
});
}
// Multiple iterations, explicit seed
if (!shouldRandomizeSeed && iterations > 1) {
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
start: seed,
size: iterations,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
width,
height,
};
graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graphClone.nodes[ITERATE] = iterateNode;
graphClone.nodes[NOISE] = noiseNode;
graphClone.edges.push({
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graphClone.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
graphClone.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: baseNodeId,
field: 'noise',
},
});
}
// Multiple iterations, random seed
if (shouldRandomizeSeed && iterations > 1) {
// TODO: This assumes the `high` value is the max seed value
const randomIntNode: RandomIntInvocation = {
id: RANDOM_INT,
type: 'rand_int',
};
const rangeOfSizeNode: RangeOfSizeInvocation = {
id: RANGE_OF_SIZE,
type: 'range_of_size',
size: iterations,
};
const iterateNode: IterateInvocation = {
id: ITERATE,
type: 'iterate',
};
const noiseNode: NoiseInvocation = {
id: NOISE,
type: 'noise',
width,
height,
};
graphClone.nodes[RANDOM_INT] = randomIntNode;
graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
graphClone.nodes[ITERATE] = iterateNode;
graphClone.nodes[NOISE] = noiseNode;
graphClone.edges.push({
source: { node_id: RANDOM_INT, field: 'a' },
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
});
graphClone.edges.push({
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
destination: {
node_id: ITERATE,
field: 'collection',
},
});
graphClone.edges.push({
source: {
node_id: ITERATE,
field: 'item',
},
destination: {
node_id: NOISE,
field: 'seed',
},
});
graphClone.edges.push({
source: { node_id: NOISE, field: 'noise' },
destination: {
node_id: baseNodeId,
field: 'noise',
},
});
}
return graphClone;
};

View File

@ -0,0 +1,26 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store/store';
import { CompelInvocation } from 'services/api';
import { O } from 'ts-toolbelt';
export const buildCompelNode = (
prompt: string,
state: RootState,
overrides: O.Partial<CompelInvocation, 'deep'> = {}
): CompelInvocation => {
const nodeId = uuidv4();
const { generation } = state;
const { model } = generation;
const compelNode: CompelInvocation = {
id: nodeId,
type: 'compel',
prompt,
model,
};
Object.assign(compelNode, overrides);
return compelNode;
};

View File

@ -18,8 +18,8 @@ export const buildImg2ImgNode = (
const activeTabName = activeTabNameSelector(state);
const {
prompt,
negativePrompt,
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,

View File

@ -13,8 +13,8 @@ export const buildInpaintNode = (
const activeTabName = activeTabNameSelector(state);
const {
prompt,
negativePrompt,
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,

View File

@ -11,8 +11,8 @@ export const buildTxt2ImgNode = (
const { generation } = state;
const {
prompt,
negativePrompt,
positivePrompt: prompt,
negativePrompt: negativePrompt,
seed,
steps,
width,

View File

@ -13,7 +13,7 @@ import {
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
const invocationDenylist = ['Graph', 'LoadImage'];
const invocationDenylist = ['Graph'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now

View File

@ -8,7 +8,7 @@ import { readinessSelector } from 'app/selectors/readinessSelector';
import {
GenerationState,
clampSymmetrySteps,
setPrompt,
setPositivePrompt,
} from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
@ -22,7 +22,7 @@ const promptInputSelector = createSelector(
[(state: RootState) => state.generation, activeTabNameSelector],
(parameters: GenerationState, activeTabName) => {
return {
prompt: parameters.prompt,
prompt: parameters.positivePrompt,
activeTabName,
};
},
@ -46,7 +46,7 @@ const ParamPositiveConditioning = () => {
const { t } = useTranslation();
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
dispatch(setPrompt(e.target.value));
dispatch(setPositivePrompt(e.target.value));
};
useHotkeys(

View File

@ -57,7 +57,7 @@ const InitialImagePreview = () => {
const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
dispatch(initialImageSelected({ name, type }));
dispatch(initialImageSelected({ image_name: name, image_type: type }));
},
[dispatch]
);
@ -73,10 +73,10 @@ const InitialImagePreview = () => {
}}
onDrop={handleDrop}
>
{initialImage?.url && (
{initialImage?.image_url && (
<>
<Image
src={getUrl(initialImage?.url)}
src={getUrl(initialImage?.image_url)}
fallbackStrategy="beforeLoadOrError"
fallback={<ImageFallbackSpinner />}
onError={handleError}
@ -92,7 +92,7 @@ const InitialImagePreview = () => {
<ImageMetadataOverlay image={initialImage} />
</>
)}
{!initialImage?.url && (
{!initialImage?.image_url && (
<Icon
as={FaImage}
sx={{

View File

@ -7,9 +7,9 @@ import { allParametersSet, setSeed } from '../store/generationSlice';
import { isImageField } from 'services/types/guards';
import { NUMPY_RAND_MAX } from 'app/constants';
import { initialImageSelected } from '../store/actions';
import { Image } from 'app/types/invokeai';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { useAppToaster } from 'app/components/Toaster';
import { ImageDTO } from 'services/api';
export const useParameters = () => {
const dispatch = useAppDispatch();
@ -88,9 +88,7 @@ export const useParameters = () => {
return;
}
dispatch(
initialImageSelected({ name: image.image_name, type: image.image_type })
);
dispatch(initialImageSelected(image));
toaster({
title: t('toast.initialImageSet'),
status: 'info',
@ -105,21 +103,21 @@ export const useParameters = () => {
* Sets image as initial image with toast
*/
const sendToImageToImage = useCallback(
(image: Image) => {
dispatch(initialImageSelected({ name: image.name, type: image.type }));
(image: ImageDTO) => {
dispatch(initialImageSelected(image));
},
[dispatch]
);
const recallAllParameters = useCallback(
(image: Image | undefined) => {
const type = image?.metadata?.invokeai?.node?.type;
(image: ImageDTO | undefined) => {
const type = image?.metadata?.type;
if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) {
dispatch(allParametersSet(image));
if (image?.metadata?.invokeai?.node?.type === 'img2img') {
if (image?.metadata?.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (image?.metadata?.invokeai?.node?.type === 'txt2img') {
} else if (image?.metadata?.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}

View File

@ -3,7 +3,7 @@ import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
import * as InvokeAI from 'app/types/invokeai';
import promptToString from 'common/util/promptToString';
import { useAppDispatch } from 'app/store/storeHooks';
import { setNegativePrompt, setPrompt } from '../store/generationSlice';
import { setNegativePrompt, setPositivePrompt } from '../store/generationSlice';
import { useCallback } from 'react';
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
@ -20,7 +20,7 @@ const useSetBothPrompts = () => {
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
dispatch(setPrompt(prompt));
dispatch(setPositivePrompt(prompt));
dispatch(setNegativePrompt(negativePrompt));
},
[dispatch]

View File

@ -1,12 +1,31 @@
import { createAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { ImageType } from 'services/api';
import { isObject } from 'lodash-es';
import { ImageDTO, ImageType } from 'services/api';
export type SelectedImage = {
name: string;
type: ImageType;
export type ImageNameAndType = {
image_name: string;
image_type: ImageType;
};
export const isImageDTO = (image: any): image is ImageDTO => {
return (
image &&
isObject(image) &&
'image_name' in image &&
image?.image_name !== undefined &&
'image_type' in image &&
image?.image_type !== undefined &&
'image_url' in image &&
image?.image_url !== undefined &&
'thumbnail_url' in image &&
image?.thumbnail_url !== undefined &&
'image_category' in image &&
image?.image_category !== undefined &&
'created_at' in image &&
image?.created_at !== undefined
);
};
export const initialImageSelected = createAction<
Image | SelectedImage | undefined
ImageDTO | ImageNameAndType | undefined
>('generation/initialImageSelected');

View File

@ -6,16 +6,17 @@ import { clamp, sample } from 'lodash-es';
import { setAllParametersReducer } from './setAllParametersReducer';
import { receivedModels } from 'services/thunks/model';
import { Scheduler } from 'app/constants';
import { ImageDTO } from 'services/api';
export interface GenerationState {
cfgScale: number;
height: number;
img2imgStrength: number;
infillMethod: string;
initialImage?: InvokeAI.Image;
initialImage?: ImageDTO;
iterations: number;
perlin: number;
prompt: string;
positivePrompt: string;
negativePrompt: string;
scheduler: Scheduler;
seamBlur: number;
@ -49,7 +50,7 @@ export const initialGenerationState: GenerationState = {
infillMethod: 'patchmatch',
iterations: 1,
perlin: 0,
prompt: '',
positivePrompt: '',
negativePrompt: '',
scheduler: 'lms',
seamBlur: 16,
@ -82,12 +83,15 @@ export const generationSlice = createSlice({
name: 'generation',
initialState,
reducers: {
setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
setPositivePrompt: (
state,
action: PayloadAction<string | InvokeAI.Prompt>
) => {
const newPrompt = action.payload;
if (typeof newPrompt === 'string') {
state.prompt = newPrompt;
state.positivePrompt = newPrompt;
} else {
state.prompt = promptToString(newPrompt);
state.positivePrompt = promptToString(newPrompt);
}
},
setNegativePrompt: (
@ -213,7 +217,7 @@ export const generationSlice = createSlice({
setShouldUseNoiseSettings: (state, action: PayloadAction<boolean>) => {
state.shouldUseNoiseSettings = action.payload;
},
initialImageChanged: (state, action: PayloadAction<InvokeAI.Image>) => {
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
state.initialImage = action.payload;
},
modelSelected: (state, action: PayloadAction<string>) => {
@ -243,7 +247,7 @@ export const {
setInfillMethod,
setIterations,
setPerlin,
setPrompt,
setPositivePrompt,
setNegativePrompt,
setScheduler,
setSeamBlur,

View File

@ -1,12 +1,11 @@
import { Draft, PayloadAction } from '@reduxjs/toolkit';
import { Image } from 'app/types/invokeai';
import { GenerationState } from './generationSlice';
import { ImageToImageInvocation } from 'services/api';
import { ImageDTO, ImageToImageInvocation } from 'services/api';
import { isScheduler } from 'app/constants';
export const setAllParametersReducer = (
state: Draft<GenerationState>,
action: PayloadAction<Image | undefined>
action: PayloadAction<ImageDTO | undefined>
) => {
const node = action.payload?.metadata.invokeai?.node;
@ -32,7 +31,7 @@ export const setAllParametersReducer = (
state.model = String(model);
}
if (prompt !== undefined) {
state.prompt = String(prompt);
state.positivePrompt = String(prompt);
}
if (scheduler !== undefined) {
const schedulerString = String(scheduler);

View File

@ -5,7 +5,6 @@ import { merge } from 'lodash-es';
export const initialConfigState: AppConfig = {
shouldTransformUrls: false,
shouldFetchImages: false,
disabledTabs: [],
disabledFeatures: [],
disabledSDFeatures: [],

View File

@ -3,4 +3,4 @@ import { UIState } from './uiTypes';
/**
* UI slice persist denylist
*/
export const uiPersistDenylist: (keyof UIState)[] = [];
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails'];

View File

@ -28,13 +28,15 @@ export type { GraphExecutionState } from './models/GraphExecutionState';
export type { GraphInvocation } from './models/GraphInvocation';
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { ImageCategory } from './models/ImageCategory';
export type { ImageDTO } from './models/ImageDTO';
export type { ImageField } from './models/ImageField';
export type { ImageMetadata } from './models/ImageMetadata';
export type { ImageOutput } from './models/ImageOutput';
export type { ImageResponse } from './models/ImageResponse';
export type { ImageResponseMetadata } from './models/ImageResponseMetadata';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
export type { ImageType } from './models/ImageType';
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
export type { InfillColorInvocation } from './models/InfillColorInvocation';
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
export type { InfillTileInvocation } from './models/InfillTileInvocation';
@ -42,7 +44,6 @@ export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
export type { InvokeAIMetadata } from './models/InvokeAIMetadata';
export type { IterateInvocation } from './models/IterateInvocation';
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
export type { LatentsField } from './models/LatentsField';
@ -53,21 +54,19 @@ export type { LerpInvocation } from './models/LerpInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MetadataColorField } from './models/MetadataColorField';
export type { MetadataImageField } from './models/MetadataImageField';
export type { MetadataLatentsField } from './models/MetadataLatentsField';
export type { ModelsList } from './models/ModelsList';
export type { MultiplyInvocation } from './models/MultiplyInvocation';
export type { NoiseInvocation } from './models/NoiseInvocation';
export type { NoiseOutput } from './models/NoiseOutput';
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
export type { PaginatedResults_ImageResponse_ } from './models/PaginatedResults_ImageResponse_';
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
export type { ParamIntInvocation } from './models/ParamIntInvocation';
export type { PasteImageInvocation } from './models/PasteImageInvocation';
export type { PromptOutput } from './models/PromptOutput';
export type { RandomIntInvocation } from './models/RandomIntInvocation';
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
export type { RangeInvocation } from './models/RangeInvocation';
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
@ -79,79 +78,6 @@ export type { UpscaleInvocation } from './models/UpscaleInvocation';
export type { VaeRepo } from './models/VaeRepo';
export type { ValidationError } from './models/ValidationError';
export { $AddInvocation } from './schemas/$AddInvocation';
export { $BlurInvocation } from './schemas/$BlurInvocation';
export { $Body_upload_image } from './schemas/$Body_upload_image';
export { $CkptModelInfo } from './schemas/$CkptModelInfo';
export { $CollectInvocation } from './schemas/$CollectInvocation';
export { $CollectInvocationOutput } from './schemas/$CollectInvocationOutput';
export { $ColorField } from './schemas/$ColorField';
export { $CompelInvocation } from './schemas/$CompelInvocation';
export { $CompelOutput } from './schemas/$CompelOutput';
export { $ConditioningField } from './schemas/$ConditioningField';
export { $CreateModelRequest } from './schemas/$CreateModelRequest';
export { $CropImageInvocation } from './schemas/$CropImageInvocation';
export { $CvInpaintInvocation } from './schemas/$CvInpaintInvocation';
export { $DiffusersModelInfo } from './schemas/$DiffusersModelInfo';
export { $DivideInvocation } from './schemas/$DivideInvocation';
export { $Edge } from './schemas/$Edge';
export { $EdgeConnection } from './schemas/$EdgeConnection';
export { $Graph } from './schemas/$Graph';
export { $GraphExecutionState } from './schemas/$GraphExecutionState';
export { $GraphInvocation } from './schemas/$GraphInvocation';
export { $GraphInvocationOutput } from './schemas/$GraphInvocationOutput';
export { $HTTPValidationError } from './schemas/$HTTPValidationError';
export { $ImageField } from './schemas/$ImageField';
export { $ImageOutput } from './schemas/$ImageOutput';
export { $ImageResponse } from './schemas/$ImageResponse';
export { $ImageResponseMetadata } from './schemas/$ImageResponseMetadata';
export { $ImageToImageInvocation } from './schemas/$ImageToImageInvocation';
export { $ImageToLatentsInvocation } from './schemas/$ImageToLatentsInvocation';
export { $ImageType } from './schemas/$ImageType';
export { $InfillColorInvocation } from './schemas/$InfillColorInvocation';
export { $InfillPatchMatchInvocation } from './schemas/$InfillPatchMatchInvocation';
export { $InfillTileInvocation } from './schemas/$InfillTileInvocation';
export { $InpaintInvocation } from './schemas/$InpaintInvocation';
export { $IntCollectionOutput } from './schemas/$IntCollectionOutput';
export { $IntOutput } from './schemas/$IntOutput';
export { $InverseLerpInvocation } from './schemas/$InverseLerpInvocation';
export { $InvokeAIMetadata } from './schemas/$InvokeAIMetadata';
export { $IterateInvocation } from './schemas/$IterateInvocation';
export { $IterateInvocationOutput } from './schemas/$IterateInvocationOutput';
export { $LatentsField } from './schemas/$LatentsField';
export { $LatentsOutput } from './schemas/$LatentsOutput';
export { $LatentsToImageInvocation } from './schemas/$LatentsToImageInvocation';
export { $LatentsToLatentsInvocation } from './schemas/$LatentsToLatentsInvocation';
export { $LerpInvocation } from './schemas/$LerpInvocation';
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
export { $MaskOutput } from './schemas/$MaskOutput';
export { $MetadataColorField } from './schemas/$MetadataColorField';
export { $MetadataImageField } from './schemas/$MetadataImageField';
export { $MetadataLatentsField } from './schemas/$MetadataLatentsField';
export { $ModelsList } from './schemas/$ModelsList';
export { $MultiplyInvocation } from './schemas/$MultiplyInvocation';
export { $NoiseInvocation } from './schemas/$NoiseInvocation';
export { $NoiseOutput } from './schemas/$NoiseOutput';
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';
export { $PaginatedResults_ImageResponse_ } from './schemas/$PaginatedResults_ImageResponse_';
export { $ParamIntInvocation } from './schemas/$ParamIntInvocation';
export { $PasteImageInvocation } from './schemas/$PasteImageInvocation';
export { $PromptOutput } from './schemas/$PromptOutput';
export { $RandomIntInvocation } from './schemas/$RandomIntInvocation';
export { $RandomRangeInvocation } from './schemas/$RandomRangeInvocation';
export { $RangeInvocation } from './schemas/$RangeInvocation';
export { $ResizeLatentsInvocation } from './schemas/$ResizeLatentsInvocation';
export { $RestoreFaceInvocation } from './schemas/$RestoreFaceInvocation';
export { $ScaleLatentsInvocation } from './schemas/$ScaleLatentsInvocation';
export { $ShowImageInvocation } from './schemas/$ShowImageInvocation';
export { $SubtractInvocation } from './schemas/$SubtractInvocation';
export { $TextToImageInvocation } from './schemas/$TextToImageInvocation';
export { $TextToLatentsInvocation } from './schemas/$TextToLatentsInvocation';
export { $UpscaleInvocation } from './schemas/$UpscaleInvocation';
export { $VaeRepo } from './schemas/$VaeRepo';
export { $ValidationError } from './schemas/$ValidationError';
export { ImagesService } from './services/ImagesService';
export { ModelsService } from './services/ModelsService';
export { SessionsService } from './services/SessionsService';

View File

@ -31,6 +31,7 @@ import type { PasteImageInvocation } from './PasteImageInvocation';
import type { RandomIntInvocation } from './RandomIntInvocation';
import type { RandomRangeInvocation } from './RandomRangeInvocation';
import type { RangeInvocation } from './RangeInvocation';
import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
@ -48,7 +49,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/

View File

@ -42,7 +42,7 @@ export type GraphExecutionState = {
/**
* The results of node executions
*/
results: Record<string, (ImageOutput | MaskOutput | CompelOutput | LatentsOutput | NoiseOutput | IntOutput | PromptOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
results: Record<string, (ImageOutput | MaskOutput | PromptOutput | CompelOutput | IntOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
/**
* Errors raised when executing nodes
*/

View File

@ -0,0 +1,8 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* The category of an image. Use ImageCategory.OTHER for non-default categories.
*/
export type ImageCategory = 'general' | 'control' | 'other';

View File

@ -0,0 +1,66 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageCategory } from './ImageCategory';
import type { ImageMetadata } from './ImageMetadata';
import type { ImageType } from './ImageType';
/**
* Deserialized image record, enriched for the frontend with URLs.
*/
export type ImageDTO = {
/**
* The unique name of the image.
*/
image_name: string;
/**
* The type of the image.
*/
image_type: ImageType;
/**
* The URL of the image.
*/
image_url: string;
/**
* The URL of the image's thumbnail.
*/
thumbnail_url: string;
/**
* The category of the image.
*/
image_category: ImageCategory;
/**
* The width of the image in px.
*/
width: number;
/**
* The height of the image in px.
*/
height: number;
/**
* The created timestamp of the image.
*/
created_at: string;
/**
* The updated timestamp of the image.
*/
updated_at: string;
/**
* The deleted timestamp of the image.
*/
deleted_at?: string;
/**
* The session ID that generated this image, if it is a generated image.
*/
session_id?: string;
/**
* The node ID that generated this image, if it is a generated image.
*/
node_id?: string;
/**
* A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.
*/
metadata?: ImageMetadata;
};

View File

@ -0,0 +1,81 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
/**
* Core generation metadata for an image/tensor generated in InvokeAI.
*
* Also includes any metadata from the image's PNG tEXt chunks.
*
* Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
* of a given node.
*
* Full metadata may be accessed by querying for the session in the `graph_executions` table.
*/
export type ImageMetadata = {
/**
* The type of the ancestor node of the image output node.
*/
type?: string;
/**
* The positive conditioning.
*/
positive_conditioning?: string;
/**
* The negative conditioning.
*/
negative_conditioning?: string;
/**
* Width of the image/latents in pixels.
*/
width?: number;
/**
* Height of the image/latents in pixels.
*/
height?: number;
/**
* The seed used for noise generation.
*/
seed?: number;
/**
* The classifier-free guidance scale.
*/
cfg_scale?: number;
/**
* The number of steps used for inference.
*/
steps?: number;
/**
* The scheduler used for inference.
*/
scheduler?: string;
/**
* The model used for inference.
*/
model?: string;
/**
* The strength used for image-to-image/latents-to-latents.
*/
strength?: number;
/**
* The ID of the initial latents.
*/
latents?: string;
/**
* The VAE used for decoding.
*/
vae?: string;
/**
* The UNet used dor inference.
*/
unet?: string;
/**
* The CLIP Encoder used for conditioning.
*/
clip?: string;
/**
* Uploaded image metadata, extracted from the PNG tEXt chunk.
*/
extra?: string;
};

View File

@ -8,7 +8,7 @@ import type { ImageField } from './ImageField';
* Base class for invocations that output an image
*/
export type ImageOutput = {
type: 'image';
type: 'image_output';
/**
* The output image
*/

View File

@ -1,33 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageResponseMetadata } from './ImageResponseMetadata';
import type { ImageType } from './ImageType';
/**
* The response type for images
*/
export type ImageResponse = {
/**
* The type of the image
*/
image_type: ImageType;
/**
* The name of the image
*/
image_name: string;
/**
* The url of the image
*/
image_url: string;
/**
* The url of the image's thumbnail
*/
thumbnail_url: string;
/**
* The image's metadata
*/
metadata: ImageResponseMetadata;
};

View File

@ -1,28 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { InvokeAIMetadata } from './InvokeAIMetadata';
/**
* An image's metadata. Used only in HTTP responses.
*/
export type ImageResponseMetadata = {
/**
* The creation timestamp of the image
*/
created: number;
/**
* The width of the image in pixels
*/
width: number;
/**
* The height of the image in pixels
*/
height: number;
/**
* The image's InvokeAI-specific metadata
*/
invokeai?: InvokeAIMetadata;
};

View File

@ -3,6 +3,6 @@
/* eslint-disable */
/**
* An enumeration.
* The type of an image.
*/
export type ImageType = 'results' | 'intermediates' | 'uploads';
export type ImageType = 'results' | 'uploads' | 'intermediates';

View File

@ -0,0 +1,28 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageType } from './ImageType';
/**
* The URLs for an image and its thumbnail.
*/
export type ImageUrlsDTO = {
/**
* The unique name of the image.
*/
image_name: string;
/**
* The type of the image.
*/
image_type: ImageType;
/**
* The URL of the image.
*/
image_url: string;
/**
* The URL of the image's thumbnail.
*/
thumbnail_url: string;
};

View File

@ -1,13 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { MetadataColorField } from './MetadataColorField';
import type { MetadataImageField } from './MetadataImageField';
import type { MetadataLatentsField } from './MetadataLatentsField';
export type InvokeAIMetadata = {
session_id?: string;
node?: Record<string, (string | number | boolean | MetadataImageField | MetadataLatentsField | MetadataColorField)>;
};

View File

@ -13,5 +13,13 @@ export type MaskOutput = {
* The output mask
*/
mask: ImageField;
/**
* The width of the mask in pixels
*/
width?: number;
/**
* The height of the mask in pixels
*/
height?: number;
};

View File

@ -1,11 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type MetadataColorField = {
'r': number;
'g': number;
'b': number;
'a': number;
};

View File

@ -1,11 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageType } from './ImageType';
export type MetadataImageField = {
image_type: ImageType;
image_name: string;
};

View File

@ -1,8 +0,0 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type MetadataLatentsField = {
latents_name: string;
};

View File

@ -2,16 +2,16 @@
/* tslint:disable */
/* eslint-disable */
import type { ImageResponse } from './ImageResponse';
import type { ImageDTO } from './ImageDTO';
/**
* Paginated results
*/
export type PaginatedResults_ImageResponse_ = {
export type PaginatedResults_ImageDTO_ = {
/**
* Items
*/
items: Array<ImageResponse>;
items: Array<ImageDTO>;
/**
* Current Page
*/

Some files were not shown because too many files have changed in this diff Show More