mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into release/make-web-dist-startable
This commit is contained in:
commit
9110838fe4
1
.github/workflows/test-invoke-pip.yml
vendored
1
.github/workflows/test-invoke-pip.yml
vendored
@ -125,6 +125,7 @@ jobs:
|
|||||||
--no-nsfw_checker
|
--no-nsfw_checker
|
||||||
--precision=float32
|
--precision=float32
|
||||||
--always_use_cpu
|
--always_use_cpu
|
||||||
|
--use_memory_db
|
||||||
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
--outdir ${{ env.INVOKEAI_OUTDIR }}/${{ matrix.python-version }}/${{ matrix.pytorch }}
|
||||||
--from_file ${{ env.TEST_PROMPTS }}
|
--from_file ${{ env.TEST_PROMPTS }}
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ manager, please follow these steps:
|
|||||||
9. Run the command-line- or the web- interface:
|
9. Run the command-line- or the web- interface:
|
||||||
|
|
||||||
From within INVOKEAI_ROOT, activate the environment
|
From within INVOKEAI_ROOT, activate the environment
|
||||||
(with `source .venv/bin/activate` or `.venv\scripts\activate), and then run
|
(with `source .venv/bin/activate` or `.venv\scripts\activate`), and then run
|
||||||
the script `invokeai`. If the virtual environment you selected is NOT inside
|
the script `invokeai`. If the virtual environment you selected is NOT inside
|
||||||
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
|
INVOKEAI_ROOT, then you must specify the path to the root directory by adding
|
||||||
`--root_dir \path\to\invokeai` to the commands below:
|
`--root_dir \path\to\invokeai` to the commands below:
|
||||||
|
@ -1,22 +1,24 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from logging import Logger
|
||||||
import os
|
import os
|
||||||
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
import invokeai.backend.util.logging as logger
|
from invokeai.app.services.images import ImageService
|
||||||
from typing import types
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ..services.model_manager_initializer import get_model_manager
|
from ..services.model_manager_initializer import get_model_manager
|
||||||
from ..services.restoration_services import RestorationServices
|
from ..services.restoration_services import RestorationServices
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_file_storage import DiskImageFileStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.metadata import PngMetadataService
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -36,42 +38,59 @@ def check_internet() -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class ApiDependencies:
|
class ApiDependencies:
|
||||||
"""Contains and initializes all dependencies for the API"""
|
"""Contains and initializes all dependencies for the API"""
|
||||||
|
|
||||||
invoker: Invoker = None
|
invoker: Invoker = None
|
||||||
|
|
||||||
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
|
@staticmethod
|
||||||
|
def initialize(config, event_handler_id: int, logger: Logger = logger):
|
||||||
logger.info(f"Internet connectivity is {config.internet_available}")
|
logger.info(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
|
|
||||||
|
|
||||||
metadata = PngMetadataService()
|
|
||||||
|
|
||||||
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=db_location, table_name="graph_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
urls = LocalUrlService()
|
||||||
|
metadata = CoreMetadataService()
|
||||||
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
|
latents = ForwardCacheLatentsStorage(
|
||||||
|
DiskLatentsStorage(f"{output_folder}/latents")
|
||||||
|
)
|
||||||
|
|
||||||
|
images = ImageService(
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
image_file_storage=image_file_storage,
|
||||||
|
metadata=metadata,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=get_model_manager(config,logger),
|
model_manager=get_model_manager(config, logger),
|
||||||
events=events,
|
events=events,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
images=images,
|
images=images,
|
||||||
metadata=metadata,
|
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
),
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger),
|
restoration=RestorationServices(config, logger),
|
||||||
configuration=config,
|
configuration=config,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
)
|
)
|
||||||
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType
|
from invokeai.app.models.image import ImageType
|
||||||
from invokeai.app.services.metadata import InvokeAIMetadata
|
|
||||||
|
|
||||||
|
|
||||||
class ImageResponseMetadata(BaseModel):
|
class ImageResponseMetadata(BaseModel):
|
||||||
@ -11,9 +10,9 @@ class ImageResponseMetadata(BaseModel):
|
|||||||
created: int = Field(description="The creation timestamp of the image")
|
created: int = Field(description="The creation timestamp of the image")
|
||||||
width: int = Field(description="The width of the image in pixels")
|
width: int = Field(description="The width of the image in pixels")
|
||||||
height: int = Field(description="The height of the image in pixels")
|
height: int = Field(description="The height of the image in pixels")
|
||||||
invokeai: Optional[InvokeAIMetadata] = Field(
|
# invokeai: Optional[InvokeAIMetadata] = Field(
|
||||||
description="The image's InvokeAI-specific metadata"
|
# description="The image's InvokeAI-specific metadata"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
class ImageResponse(BaseModel):
|
class ImageResponse(BaseModel):
|
||||||
|
@ -1,148 +1,215 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
import io
|
import io
|
||||||
from datetime import datetime, timezone
|
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
|
|
||||||
from fastapi.responses import FileResponse, Response
|
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.api.models.images import (
|
from invokeai.app.models.image import (
|
||||||
ImageResponse,
|
ImageCategory,
|
||||||
ImageResponseMetadata,
|
ImageType,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
images_router = APIRouter(prefix="/v1/images", tags=["images"])
|
||||||
|
|
||||||
|
|
||||||
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
|
|
||||||
async def get_image(
|
|
||||||
image_type: ImageType = Path(description="The type of image to get"),
|
|
||||||
image_name: str = Path(description="The name of the image to get"),
|
|
||||||
) -> FileResponse:
|
|
||||||
"""Gets an image"""
|
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
|
||||||
async def delete_image(
|
|
||||||
image_type: ImageType = Path(description="The type of image to delete"),
|
|
||||||
image_name: str = Path(description="The name of the image to delete"),
|
|
||||||
) -> None:
|
|
||||||
"""Deletes an image and its thumbnail"""
|
|
||||||
|
|
||||||
ApiDependencies.invoker.services.images.delete(
|
|
||||||
image_type=image_type, image_name=image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
|
||||||
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
|
|
||||||
)
|
|
||||||
async def get_thumbnail(
|
|
||||||
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
|
|
||||||
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
|
|
||||||
) -> FileResponse | Response:
|
|
||||||
"""Gets a thumbnail"""
|
|
||||||
|
|
||||||
path = ApiDependencies.invoker.services.images.get_path(
|
|
||||||
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if ApiDependencies.invoker.services.images.validate_path(path):
|
|
||||||
return FileResponse(path)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=404)
|
|
||||||
|
|
||||||
|
|
||||||
@images_router.post(
|
@images_router.post(
|
||||||
"/uploads/",
|
"/",
|
||||||
operation_id="upload_image",
|
operation_id="upload_image",
|
||||||
responses={
|
responses={
|
||||||
201: {
|
201: {"description": "The image was uploaded successfully"},
|
||||||
"description": "The image was uploaded successfully",
|
|
||||||
"model": ImageResponse,
|
|
||||||
},
|
|
||||||
415: {"description": "Image upload failed"},
|
415: {"description": "Image upload failed"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
|
response_model=ImageDTO,
|
||||||
)
|
)
|
||||||
async def upload_image(
|
async def upload_image(
|
||||||
file: UploadFile, image_type: ImageType, request: Request, response: Response
|
file: UploadFile,
|
||||||
) -> ImageResponse:
|
image_type: ImageType,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Uploads an image"""
|
||||||
if not file.content_type.startswith("image"):
|
if not file.content_type.startswith("image"):
|
||||||
raise HTTPException(status_code=415, detail="Not an image")
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
contents = await file.read()
|
contents = await file.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(io.BytesIO(contents))
|
pil_image = Image.open(io.BytesIO(contents))
|
||||||
except:
|
except:
|
||||||
# Error opening the image
|
# Error opening the image
|
||||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||||
|
|
||||||
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
|
try:
|
||||||
|
image_dto = ApiDependencies.invoker.services.images.create(
|
||||||
|
pil_image,
|
||||||
|
image_type,
|
||||||
|
image_category,
|
||||||
|
)
|
||||||
|
|
||||||
saved_image = ApiDependencies.invoker.services.images.save(
|
response.status_code = 201
|
||||||
image_type, filename, img
|
response.headers["Location"] = image_dto.image_url
|
||||||
)
|
|
||||||
|
|
||||||
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
|
return image_dto
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to create image")
|
||||||
|
|
||||||
image_url = ApiDependencies.invoker.services.images.get_uri(
|
|
||||||
image_type, saved_image.image_name
|
|
||||||
)
|
|
||||||
|
|
||||||
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
|
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
|
||||||
image_type, saved_image.image_name, True
|
async def delete_image(
|
||||||
)
|
image_type: ImageType = Query(description="The type of image to delete"),
|
||||||
|
image_name: str = Path(description="The name of the image to delete"),
|
||||||
|
) -> None:
|
||||||
|
"""Deletes an image"""
|
||||||
|
|
||||||
res = ImageResponse(
|
try:
|
||||||
image_type=image_type,
|
ApiDependencies.invoker.services.images.delete(image_type, image_name)
|
||||||
image_name=saved_image.image_name,
|
except Exception as e:
|
||||||
image_url=image_url,
|
# TODO: Does this need any exception handling at all?
|
||||||
thumbnail_url=thumbnail_url,
|
pass
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=saved_image.created,
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
response.status_code = 201
|
|
||||||
response.headers["Location"] = image_url
|
|
||||||
|
|
||||||
return res
|
@images_router.get(
|
||||||
|
"/{image_type}/{image_name}/metadata",
|
||||||
|
operation_id="get_image_metadata",
|
||||||
|
response_model=ImageDTO,
|
||||||
|
)
|
||||||
|
async def get_image_metadata(
|
||||||
|
image_type: ImageType = Path(description="The type of image to get"),
|
||||||
|
image_name: str = Path(description="The name of image to get"),
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Gets an image's metadata"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return ApiDependencies.invoker.services.images.get_dto(
|
||||||
|
image_type, image_name
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/{image_type}/{image_name}",
|
||||||
|
operation_id="get_image_full",
|
||||||
|
response_class=Response,
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "Return the full-resolution image",
|
||||||
|
"content": {"image/png": {}},
|
||||||
|
},
|
||||||
|
404: {"description": "Image not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_image_full(
|
||||||
|
image_type: ImageType = Path(
|
||||||
|
description="The type of full-resolution image file to get"
|
||||||
|
),
|
||||||
|
image_name: str = Path(description="The name of full-resolution image file to get"),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
|
image_type, image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path,
|
||||||
|
media_type="image/png",
|
||||||
|
filename=image_name,
|
||||||
|
content_disposition_type="inline",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/{image_type}/{image_name}/thumbnail",
|
||||||
|
operation_id="get_image_thumbnail",
|
||||||
|
response_class=Response,
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "Return the image thumbnail",
|
||||||
|
"content": {"image/webp": {}},
|
||||||
|
},
|
||||||
|
404: {"description": "Image not found"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def get_image_thumbnail(
|
||||||
|
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
|
||||||
|
image_name: str = Path(description="The name of thumbnail image file to get"),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Gets a thumbnail image file"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = ApiDependencies.invoker.services.images.get_path(
|
||||||
|
image_type, image_name, thumbnail=True
|
||||||
|
)
|
||||||
|
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path, media_type="image/webp", content_disposition_type="inline"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@images_router.get(
|
||||||
|
"/{image_type}/{image_name}/urls",
|
||||||
|
operation_id="get_image_urls",
|
||||||
|
response_model=ImageUrlsDTO,
|
||||||
|
)
|
||||||
|
async def get_image_urls(
|
||||||
|
image_type: ImageType = Path(description="The type of the image whose URL to get"),
|
||||||
|
image_name: str = Path(description="The name of the image whose URL to get"),
|
||||||
|
) -> ImageUrlsDTO:
|
||||||
|
"""Gets an image and thumbnail URL"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
|
image_type, image_name
|
||||||
|
)
|
||||||
|
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||||
|
image_type, image_name, thumbnail=True
|
||||||
|
)
|
||||||
|
return ImageUrlsDTO(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_images",
|
operation_id="list_images_with_metadata",
|
||||||
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
response_model=PaginatedResults[ImageDTO],
|
||||||
)
|
)
|
||||||
async def list_images(
|
async def list_images_with_metadata(
|
||||||
image_type: ImageType = Query(
|
image_type: ImageType = Query(description="The type of images to list"),
|
||||||
default=ImageType.RESULT, description="The type of images to get"
|
image_category: ImageCategory = Query(description="The kind of images to list"),
|
||||||
|
page: int = Query(default=0, description="The page of image metadata to get"),
|
||||||
|
per_page: int = Query(
|
||||||
|
default=10, description="The number of image metadata per page"
|
||||||
),
|
),
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
) -> PaginatedResults[ImageDTO]:
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
"""Gets a list of images with metadata"""
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a list of images"""
|
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||||
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
|
image_type,
|
||||||
return result
|
image_category,
|
||||||
|
page,
|
||||||
|
per_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_dtos
|
||||||
|
@ -3,8 +3,8 @@ import asyncio
|
|||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
import invokeai.frontend.web as web_dir
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||||
@ -16,11 +16,13 @@ from pathlib import Path
|
|||||||
from pydantic.schema import schema
|
from pydantic.schema import schema
|
||||||
|
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
from .api.routers import images, sessions, models
|
from .api.routers import sessions, models, images
|
||||||
from .api.sockets import SocketIO
|
from .api.sockets import SocketIO
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.config import InvokeAIAppConfig
|
from .services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
logger = InvokeAILogger.getLogger()
|
||||||
|
|
||||||
# Create the app
|
# Create the app
|
||||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||||
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
|
||||||
@ -71,10 +73,9 @@ async def shutdown_event():
|
|||||||
|
|
||||||
app.include_router(sessions.session_router, prefix="/api")
|
app.include_router(sessions.session_router, prefix="/api")
|
||||||
|
|
||||||
app.include_router(images.images_router, prefix="/api")
|
|
||||||
|
|
||||||
app.include_router(models.models_router, prefix="/api")
|
app.include_router(models.models_router, prefix="/api")
|
||||||
|
|
||||||
|
app.include_router(images.images_router, prefix="/api")
|
||||||
|
|
||||||
# Build a custom OpenAPI to include all outputs
|
# Build a custom OpenAPI to include all outputs
|
||||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||||
@ -123,6 +124,7 @@ app.openapi = custom_openapi
|
|||||||
# Override API doc favicons
|
# Override API doc favicons
|
||||||
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/docs", include_in_schema=False)
|
@app.get("/docs", include_in_schema=False)
|
||||||
def overridden_swagger():
|
def overridden_swagger():
|
||||||
return get_swagger_ui_html(
|
return get_swagger_ui_html(
|
||||||
@ -140,8 +142,13 @@ def overridden_redoc():
|
|||||||
redoc_favicon_url="/static/favicon.ico",
|
redoc_favicon_url="/static/favicon.ico",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Must mount *after* the other routes else it borks em
|
# Must mount *after* the other routes else it borks em
|
||||||
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0],"dist"), html=True), name="ui")
|
app.mount("/",
|
||||||
|
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
|
||||||
|
html=True
|
||||||
|
), name="ui"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke_api():
|
def invoke_api():
|
||||||
# Start our own event loop for eventing usage
|
# Start our own event loop for eventing usage
|
||||||
|
@ -13,10 +13,13 @@ from typing import (
|
|||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
|
from invokeai.app.services.images import ImageService
|
||||||
|
from invokeai.app.services.metadata import CoreMetadataService
|
||||||
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
|
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.metadata import PngMetadataService
|
|
||||||
from .services.default_graphs import create_system_graphs
|
from .services.default_graphs import create_system_graphs
|
||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
|
|
||||||
@ -28,7 +31,7 @@ from .services.model_manager_initializer import get_model_manager
|
|||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
|
||||||
from .services.default_graphs import default_text_to_image_graph_id
|
from .services.default_graphs import default_text_to_image_graph_id
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_file_storage import DiskImageFileStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
from .services.invoker import Invoker
|
from .services.invoker import Invoker
|
||||||
@ -188,6 +191,9 @@ def invoke_all(context: CliContext):
|
|||||||
raise SessionError()
|
raise SessionError()
|
||||||
|
|
||||||
|
|
||||||
|
logger = logger.InvokeAILogger.getLogger()
|
||||||
|
|
||||||
|
|
||||||
def invoke_cli():
|
def invoke_cli():
|
||||||
# this gets the basic configuration
|
# this gets the basic configuration
|
||||||
config = get_invokeai_config()
|
config = get_invokeai_config()
|
||||||
@ -206,24 +212,43 @@ def invoke_cli():
|
|||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
metadata = PngMetadataService()
|
|
||||||
|
|
||||||
# TODO: build a file/path manager?
|
# TODO: build a file/path manager?
|
||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
if config.use_memory_db:
|
||||||
|
db_location = ":memory:"
|
||||||
|
else:
|
||||||
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
|
logger.info(f'InvokeAI database location is "{db_location}"')
|
||||||
|
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
filename=db_location, table_name="graph_executions"
|
||||||
|
)
|
||||||
|
|
||||||
|
urls = LocalUrlService()
|
||||||
|
metadata = CoreMetadataService()
|
||||||
|
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||||
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
|
images = ImageService(
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
image_file_storage=image_file_storage,
|
||||||
|
metadata=metadata,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
events=events,
|
events=events,
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
|
||||||
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
|
images=images,
|
||||||
metadata=metadata,
|
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](
|
graph_library=SqliteItemStorage[LibraryGraph](
|
||||||
filename=db_location, table_name="graphs"
|
filename=db_location, table_name="graphs"
|
||||||
),
|
),
|
||||||
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
|
graph_execution_manager=graph_execution_manager,
|
||||||
filename=db_location, table_name="graph_executions"
|
|
||||||
),
|
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
restoration=RestorationServices(config,logger=logger),
|
restoration=RestorationServices(config,logger=logger),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
|
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..services.invocation_services import InvocationServices
|
if TYPE_CHECKING:
|
||||||
|
from ..services.invocation_services import InvocationServices
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ class IntCollectionOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
|
|
||||||
class RangeInvocation(BaseInvocation):
|
class RangeInvocation(BaseInvocation):
|
||||||
"""Creates a range"""
|
"""Creates a range of numbers from start to stop with step"""
|
||||||
|
|
||||||
type: Literal["range"] = "range"
|
type: Literal["range"] = "range"
|
||||||
|
|
||||||
@ -33,12 +33,34 @@ class RangeInvocation(BaseInvocation):
|
|||||||
stop: int = Field(default=10, description="The stop of the range")
|
stop: int = Field(default=10, description="The stop of the range")
|
||||||
step: int = Field(default=1, description="The step of the range")
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
|
@validator("stop")
|
||||||
|
def stop_gt_start(cls, v, values):
|
||||||
|
if "start" in values and v <= values["start"]:
|
||||||
|
raise ValueError("stop must be greater than start")
|
||||||
|
return v
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
return IntCollectionOutput(
|
return IntCollectionOutput(
|
||||||
collection=list(range(self.start, self.stop, self.step))
|
collection=list(range(self.start, self.stop, self.step))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RangeOfSizeInvocation(BaseInvocation):
|
||||||
|
"""Creates a range from start to start + size with step"""
|
||||||
|
|
||||||
|
type: Literal["range_of_size"] = "range_of_size"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
start: int = Field(default=0, description="The start of the range")
|
||||||
|
size: int = Field(default=1, description="The number of values")
|
||||||
|
step: int = Field(default=1, description="The step of the range")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
|
||||||
|
return IntCollectionOutput(
|
||||||
|
collection=list(range(self.start, self.start + self.size, self.step))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RandomRangeInvocation(BaseInvocation):
|
class RandomRangeInvocation(BaseInvocation):
|
||||||
"""Creates a collection of random numbers"""
|
"""Creates a collection of random numbers"""
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
context.services.latents.set(conditioning_name, (c, ec))
|
context.services.latents.save(conditioning_name, (c, ec))
|
||||||
|
|
||||||
return CompelOutput(
|
return CompelOutput(
|
||||||
conditioning=ConditioningField(
|
conditioning=ConditioningField(
|
||||||
|
@ -7,9 +7,9 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class CvInvocationConfig(BaseModel):
|
class CvInvocationConfig(BaseModel):
|
||||||
@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel):
|
|||||||
|
|
||||||
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
||||||
"""Simple inpaint using opencv."""
|
"""Simple inpaint using opencv."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["cv_inpaint"] = "cv_inpaint"
|
type: Literal["cv_inpaint"] = "cv_inpaint"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to inpaint")
|
image: ImageField = Field(default=None, description="The image to inpaint")
|
||||||
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
mask: ImageField = Field(default=None, description="The mask to use when inpainting")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = context.services.images.get(self.mask.image_type, self.mask.image_name)
|
mask = context.services.images.get_pil_image(
|
||||||
|
self.mask.image_type, self.mask.image_name
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to cv image/mask
|
# Convert to cv image/mask
|
||||||
# TODO: consider making these utility functions
|
# TODO: consider making these utility functions
|
||||||
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
cv_mask = numpy.array(ImageOps.invert(mask))
|
cv_mask = numpy.array(ImageOps.invert(mask.convert("L")))
|
||||||
|
|
||||||
# Inpaint
|
# Inpaint
|
||||||
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA)
|
||||||
@ -52,18 +55,19 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
|
|||||||
# TODO: consider making a utility function
|
# TODO: consider making a utility function
|
||||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=image_inpainted,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.INTERMEDIATE,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_inpainted, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=image_inpainted,
|
|
||||||
)
|
|
@ -10,17 +10,21 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
from invokeai.app.models.image import ColorField, ImageField, ImageType
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.generator.inpaint import infill_methods
|
from invokeai.backend.generator.inpaint import infill_methods
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput
|
||||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ..util.step_callback import stable_diffusion_step_callback
|
from ..util.step_callback import stable_diffusion_step_callback
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
|
||||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||||
DEFAULT_INFILL_METHOD = 'patchmatch' if 'patchmatch' in get_args(INFILL_METHODS) else 'tile'
|
DEFAULT_INFILL_METHOD = (
|
||||||
|
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDImageInvocation(BaseModel):
|
class SDImageInvocation(BaseModel):
|
||||||
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
"""Helper class to provide all Stable Diffusion raster image invocations with additional config"""
|
||||||
@ -91,25 +95,21 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
|
|||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
generate_output = next(outputs)
|
generate_output = next(outputs)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
image_dto = context.services.images.create(
|
||||||
# TODO: pre-seed?
|
|
||||||
# TODO: can this return multiple results? Should it?
|
|
||||||
image_type = ImageType.RESULT
|
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(
|
|
||||||
image_type, image_name, generate_output.image, metadata
|
|
||||||
)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=generate_output.image,
|
image=generate_output.image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
node_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -175,26 +175,23 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
generator_output = next(outputs)
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
image_dto = context.services.images.create(
|
||||||
|
image=generator_output.image,
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
image_type=ImageType.RESULT,
|
||||||
# TODO: pre-seed?
|
image_category=ImageCategory.GENERAL,
|
||||||
# TODO: can this return multiple results? Should it?
|
session_id=context.graph_execution_state_id,
|
||||||
image_type = ImageType.RESULT
|
node_id=self.id,
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=result_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
@ -204,16 +201,38 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
mask: Union[ImageField, None] = Field(description="The mask")
|
mask: Union[ImageField, None] = Field(description="The mask")
|
||||||
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
|
||||||
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
|
seam_blur: int = Field(
|
||||||
|
default=16, ge=0, description="The seam inpaint blur radius (px)"
|
||||||
|
)
|
||||||
seam_strength: float = Field(
|
seam_strength: float = Field(
|
||||||
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
default=0.75, gt=0, le=1, description="The seam inpaint strength"
|
||||||
)
|
)
|
||||||
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
|
seam_steps: int = Field(
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
|
default=30, ge=1, description="The number of steps to use for seam inpaint"
|
||||||
infill_method: INFILL_METHODS = Field(default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)")
|
)
|
||||||
inpaint_width: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The width of the inpaint region (px)")
|
tile_size: int = Field(
|
||||||
inpaint_height: Optional[int] = Field(default=None, multiple_of=8, gt=0, description="The height of the inpaint region (px)")
|
default=32, ge=1, description="The tile infill method size (px)"
|
||||||
inpaint_fill: Optional[ColorField] = Field(default=ColorField(r=127, g=127, b=127, a=255), description="The solid infill method color")
|
)
|
||||||
|
infill_method: INFILL_METHODS = Field(
|
||||||
|
default=DEFAULT_INFILL_METHOD,
|
||||||
|
description="The method used to infill empty regions (px)",
|
||||||
|
)
|
||||||
|
inpaint_width: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The width of the inpaint region (px)",
|
||||||
|
)
|
||||||
|
inpaint_height: Optional[int] = Field(
|
||||||
|
default=None,
|
||||||
|
multiple_of=8,
|
||||||
|
gt=0,
|
||||||
|
description="The height of the inpaint region (px)",
|
||||||
|
)
|
||||||
|
inpaint_fill: Optional[ColorField] = Field(
|
||||||
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
|
description="The solid infill method color",
|
||||||
|
)
|
||||||
inpaint_replace: float = Field(
|
inpaint_replace: float = Field(
|
||||||
default=0.0,
|
default=0.0,
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
@ -238,14 +257,14 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
image = (
|
image = (
|
||||||
None
|
None
|
||||||
if self.image is None
|
if self.image is None
|
||||||
else context.services.images.get(
|
else context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else context.services.images.get(self.mask.image_type, self.mask.image_name)
|
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
@ -271,23 +290,19 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# each time it is called. We only need the first one.
|
# each time it is called. We only need the first one.
|
||||||
generator_output = next(outputs)
|
generator_output = next(outputs)
|
||||||
|
|
||||||
result_image = generator_output.image
|
image_dto = context.services.images.create(
|
||||||
|
image=generator_output.image,
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
image_type=ImageType.RESULT,
|
||||||
# TODO: pre-seed?
|
image_category=ImageCategory.GENERAL,
|
||||||
# TODO: can this return multiple results? Should it?
|
session_id=context.graph_execution_state_id,
|
||||||
image_type = ImageType.RESULT
|
node_id=self.id,
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, result_image, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type,
|
height=image_dto.height,
|
||||||
image_name=image_name,
|
|
||||||
image=result_image,
|
|
||||||
)
|
)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import io
|
import io
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from PIL import Image, ImageFilter, ImageOps
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from ..models.image import ImageField, ImageType
|
from ..models.image import ImageCategory, ImageField, ImageType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
"""Base class for invocations that output an image"""
|
"""Base class for invocations that output an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image_output"] = "image_output"
|
||||||
image: ImageField = Field(default=None, description="The output image")
|
image: ImageField = Field(default=None, description="The output image")
|
||||||
width: int = Field(description="The width of the image in pixels")
|
width: int = Field(description="The width of the image in pixels")
|
||||||
height: int = Field(description="The height of the image in pixels")
|
height: int = Field(description="The height of the image in pixels")
|
||||||
@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput):
|
|||||||
schema_extra = {"required": ["type", "image", "width", "height"]}
|
schema_extra = {"required": ["type", "image", "width", "height"]}
|
||||||
|
|
||||||
|
|
||||||
def build_image_output(
|
|
||||||
image_type: ImageType, image_name: str, image: Image.Image
|
|
||||||
) -> ImageOutput:
|
|
||||||
"""Builds an ImageOutput and its ImageField"""
|
|
||||||
image_field = ImageField(
|
|
||||||
image_name=image_name,
|
|
||||||
image_type=image_type,
|
|
||||||
)
|
|
||||||
return ImageOutput(
|
|
||||||
image=image_field,
|
|
||||||
width=image.width,
|
|
||||||
height=image.height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MaskOutput(BaseInvocationOutput):
|
class MaskOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output a mask"""
|
"""Base class for invocations that output a mask"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["mask"] = "mask"
|
type: Literal["mask"] = "mask"
|
||||||
mask: ImageField = Field(default=None, description="The output mask")
|
mask: ImageField = Field(default=None, description="The output mask")
|
||||||
|
width: int = Field(description="The width of the mask in pixels")
|
||||||
|
height: int = Field(description="The height of the mask in pixels")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -80,16 +67,20 @@ class LoadImageInvocation(BaseInvocation):
|
|||||||
type: Literal["load_image"] = "load_image"
|
type: Literal["load_image"] = "load_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image_type: ImageType = Field(description="The type of the image")
|
image: Union[ImageField, None] = Field(
|
||||||
image_name: str = Field(description="The name of the image")
|
default=None, description="The image to load"
|
||||||
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(self.image_type, self.image_name)
|
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name)
|
||||||
|
|
||||||
return build_image_output(
|
return ImageOutput(
|
||||||
image_type=self.image_type,
|
image=ImageField(
|
||||||
image_name=self.image_name,
|
image_name=self.image.image_name,
|
||||||
image=image,
|
image_type=self.image.image_type,
|
||||||
|
),
|
||||||
|
width=image.width,
|
||||||
|
height=image.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -99,10 +90,12 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
type: Literal["show_image"] = "show_image"
|
type: Literal["show_image"] = "show_image"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to show")
|
image: Union[ImageField, None] = Field(
|
||||||
|
default=None, description="The image to show"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
if image:
|
if image:
|
||||||
@ -110,21 +103,24 @@ class ShowImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: how to handle failure?
|
# TODO: how to handle failure?
|
||||||
|
|
||||||
return build_image_output(
|
return ImageOutput(
|
||||||
image_type=self.image.image_type,
|
image=ImageField(
|
||||||
image_name=self.image.image_name,
|
image_name=self.image.image_name,
|
||||||
image=image,
|
image_type=self.image.image_type,
|
||||||
|
),
|
||||||
|
width=image.width,
|
||||||
|
height=image.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["crop"] = "crop"
|
type: Literal["img_crop"] = "img_crop"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to crop")
|
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
||||||
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
||||||
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
||||||
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
||||||
@ -132,7 +128,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -141,49 +137,52 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
image_crop.paste(image, (-self.x, -self.y))
|
image_crop.paste(image, (-self.x, -self.y))
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_crop, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=image_crop,
|
image=image_crop,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Pastes an image into another image."""
|
"""Pastes an image into another image."""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["paste"] = "paste"
|
type: Literal["img_paste"] = "img_paste"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
base_image: ImageField = Field(default=None, description="The base image")
|
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
||||||
image: ImageField = Field(default=None, description="The image to paste")
|
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
||||||
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
||||||
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
||||||
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
base_image = context.services.images.get(
|
base_image = context.services.images.get_pil_image(
|
||||||
self.base_image.image_type, self.base_image.image_name
|
self.base_image.image_type, self.base_image.image_name
|
||||||
)
|
)
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
mask = (
|
mask = (
|
||||||
None
|
None
|
||||||
if self.mask is None
|
if self.mask is None
|
||||||
else ImageOps.invert(
|
else ImageOps.invert(
|
||||||
context.services.images.get(self.mask.image_type, self.mask.image_name)
|
context.services.images.get_pil_image(
|
||||||
|
self.mask.image_type, self.mask.image_name
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
||||||
@ -199,20 +198,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
||||||
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, new_image, metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=new_image,
|
image=new_image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -223,12 +223,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
type: Literal["tomask"] = "tomask"
|
type: Literal["tomask"] = "tomask"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to create the mask from")
|
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
||||||
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -236,33 +236,151 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
if self.invert:
|
if self.invert:
|
||||||
image_mask = ImageOps.invert(image_mask)
|
image_mask = ImageOps.invert(image_mask)
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=image_mask,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.MASK,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return MaskOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
mask=ImageField(
|
||||||
|
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image_mask, metadata)
|
|
||||||
return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name))
|
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_mul"] = "img_mul"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
||||||
|
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image1 = context.services.images.get_pil_image(
|
||||||
|
self.image1.image_type, self.image1.image_name
|
||||||
|
)
|
||||||
|
image2 = context.services.images.get_pil_image(
|
||||||
|
self.image2.image_type, self.image2.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
multiply_image = ImageChops.multiply(image1, image2)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=multiply_image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Gets a channel from an image."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_chan"] = "img_chan"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
||||||
|
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
channel_image = image.getchannel(self.channel)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=channel_image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Converts an image to a different mode."""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_conv"] = "img_conv"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
||||||
|
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_type, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
converted_image = image.convert(self.mode)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=converted_image,
|
||||||
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_type=image_dto.image_type, image_name=image_dto.image_name
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Blurs an image"""
|
"""Blurs an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["blur"] = "blur"
|
type: Literal["img_blur"] = "img_blur"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to blur")
|
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
||||||
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
||||||
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -273,35 +391,38 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
blur_image = image.filter(blur)
|
blur_image = image.filter(blur)
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=blur_image,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, blur_image, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type, image_name=image_name, image=blur_image
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["lerp"] = "lerp"
|
type: Literal["img_lerp"] = "img_lerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,35 +431,38 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=lerp_image,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, lerp_image, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type, image_name=image_name, image=lerp_image
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Inverse linear interpolation of all pixels of an image"""
|
"""Inverse linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["ilerp"] = "ilerp"
|
type: Literal["img_ilerp"] = "img_ilerp"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = Field(default=None, description="The image to lerp")
|
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
||||||
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
||||||
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -352,16 +476,19 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
||||||
|
|
||||||
image_type = ImageType.INTERMEDIATE
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=ilerp_image,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, ilerp_image, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type, image_name=image_name, image=ilerp_image
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
|
||||||
from typing import Literal, Optional, Union, get_args
|
from typing import Literal, Union, get_args
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.invocations.image import ImageOutput, build_image_output
|
from invokeai.app.invocations.image import ImageOutput
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ColorField, ImageField, ImageType
|
from ..models.image import ColorField, ImageCategory, ImageField, ImageType
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
@ -125,36 +125,39 @@ class InfillColorInvocation(BaseInvocation):
|
|||||||
"""Infills transparent areas of an image with a solid color"""
|
"""Infills transparent areas of an image with a solid color"""
|
||||||
|
|
||||||
type: Literal["infill_rgba"] = "infill_rgba"
|
type: Literal["infill_rgba"] = "infill_rgba"
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Union[ImageField, None] = Field(
|
||||||
color: Optional[ColorField] = Field(
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
|
color: ColorField = Field(
|
||||||
default=ColorField(r=127, g=127, b=127, a=255),
|
default=ColorField(r=127, g=127, b=127, a=255),
|
||||||
description="The color to use to infill",
|
description="The color to use to infill",
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||||
infilled = Image.alpha_composite(solid_bg, image)
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||||
|
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=infilled,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type,
|
height=image_dto.height,
|
||||||
image_name=image_name,
|
|
||||||
image=image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -163,7 +166,9 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_tile"] = "infill_tile"
|
type: Literal["infill_tile"] = "infill_tile"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Union[ImageField, None] = Field(
|
||||||
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
||||||
seed: int = Field(
|
seed: int = Field(
|
||||||
ge=0,
|
ge=0,
|
||||||
@ -173,7 +178,7 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -182,20 +187,21 @@ class InfillTileInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
infilled.paste(image, (0, 0), image.split()[-1])
|
infilled.paste(image, (0, 0), image.split()[-1])
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=infilled,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type,
|
height=image_dto.height,
|
||||||
image_name=image_name,
|
|
||||||
image=image,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -204,10 +210,12 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
|
|
||||||
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
||||||
|
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to infill")
|
image: Union[ImageField, None] = Field(
|
||||||
|
default=None, description="The image to infill"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -216,18 +224,19 @@ class InfillPatchMatchInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("PatchMatch is not available on this system")
|
raise ValueError("PatchMatch is not available on this system")
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=infilled,
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
)
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
context.services.images.save(image_type, image_name, infilled, metadata)
|
),
|
||||||
return build_image_output(
|
width=image_dto.width,
|
||||||
image_type=image_type,
|
height=image_dto.height,
|
||||||
image_name=image_name,
|
|
||||||
image=image,
|
|
||||||
)
|
)
|
||||||
|
@ -3,10 +3,11 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
import einops
|
import einops
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, validator
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.util.choose_model import choose_model
|
from invokeai.app.invocations.util.choose_model import choose_model
|
||||||
|
from invokeai.app.models.image import ImageCategory
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
@ -20,9 +21,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, Sta
|
|||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.image_storage import ImageType
|
from ..services.image_file_storage import ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput, build_image_output
|
from .image import ImageField, ImageOutput
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
@ -139,12 +140,17 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@validator("seed", pre=True)
|
||||||
|
def modulo_seed(cls, v):
|
||||||
|
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||||
|
return v % SEED_MAX
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
noise = get_noise(self.width, self.height, device, self.seed)
|
noise = get_noise(self.width, self.height, device, self.seed)
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, noise)
|
context.services.latents.save(name, noise)
|
||||||
return build_noise_output(latents_name=name, latents=noise)
|
return build_noise_output(latents_name=name, latents=noise)
|
||||||
|
|
||||||
|
|
||||||
@ -163,8 +169,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="lms", description="The scheduler to use" )
|
||||||
model: str = Field(default="", description="The model to use (currently ignored)")
|
model: str = Field(default="", description="The model to use (currently ignored)")
|
||||||
seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -199,17 +205,17 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler_name=self.scheduler
|
scheduler_name=self.scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(model, DiffusionPipeline):
|
# if isinstance(model, DiffusionPipeline):
|
||||||
for component in [model.unet, model.vae]:
|
# for component in [model.unet, model.vae]:
|
||||||
configure_model_padding(component,
|
# configure_model_padding(component,
|
||||||
self.seamless,
|
# self.seamless,
|
||||||
self.seamless_axes
|
# self.seamless_axes
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
configure_model_padding(model,
|
# configure_model_padding(model,
|
||||||
self.seamless,
|
# self.seamless,
|
||||||
self.seamless_axes
|
# self.seamless_axes
|
||||||
)
|
# )
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -260,7 +266,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
@ -319,7 +325,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
context.services.latents.set(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
|
|
||||||
@ -356,20 +362,23 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
np_image = model.decode_latents(latents)
|
np_image = model.decode_latents(latents)
|
||||||
image = model.numpy_to_pil(np_image)[0]
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
image_type = ImageType.RESULT
|
|
||||||
image_name = context.services.images.create_name(
|
|
||||||
context.graph_execution_state_id, self.id
|
|
||||||
)
|
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
|
||||||
session_id=context.graph_execution_state_id, node=self
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, image, metadata)
|
image_dto = context.services.images.create(
|
||||||
return build_image_output(
|
image=image,
|
||||||
image_type=image_type, image_name=image_name, image=image
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
node_id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -404,7 +413,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
@ -434,7 +443,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, resized_latents)
|
context.services.latents.save(name, resized_latents)
|
||||||
return build_latents_output(latents_name=name, latents=resized_latents)
|
return build_latents_output(latents_name=name, latents=resized_latents)
|
||||||
|
|
||||||
|
|
||||||
@ -458,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -478,5 +487,6 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.set(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
return build_latents_output(latents_name=name, latents=latents)
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
@ -2,21 +2,23 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class RestoreFaceInvocation(BaseInvocation):
|
class RestoreFaceInvocation(BaseInvocation):
|
||||||
"""Restores faces in an image."""
|
"""Restores faces in an image."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["restore_face"] = "restore_face"
|
type: Literal["restore_face"] = "restore_face"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image")
|
image: Union[ImageField, None] = Field(description="The input image")
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" )
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
@ -26,7 +28,7 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
@ -39,18 +41,19 @@ class RestoreFaceInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
# Results are image and seed, unwrap for now
|
||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=results[0][0],
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.INTERMEDIATE,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=results[0][0]
|
|
||||||
)
|
|
@ -4,22 +4,22 @@ from typing import Literal, Union
|
|||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageField, ImageType
|
from invokeai.app.models.image import ImageCategory, ImageField, ImageType
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
|
||||||
from .image import ImageOutput, build_image_output
|
from .image import ImageOutput
|
||||||
|
|
||||||
|
|
||||||
class UpscaleInvocation(BaseInvocation):
|
class UpscaleInvocation(BaseInvocation):
|
||||||
"""Upscales an image."""
|
"""Upscales an image."""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["upscale"] = "upscale"
|
type: Literal["upscale"] = "upscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
image: Union[ImageField, None] = Field(description="The input image", default=None)
|
||||||
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
strength: float = Field(default=0.75, gt=0, le=1, description="The strength")
|
||||||
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
level: Literal[2, 4] = Field(default=2, description="The upscale level")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -30,7 +30,7 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get(
|
image = context.services.images.get_pil_image(
|
||||||
self.image.image_type, self.image.image_name
|
self.image.image_type, self.image.image_name
|
||||||
)
|
)
|
||||||
results = context.services.restoration.upscale_and_reconstruct(
|
results = context.services.restoration.upscale_and_reconstruct(
|
||||||
@ -43,18 +43,19 @@ class UpscaleInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Results are image and seed, unwrap for now
|
# Results are image and seed, unwrap for now
|
||||||
# TODO: can this return multiple results?
|
# TODO: can this return multiple results?
|
||||||
image_type = ImageType.RESULT
|
image_dto = context.services.images.create(
|
||||||
image_name = context.services.images.create_name(
|
image=results[0][0],
|
||||||
context.graph_execution_state_id, self.id
|
image_type=ImageType.RESULT,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = context.services.metadata.build_metadata(
|
return ImageOutput(
|
||||||
session_id=context.graph_execution_state_id, node=self
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_type=image_dto.image_type,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.services.images.save(image_type, image_name, results[0][0], metadata)
|
|
||||||
return build_image_output(
|
|
||||||
image_type=image_type,
|
|
||||||
image_name=image_name,
|
|
||||||
image=results[0][0]
|
|
||||||
)
|
|
@ -2,19 +2,44 @@ from enum import Enum
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ImageType(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The type of an image."""
|
||||||
|
|
||||||
class ImageType(str, Enum):
|
|
||||||
RESULT = "results"
|
RESULT = "results"
|
||||||
INTERMEDIATE = "intermediates"
|
|
||||||
UPLOAD = "uploads"
|
UPLOAD = "uploads"
|
||||||
|
INTERMEDIATE = "intermediates"
|
||||||
|
|
||||||
|
|
||||||
def is_image_type(obj):
|
class InvalidImageTypeException(ValueError):
|
||||||
try:
|
"""Raised when a provided value is not a valid ImageType.
|
||||||
ImageType(obj)
|
|
||||||
except ValueError:
|
Subclasses `ValueError`.
|
||||||
return False
|
"""
|
||||||
return True
|
|
||||||
|
def __init__(self, message="Invalid image type."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
|
||||||
|
|
||||||
|
GENERAL = "general"
|
||||||
|
CONTROL = "control"
|
||||||
|
MASK = "mask"
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidImageCategoryException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ImageCategory.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid image category."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ImageField(BaseModel):
|
class ImageField(BaseModel):
|
||||||
|
91
invokeai/app/models/metadata.py
Normal file
91
invokeai/app/models/metadata.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Extra, Field, StrictFloat, StrictInt, StrictStr
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMetadata(BaseModel):
|
||||||
|
"""
|
||||||
|
Core generation metadata for an image/tensor generated in InvokeAI.
|
||||||
|
|
||||||
|
Also includes any metadata from the image's PNG tEXt chunks.
|
||||||
|
|
||||||
|
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||||
|
of a given node.
|
||||||
|
|
||||||
|
Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = Extra.allow
|
||||||
|
"""
|
||||||
|
This lets the ImageMetadata class accept arbitrary additional fields. The CoreMetadataService
|
||||||
|
won't add any fields that are not already defined, but other a different metadata service
|
||||||
|
implementation might.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Optional[StrictStr] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The type of the ancestor node of the image output node.",
|
||||||
|
)
|
||||||
|
"""The type of the ancestor node of the image output node."""
|
||||||
|
positive_conditioning: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The positive conditioning."
|
||||||
|
)
|
||||||
|
"""The positive conditioning"""
|
||||||
|
negative_conditioning: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The negative conditioning."
|
||||||
|
)
|
||||||
|
"""The negative conditioning"""
|
||||||
|
width: Optional[StrictInt] = Field(
|
||||||
|
default=None, description="Width of the image/latents in pixels."
|
||||||
|
)
|
||||||
|
"""Width of the image/latents in pixels"""
|
||||||
|
height: Optional[StrictInt] = Field(
|
||||||
|
default=None, description="Height of the image/latents in pixels."
|
||||||
|
)
|
||||||
|
"""Height of the image/latents in pixels"""
|
||||||
|
seed: Optional[StrictInt] = Field(
|
||||||
|
default=None, description="The seed used for noise generation."
|
||||||
|
)
|
||||||
|
"""The seed used for noise generation"""
|
||||||
|
cfg_scale: Optional[StrictFloat] = Field(
|
||||||
|
default=None, description="The classifier-free guidance scale."
|
||||||
|
)
|
||||||
|
"""The classifier-free guidance scale"""
|
||||||
|
steps: Optional[StrictInt] = Field(
|
||||||
|
default=None, description="The number of steps used for inference."
|
||||||
|
)
|
||||||
|
"""The number of steps used for inference"""
|
||||||
|
scheduler: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The scheduler used for inference."
|
||||||
|
)
|
||||||
|
"""The scheduler used for inference"""
|
||||||
|
model: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The model used for inference."
|
||||||
|
)
|
||||||
|
"""The model used for inference"""
|
||||||
|
strength: Optional[StrictFloat] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The strength used for image-to-image/latents-to-latents.",
|
||||||
|
)
|
||||||
|
"""The strength used for image-to-image/latents-to-latents."""
|
||||||
|
latents: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The ID of the initial latents."
|
||||||
|
)
|
||||||
|
"""The ID of the initial latents"""
|
||||||
|
vae: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The VAE used for decoding."
|
||||||
|
)
|
||||||
|
"""The VAE used for decoding"""
|
||||||
|
unet: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The UNet used dor inference."
|
||||||
|
)
|
||||||
|
"""The UNet used dor inference"""
|
||||||
|
clip: Optional[StrictStr] = Field(
|
||||||
|
default=None, description="The CLIP Encoder used for conditioning."
|
||||||
|
)
|
||||||
|
"""The CLIP Encoder used for conditioning"""
|
||||||
|
extra: Optional[StrictStr] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Uploaded image metadata, extracted from the PNG tEXt chunk.",
|
||||||
|
)
|
||||||
|
"""Uploaded image metadata, extracted from the PNG tEXt chunk."""
|
@ -353,6 +353,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
|
||||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
|
||||||
|
|
||||||
|
|
||||||
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
root : Path = Field(default=_find_root(), description='InvokeAI runtime root directory', category='Paths')
|
||||||
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
autoconvert_dir : Path = Field(default=None, description='Path to a directory of ckpt files to be converted into diffusers and imported on startup.', category='Paths')
|
||||||
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
conf_path : Path = Field(default='configs/models.yaml', description='Path to models definition file', category='Paths')
|
||||||
@ -362,6 +363,7 @@ setting environment variables INVOKEAI_<setting>.
|
|||||||
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
lora_dir : Path = Field(default='loras', description='Path to InvokeAI LoRA model directory', category='Paths')
|
||||||
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths')
|
||||||
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths')
|
||||||
|
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths')
|
||||||
|
|
||||||
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models')
|
||||||
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
embeddings : bool = Field(default=True, description='Load contents of embeddings directory', category='Models')
|
||||||
@ -511,7 +513,7 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAISettings:
|
def get_invokeai_config(cls:Type[InvokeAISettings]=InvokeAIAppConfig,**kwargs)->InvokeAIAppConfig:
|
||||||
'''
|
'''
|
||||||
This returns a singleton InvokeAIAppConfig configuration object.
|
This returns a singleton InvokeAIAppConfig configuration object.
|
||||||
'''
|
'''
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
from invokeai.app.api.models.images import ProgressImage
|
from invokeai.app.api.models.images import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
|
||||||
|
@ -713,6 +713,13 @@ class Graph(BaseModel):
|
|||||||
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
def nx_graph_with_data(self) -> nx.DiGraph:
|
||||||
|
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
|
||||||
|
g = nx.DiGraph()
|
||||||
|
g.add_nodes_from([n for n in self.nodes.items()])
|
||||||
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(
|
def nx_graph_flat(
|
||||||
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
|
||||||
) -> nx.DiGraph:
|
) -> nx.DiGraph:
|
||||||
|
204
invokeai/app/services/image_file_storage.py
Normal file
204
invokeai/app/services/image_file_storage.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
from PIL import Image, PngImagePlugin
|
||||||
|
from send2trash import send2trash
|
||||||
|
|
||||||
|
from invokeai.app.models.image import ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageFileNotFoundException(Exception):
|
||||||
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileSaveException(Exception):
|
||||||
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileDeleteException(Exception):
|
||||||
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
|
"""Retrieves an image as PIL Image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||||
|
# 500 internal server error. I don't like having this method on the service.
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates the path given for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
metadata: Optional[ImageMetadata] = None,
|
||||||
|
thumbnail_size: int = 256,
|
||||||
|
) -> None:
|
||||||
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
|
"""Stores images on disk"""
|
||||||
|
|
||||||
|
__output_folder: str
|
||||||
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
|
__cache: Dict[str, PILImageType]
|
||||||
|
__max_cache_size: int
|
||||||
|
|
||||||
|
def __init__(self, output_folder: str):
|
||||||
|
self.__output_folder = output_folder
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = 10 # TODO: get this from config
|
||||||
|
|
||||||
|
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
||||||
|
for image_type in ImageType:
|
||||||
|
Path(os.path.join(output_folder, image_type)).mkdir(
|
||||||
|
parents=True, exist_ok=True
|
||||||
|
)
|
||||||
|
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
||||||
|
parents=True, exist_ok=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
|
try:
|
||||||
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
cache_item = self.__get_cache(image_path)
|
||||||
|
if cache_item:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
image = Image.open(image_path)
|
||||||
|
self.__set_cache(image_path, image)
|
||||||
|
return image
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
raise ImageFileNotFoundException from e
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_name: str,
|
||||||
|
metadata: Optional[ImageMetadata] = None,
|
||||||
|
thumbnail_size: int = 256,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
|
||||||
|
if metadata is not None:
|
||||||
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
pnginfo.add_text("invokeai", metadata.json())
|
||||||
|
image.save(image_path, "PNG", pnginfo=pnginfo)
|
||||||
|
else:
|
||||||
|
image.save(image_path, "PNG")
|
||||||
|
|
||||||
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
|
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
|
||||||
|
thumbnail_image = make_thumbnail(image, thumbnail_size)
|
||||||
|
thumbnail_image.save(thumbnail_path)
|
||||||
|
|
||||||
|
self.__set_cache(image_path, image)
|
||||||
|
self.__set_cache(thumbnail_path, thumbnail_image)
|
||||||
|
except Exception as e:
|
||||||
|
raise ImageFileSaveException from e
|
||||||
|
|
||||||
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
|
try:
|
||||||
|
basename = os.path.basename(image_name)
|
||||||
|
image_path = self.get_path(image_type, basename)
|
||||||
|
|
||||||
|
if os.path.exists(image_path):
|
||||||
|
send2trash(image_path)
|
||||||
|
if image_path in self.__cache:
|
||||||
|
del self.__cache[image_path]
|
||||||
|
|
||||||
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
|
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
||||||
|
|
||||||
|
if os.path.exists(thumbnail_path):
|
||||||
|
send2trash(thumbnail_path)
|
||||||
|
if thumbnail_path in self.__cache:
|
||||||
|
del self.__cache[thumbnail_path]
|
||||||
|
except Exception as e:
|
||||||
|
raise ImageFileDeleteException from e
|
||||||
|
|
||||||
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
# strip out any relative path shenanigans
|
||||||
|
basename = os.path.basename(image_name)
|
||||||
|
|
||||||
|
if thumbnail:
|
||||||
|
thumbnail_name = get_thumbnail_name(basename)
|
||||||
|
path = os.path.join(
|
||||||
|
self.__output_folder, image_type, "thumbnails", thumbnail_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = os.path.join(self.__output_folder, image_type, basename)
|
||||||
|
|
||||||
|
abspath = os.path.abspath(path)
|
||||||
|
|
||||||
|
return abspath
|
||||||
|
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates the path given for an image or thumbnail."""
|
||||||
|
try:
|
||||||
|
os.stat(path)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __get_cache(self, image_name: str) -> PILImageType | None:
|
||||||
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
|
def __set_cache(self, image_name: str, image: PILImageType):
|
||||||
|
if not image_name in self.__cache:
|
||||||
|
self.__cache[image_name] = image
|
||||||
|
self.__cache_ids.put(
|
||||||
|
image_name
|
||||||
|
) # TODO: this should refresh position for LRU cache
|
||||||
|
if len(self.__cache) > self.__max_cache_size:
|
||||||
|
cache_id = self.__cache_ids.get()
|
||||||
|
if cache_id in self.__cache:
|
||||||
|
del self.__cache[cache_id]
|
317
invokeai/app/services/image_record_storage.py
Normal file
317
invokeai/app/services/image_record_storage.py
Normal file
@ -0,0 +1,317 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, cast
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.models.image import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageType,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.models.image_record import (
|
||||||
|
ImageRecord,
|
||||||
|
deserialize_image_record,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an image record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordSaveException(Exception):
|
||||||
|
"""Raised when an image record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordDeleteException(Exception):
|
||||||
|
"""Raised when an image record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
page: int = 0,
|
||||||
|
per_page: int = 10,
|
||||||
|
) -> PaginatedResults[ImageRecord]:
|
||||||
|
"""Gets a page of image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
|
"""Deletes an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
session_id: Optional[str],
|
||||||
|
node_id: Optional[str],
|
||||||
|
metadata: Optional[ImageMetadata],
|
||||||
|
) -> datetime:
|
||||||
|
"""Saves an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
|
_filename: str
|
||||||
|
_conn: sqlite3.Connection
|
||||||
|
_cursor: sqlite3.Cursor
|
||||||
|
_lock: threading.Lock
|
||||||
|
|
||||||
|
def __init__(self, filename: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._filename = filename
|
||||||
|
self._conn = sqlite3.connect(filename, check_same_thread=False)
|
||||||
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
# Enable foreign keys
|
||||||
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
self._create_tables()
|
||||||
|
self._conn.commit()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def _create_tables(self) -> None:
|
||||||
|
"""Creates the tables for the `images` database."""
|
||||||
|
|
||||||
|
# Create the `images` table.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS images (
|
||||||
|
image_name TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
|
image_type TEXT NOT NULL,
|
||||||
|
-- This is an enum in python, unrestricted string here for flexibility
|
||||||
|
image_category TEXT NOT NULL,
|
||||||
|
width INTEGER NOT NULL,
|
||||||
|
height INTEGER NOT NULL,
|
||||||
|
session_id TEXT,
|
||||||
|
node_id TEXT,
|
||||||
|
metadata TEXT,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
-- Soft delete, currently unused
|
||||||
|
deleted_at DATETIME
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the `images` table indices.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_image_name ON images(image_name);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_images_created_at ON images(created_at);
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS tg_images_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON images FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE images SET updated_at = current_timestamp
|
||||||
|
WHERE image_name = old.image_name;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT * FROM images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordNotFoundException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
raise ImageRecordNotFoundException
|
||||||
|
|
||||||
|
return deserialize_image_record(dict(result))
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
page: int = 0,
|
||||||
|
per_page: int = 10,
|
||||||
|
) -> PaginatedResults[ImageRecord]:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
f"""--sql
|
||||||
|
SELECT * FROM images
|
||||||
|
WHERE image_type = ? AND image_category = ?
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?;
|
||||||
|
""",
|
||||||
|
(image_type.value, image_category.value, per_page, page * per_page),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||||
|
|
||||||
|
images = list(map(lambda r: deserialize_image_record(dict(r)), result))
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT count(*) FROM images
|
||||||
|
WHERE image_type = ? AND image_category = ?
|
||||||
|
""",
|
||||||
|
(image_type.value, image_category.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
count = self._cursor.fetchone()[0]
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
pageCount = int(count / per_page) + 1
|
||||||
|
|
||||||
|
return PaginatedResults(
|
||||||
|
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
|
try:
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
DELETE FROM images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordDeleteException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
session_id: Optional[str],
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
node_id: Optional[str],
|
||||||
|
metadata: Optional[ImageMetadata],
|
||||||
|
) -> datetime:
|
||||||
|
try:
|
||||||
|
metadata_json = (
|
||||||
|
None if metadata is None else metadata.json(exclude_none=True)
|
||||||
|
)
|
||||||
|
self._lock.acquire()
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
INSERT OR IGNORE INTO images (
|
||||||
|
image_name,
|
||||||
|
image_type,
|
||||||
|
image_category,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
node_id,
|
||||||
|
session_id,
|
||||||
|
metadata
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
image_name,
|
||||||
|
image_type.value,
|
||||||
|
image_category.value,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
node_id,
|
||||||
|
session_id,
|
||||||
|
metadata_json,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._conn.commit()
|
||||||
|
|
||||||
|
self._cursor.execute(
|
||||||
|
"""--sql
|
||||||
|
SELECT created_at
|
||||||
|
FROM images
|
||||||
|
WHERE image_name = ?;
|
||||||
|
""",
|
||||||
|
(image_name,),
|
||||||
|
)
|
||||||
|
|
||||||
|
created_at = datetime.fromisoformat(self._cursor.fetchone()[0])
|
||||||
|
|
||||||
|
return created_at
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
self._conn.rollback()
|
||||||
|
raise ImageRecordSaveException from e
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
@ -1,274 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
import os
|
|
||||||
from glob import glob
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from PIL.Image import Image
|
|
||||||
import PIL.Image as PILImage
|
|
||||||
from send2trash import send2trash
|
|
||||||
from invokeai.app.api.models.images import (
|
|
||||||
ImageResponse,
|
|
||||||
ImageResponseMetadata,
|
|
||||||
SavedImage,
|
|
||||||
)
|
|
||||||
from invokeai.app.models.image import ImageType
|
|
||||||
from invokeai.app.services.metadata import (
|
|
||||||
InvokeAIMetadata,
|
|
||||||
MetadataServiceBase,
|
|
||||||
build_invokeai_metadata_pnginfo,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
|
||||||
from invokeai.app.util.misc import get_timestamp
|
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
|
||||||
"""Responsible for storing and retrieving images."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
|
||||||
"""Retrieves an image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list(
|
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
"""Gets a paginated list of images."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the internal path to an image or its thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Gets the external URI to an image or its thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_name: str,
|
|
||||||
image: Image,
|
|
||||||
metadata: InvokeAIMetadata | None = None,
|
|
||||||
) -> SavedImage:
|
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def create_name(self, context_id: str, node_id: str) -> str:
|
|
||||||
"""Creates a unique contextual image filename."""
|
|
||||||
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
|
|
||||||
|
|
||||||
|
|
||||||
class DiskImageStorage(ImageStorageBase):
|
|
||||||
"""Stores images on disk"""
|
|
||||||
|
|
||||||
__output_folder: str
|
|
||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
|
||||||
__cache: Dict[str, Image]
|
|
||||||
__max_cache_size: int
|
|
||||||
__metadata_service: MetadataServiceBase
|
|
||||||
|
|
||||||
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
|
|
||||||
self.__output_folder = output_folder
|
|
||||||
self.__cache = dict()
|
|
||||||
self.__cache_ids = Queue()
|
|
||||||
self.__max_cache_size = 10 # TODO: get this from config
|
|
||||||
self.__metadata_service = metadata_service
|
|
||||||
|
|
||||||
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# TODO: don't hard-code. get/save/delete should maybe take subpath?
|
|
||||||
for image_type in ImageType:
|
|
||||||
Path(os.path.join(output_folder, image_type)).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
|
|
||||||
parents=True, exist_ok=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def list(
|
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
|
||||||
) -> PaginatedResults[ImageResponse]:
|
|
||||||
dir_path = os.path.join(self.__output_folder, image_type)
|
|
||||||
image_paths = glob(f"{dir_path}/*.png")
|
|
||||||
count = len(image_paths)
|
|
||||||
|
|
||||||
sorted_image_paths = sorted(
|
|
||||||
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
page_of_image_paths = sorted_image_paths[
|
|
||||||
page * per_page : (page + 1) * per_page
|
|
||||||
]
|
|
||||||
|
|
||||||
page_of_images: List[ImageResponse] = []
|
|
||||||
|
|
||||||
for path in page_of_image_paths:
|
|
||||||
filename = os.path.basename(path)
|
|
||||||
img = PILImage.open(path)
|
|
||||||
|
|
||||||
invokeai_metadata = self.__metadata_service.get_metadata(img)
|
|
||||||
|
|
||||||
page_of_images.append(
|
|
||||||
ImageResponse(
|
|
||||||
image_type=image_type.value,
|
|
||||||
image_name=filename,
|
|
||||||
# TODO: DiskImageStorage should not be building URLs...?
|
|
||||||
image_url=self.get_uri(image_type, filename),
|
|
||||||
thumbnail_url=self.get_uri(image_type, filename, True),
|
|
||||||
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
|
|
||||||
metadata=ImageResponseMetadata(
|
|
||||||
created=int(os.path.getctime(path)),
|
|
||||||
width=img.width,
|
|
||||||
height=img.height,
|
|
||||||
invokeai=invokeai_metadata,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
page_count_trunc = int(count / per_page)
|
|
||||||
page_count_mod = count % per_page
|
|
||||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
|
||||||
|
|
||||||
return PaginatedResults[ImageResponse](
|
|
||||||
items=page_of_images,
|
|
||||||
page=page,
|
|
||||||
pages=page_count,
|
|
||||||
per_page=per_page,
|
|
||||||
total=count,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, image_type: ImageType, image_name: str) -> Image:
|
|
||||||
image_path = self.get_path(image_type, image_name)
|
|
||||||
cache_item = self.__get_cache(image_path)
|
|
||||||
if cache_item:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
image = PILImage.open(image_path)
|
|
||||||
self.__set_cache(image_path, image)
|
|
||||||
return image
|
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
def get_path(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
path = os.path.join(
|
|
||||||
self.__output_folder, image_type, "thumbnails", basename
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
path = os.path.join(self.__output_folder, image_type, basename)
|
|
||||||
|
|
||||||
abspath = os.path.abspath(path)
|
|
||||||
|
|
||||||
return abspath
|
|
||||||
|
|
||||||
def get_uri(
|
|
||||||
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
|
||||||
) -> str:
|
|
||||||
# strip out any relative path shenanigans
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
|
|
||||||
if is_thumbnail:
|
|
||||||
thumbnail_basename = get_thumbnail_name(basename)
|
|
||||||
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
|
|
||||||
else:
|
|
||||||
uri = f"api/v1/images/{image_type.value}/{basename}"
|
|
||||||
|
|
||||||
return uri
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
try:
|
|
||||||
os.stat(path)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_type: ImageType,
|
|
||||||
image_name: str,
|
|
||||||
image: Image,
|
|
||||||
metadata: InvokeAIMetadata | None = None,
|
|
||||||
) -> SavedImage:
|
|
||||||
image_path = self.get_path(image_type, image_name)
|
|
||||||
|
|
||||||
# TODO: Reading the image and then saving it strips the metadata...
|
|
||||||
if metadata:
|
|
||||||
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
|
|
||||||
image.save(image_path, "PNG", pnginfo=pnginfo)
|
|
||||||
else:
|
|
||||||
image.save(image_path) # this saved image has an empty info
|
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
|
|
||||||
thumbnail_image = make_thumbnail(image)
|
|
||||||
thumbnail_image.save(thumbnail_path)
|
|
||||||
|
|
||||||
self.__set_cache(image_path, image)
|
|
||||||
self.__set_cache(thumbnail_path, thumbnail_image)
|
|
||||||
|
|
||||||
return SavedImage(
|
|
||||||
image_name=image_name,
|
|
||||||
thumbnail_name=thumbnail_name,
|
|
||||||
created=int(os.path.getctime(image_path)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
|
||||||
basename = os.path.basename(image_name)
|
|
||||||
image_path = self.get_path(image_type, basename)
|
|
||||||
|
|
||||||
if os.path.exists(image_path):
|
|
||||||
send2trash(image_path)
|
|
||||||
if image_path in self.__cache:
|
|
||||||
del self.__cache[image_path]
|
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
|
||||||
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
|
|
||||||
|
|
||||||
if os.path.exists(thumbnail_path):
|
|
||||||
send2trash(thumbnail_path)
|
|
||||||
if thumbnail_path in self.__cache:
|
|
||||||
del self.__cache[thumbnail_path]
|
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image | None:
|
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
|
||||||
|
|
||||||
def __set_cache(self, image_name: str, image: Image):
|
|
||||||
if not image_name in self.__cache:
|
|
||||||
self.__cache[image_name] = image
|
|
||||||
self.__cache_ids.put(
|
|
||||||
image_name
|
|
||||||
) # TODO: this should refresh position for LRU cache
|
|
||||||
if len(self.__cache) > self.__max_cache_size:
|
|
||||||
cache_id = self.__cache_ids.get()
|
|
||||||
if cache_id in self.__cache:
|
|
||||||
del self.__cache[cache_id]
|
|
375
invokeai/app/services/images.py
Normal file
375
invokeai/app/services/images.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Optional, TYPE_CHECKING, Union
|
||||||
|
import uuid
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.models.image import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageType,
|
||||||
|
InvalidImageCategoryException,
|
||||||
|
InvalidImageTypeException,
|
||||||
|
)
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.image_record_storage import (
|
||||||
|
ImageRecordDeleteException,
|
||||||
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
|
ImageRecordStorageBase,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.models.image_record import (
|
||||||
|
ImageRecord,
|
||||||
|
ImageDTO,
|
||||||
|
image_record_to_dto,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.image_file_storage import (
|
||||||
|
ImageFileDeleteException,
|
||||||
|
ImageFileNotFoundException,
|
||||||
|
ImageFileSaveException,
|
||||||
|
ImageFileStorageBase,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
|
||||||
|
from invokeai.app.services.metadata import MetadataServiceBase
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceABC(ABC):
|
||||||
|
"""High-level service for image management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
metadata: Optional[ImageMetadata] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Creates an image, storing the file and its metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
|
"""Gets an image as a PIL image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
||||||
|
"""Gets an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Gets an image's or thumbnail's URL."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
page: int = 0,
|
||||||
|
per_page: int = 10,
|
||||||
|
) -> PaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a paginated list of image DTOs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_type: ImageType, image_name: str):
|
||||||
|
"""Deletes an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceDependencies:
|
||||||
|
"""Service dependencies for the ImageService."""
|
||||||
|
|
||||||
|
records: ImageRecordStorageBase
|
||||||
|
files: ImageFileStorageBase
|
||||||
|
metadata: MetadataServiceBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
image_file_storage: ImageFileStorageBase,
|
||||||
|
metadata: MetadataServiceBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
|
):
|
||||||
|
self.records = image_record_storage
|
||||||
|
self.files = image_file_storage
|
||||||
|
self.metadata = metadata
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
|
||||||
|
|
||||||
|
class ImageService(ImageServiceABC):
|
||||||
|
_services: ImageServiceDependencies
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
image_file_storage: ImageFileStorageBase,
|
||||||
|
metadata: MetadataServiceBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
|
):
|
||||||
|
self._services = ImageServiceDependencies(
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
image_file_storage=image_file_storage,
|
||||||
|
metadata=metadata,
|
||||||
|
url=url,
|
||||||
|
logger=logger,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
if image_type not in ImageType:
|
||||||
|
raise InvalidImageTypeException
|
||||||
|
|
||||||
|
if image_category not in ImageCategory:
|
||||||
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
|
image_name = self._create_image_name(
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=image_category,
|
||||||
|
node_id=node_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = self._get_metadata(session_id, node_id)
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
|
created_at = self._services.records.save(
|
||||||
|
# Non-nullable fields
|
||||||
|
image_name=image_name,
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Nullable fields
|
||||||
|
node_id=node_id,
|
||||||
|
session_id=session_id,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._services.files.save(
|
||||||
|
image_type=image_type,
|
||||||
|
image_name=image_name,
|
||||||
|
image=image,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_url = self._services.urls.get_image_url(image_type, image_name)
|
||||||
|
thumbnail_url = self._services.urls.get_image_url(
|
||||||
|
image_type, image_name, True
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageDTO(
|
||||||
|
# Non-nullable fields
|
||||||
|
image_name=image_name,
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Nullable fields
|
||||||
|
node_id=node_id,
|
||||||
|
session_id=session_id,
|
||||||
|
metadata=metadata,
|
||||||
|
# Meta fields
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=created_at, # this is always the same as the created_at at this time
|
||||||
|
deleted_at=None,
|
||||||
|
# Extra non-nullable fields for DTO
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
)
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self._services.logger.error("Failed to save image record")
|
||||||
|
raise
|
||||||
|
except ImageFileSaveException:
|
||||||
|
self._services.logger.error("Failed to save image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem saving image record and file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
|
||||||
|
try:
|
||||||
|
return self._services.files.get(image_type, image_name)
|
||||||
|
except ImageFileNotFoundException:
|
||||||
|
self._services.logger.error("Failed to get image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
|
||||||
|
try:
|
||||||
|
return self._services.records.get(image_type, image_name)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
image_record = self._services.records.get(image_type, image_name)
|
||||||
|
|
||||||
|
image_dto = image_record_to_dto(
|
||||||
|
image_record,
|
||||||
|
self._services.urls.get_image_url(image_type, image_name),
|
||||||
|
self._services.urls.get_image_url(image_type, image_name, True),
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.files.get_path(image_type, image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
try:
|
||||||
|
return self._services.files.validate_path(path)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem validating image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
page: int = 0,
|
||||||
|
per_page: int = 10,
|
||||||
|
) -> PaginatedResults[ImageDTO]:
|
||||||
|
try:
|
||||||
|
results = self._services.records.get_many(
|
||||||
|
image_type,
|
||||||
|
image_category,
|
||||||
|
page,
|
||||||
|
per_page,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dtos = list(
|
||||||
|
map(
|
||||||
|
lambda r: image_record_to_dto(
|
||||||
|
r,
|
||||||
|
self._services.urls.get_image_url(image_type, r.image_name),
|
||||||
|
self._services.urls.get_image_url(
|
||||||
|
image_type, r.image_name, True
|
||||||
|
),
|
||||||
|
),
|
||||||
|
results.items,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return PaginatedResults[ImageDTO](
|
||||||
|
items=image_dtos,
|
||||||
|
page=results.page,
|
||||||
|
pages=results.pages,
|
||||||
|
per_page=results.per_page,
|
||||||
|
total=results.total,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self, image_type: ImageType, image_name: str):
|
||||||
|
try:
|
||||||
|
self._services.files.delete(image_type, image_name)
|
||||||
|
self._services.records.delete(image_type, image_name)
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self._services.logger.error(f"Failed to delete image record")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self._services.logger.error(f"Failed to delete image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _create_image_name(
|
||||||
|
self,
|
||||||
|
image_type: ImageType,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a unique image name."""
|
||||||
|
uuid_str = str(uuid.uuid4())
|
||||||
|
|
||||||
|
if node_id is not None and session_id is not None:
|
||||||
|
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
|
||||||
|
|
||||||
|
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
|
||||||
|
|
||||||
|
def _get_metadata(
|
||||||
|
self, session_id: Optional[str] = None, node_id: Optional[str] = None
|
||||||
|
) -> Union[ImageMetadata, None]:
|
||||||
|
"""Get the metadata for a node."""
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
if node_id is not None and session_id is not None:
|
||||||
|
session = self._services.graph_execution_manager.get(session_id)
|
||||||
|
metadata = self._services.metadata.create_image_metadata(session, node_id)
|
||||||
|
|
||||||
|
return metadata
|
@ -1,55 +1,57 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from typing import types
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.metadata import MetadataServiceBase
|
from logging import Logger
|
||||||
from invokeai.backend import ModelManager
|
from invokeai.app.services.images import ImageService
|
||||||
|
from invokeai.backend import ModelManager
|
||||||
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
|
from invokeai.app.services.restoration_services import RestorationServices
|
||||||
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
|
from invokeai.app.services.config import InvokeAISettings
|
||||||
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
|
|
||||||
from .events import EventServiceBase
|
|
||||||
from .latent_storage import LatentsStorageBase
|
|
||||||
from .image_storage import ImageStorageBase
|
|
||||||
from .restoration_services import RestorationServices
|
|
||||||
from .invocation_queue import InvocationQueueABC
|
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
from .config import InvokeAISettings
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
events: EventServiceBase
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
latents: LatentsStorageBase
|
events: "EventServiceBase"
|
||||||
images: ImageStorageBase
|
latents: "LatentsStorageBase"
|
||||||
metadata: MetadataServiceBase
|
queue: "InvocationQueueABC"
|
||||||
queue: InvocationQueueABC
|
model_manager: "ModelManager"
|
||||||
model_manager: ModelManager
|
restoration: "RestorationServices"
|
||||||
restoration: RestorationServices
|
configuration: "InvokeAISettings"
|
||||||
configuration: InvokeAISettings
|
images: "ImageService"
|
||||||
|
|
||||||
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
|
||||||
graph_library: ItemStorageABC["LibraryGraph"]
|
graph_library: "ItemStorageABC"["LibraryGraph"]
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"]
|
||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_manager: ModelManager,
|
model_manager: "ModelManager",
|
||||||
events: EventServiceBase,
|
events: "EventServiceBase",
|
||||||
logger: types.ModuleType,
|
logger: "Logger",
|
||||||
latents: LatentsStorageBase,
|
latents: "LatentsStorageBase",
|
||||||
images: ImageStorageBase,
|
images: "ImageService",
|
||||||
metadata: MetadataServiceBase,
|
queue: "InvocationQueueABC",
|
||||||
queue: InvocationQueueABC,
|
graph_library: "ItemStorageABC"["LibraryGraph"],
|
||||||
graph_library: ItemStorageABC["LibraryGraph"],
|
graph_execution_manager: "ItemStorageABC"["GraphExecutionState"],
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
processor: "InvocationProcessorABC",
|
||||||
processor: "InvocationProcessorABC",
|
restoration: "RestorationServices",
|
||||||
restoration: RestorationServices,
|
configuration: "InvokeAISettings",
|
||||||
configuration: InvokeAISettings=None,
|
|
||||||
):
|
):
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.events = events
|
self.events = events
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.images = images
|
self.images = images
|
||||||
self.metadata = metadata
|
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def set(self, name: str, data: torch.Tensor) -> None:
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|||||||
self.__set_cache(name, latent)
|
self.__set_cache(name, latent)
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
def set(self, name: str, data: torch.Tensor) -> None:
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
self.__underlying_storage.set(name, data)
|
self.__underlying_storage.save(name, data)
|
||||||
self.__set_cache(name, data)
|
self.__set_cache(name, data)
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
def delete(self, name: str) -> None:
|
||||||
@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
|
|||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
return torch.load(latent_path)
|
return torch.load(latent_path)
|
||||||
|
|
||||||
def set(self, name: str, data: torch.Tensor) -> None:
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
latent_path = self.get_path(name)
|
latent_path = self.get_path(name)
|
||||||
torch.save(data, latent_path)
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
@ -1,105 +1,142 @@
|
|||||||
import json
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, Optional, TypedDict
|
from typing import Any, Union
|
||||||
from PIL import Image, PngImagePlugin
|
import networkx as nx
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageType, is_image_type
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.graph import Graph, GraphExecutionState
|
||||||
|
|
||||||
class MetadataImageField(TypedDict):
|
|
||||||
"""Pydantic-less ImageField, used for metadata parsing."""
|
|
||||||
|
|
||||||
image_type: ImageType
|
|
||||||
image_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataLatentsField(TypedDict):
|
|
||||||
"""Pydantic-less LatentsField, used for metadata parsing."""
|
|
||||||
|
|
||||||
latents_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataColorField(TypedDict):
|
|
||||||
"""Pydantic-less ColorField, used for metadata parsing"""
|
|
||||||
r: int
|
|
||||||
g: int
|
|
||||||
b: int
|
|
||||||
a: int
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
|
|
||||||
NodeMetadata = Dict[
|
|
||||||
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIMetadata(TypedDict, total=False):
|
|
||||||
"""InvokeAI-specific metadata format."""
|
|
||||||
|
|
||||||
session_id: Optional[str]
|
|
||||||
node: Optional[NodeMetadata]
|
|
||||||
|
|
||||||
|
|
||||||
def build_invokeai_metadata_pnginfo(
|
|
||||||
metadata: InvokeAIMetadata | None,
|
|
||||||
) -> PngImagePlugin.PngInfo:
|
|
||||||
"""Builds a PngInfo object with key `"invokeai"` and value `metadata`"""
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
|
||||||
|
|
||||||
if metadata is not None:
|
|
||||||
pnginfo.add_text("invokeai", json.dumps(metadata))
|
|
||||||
|
|
||||||
return pnginfo
|
|
||||||
|
|
||||||
|
|
||||||
class MetadataServiceBase(ABC):
|
class MetadataServiceBase(ABC):
|
||||||
@abstractmethod
|
"""Handles building metadata for nodes, images, and outputs."""
|
||||||
def get_metadata(self, image: Image.Image) -> InvokeAIMetadata | None:
|
|
||||||
"""Gets the InvokeAI metadata from a PIL Image, skipping invalid values"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build_metadata(
|
def create_image_metadata(
|
||||||
self, session_id: str, node: BaseModel
|
self, session: GraphExecutionState, node_id: str
|
||||||
) -> InvokeAIMetadata | None:
|
) -> ImageMetadata:
|
||||||
"""Builds an InvokeAIMetadata object"""
|
"""Builds an ImageMetadata object for a node."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PngMetadataService(MetadataServiceBase):
|
class CoreMetadataService(MetadataServiceBase):
|
||||||
"""Handles loading and building metadata for images."""
|
_ANCESTOR_TYPES = ["t2l", "l2l"]
|
||||||
|
"""The ancestor types that contain the core metadata"""
|
||||||
|
|
||||||
# TODO: Use `InvocationsUnion` to **validate** metadata as representing a fully-functioning node
|
_ANCESTOR_PARAMS = ["type", "steps", "model", "cfg_scale", "scheduler", "strength"]
|
||||||
def _load_metadata(self, image: Image.Image) -> dict | None:
|
"""The core metadata parameters in the ancestor types"""
|
||||||
"""Loads a specific info entry from a PIL Image."""
|
|
||||||
|
|
||||||
try:
|
_NOISE_FIELDS = ["seed", "width", "height"]
|
||||||
info = image.info.get("invokeai")
|
"""The core metadata parameters in the noise node"""
|
||||||
|
|
||||||
if type(info) is not str:
|
def create_image_metadata(
|
||||||
return None
|
self, session: GraphExecutionState, node_id: str
|
||||||
|
) -> ImageMetadata:
|
||||||
loaded_metadata = json.loads(info)
|
metadata = self._build_metadata_from_graph(session, node_id)
|
||||||
|
|
||||||
if type(loaded_metadata) is not dict:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if len(loaded_metadata.items()) == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_metadata(self, image: Image.Image) -> dict | None:
|
|
||||||
"""Retrieves an image's metadata as a dict"""
|
|
||||||
loaded_metadata = self._load_metadata(image)
|
|
||||||
|
|
||||||
return loaded_metadata
|
|
||||||
|
|
||||||
def build_metadata(self, session_id: str, node: BaseModel) -> InvokeAIMetadata:
|
|
||||||
metadata = InvokeAIMetadata(session_id=session_id, node=node.dict())
|
|
||||||
|
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
def _find_nearest_ancestor(self, G: nx.DiGraph, node_id: str) -> Union[str, None]:
|
||||||
|
"""
|
||||||
|
Finds the id of the nearest ancestor (of a valid type) of a given node.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
G (nx.DiGraph): The execution graph, converted in to a networkx DiGraph. Its nodes must
|
||||||
|
have the same data as the execution graph.
|
||||||
|
node_id (str): The ID of the node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str | None: The ID of the nearest ancestor, or None if there are no valid ancestors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Retrieve the node from the graph
|
||||||
|
node = G.nodes[node_id]
|
||||||
|
|
||||||
|
# If the node type is one of the core metadata node types, return its id
|
||||||
|
if node.get("type") in self._ANCESTOR_TYPES:
|
||||||
|
return node.get("id")
|
||||||
|
|
||||||
|
# Else, look for the ancestor in the predecessor nodes
|
||||||
|
for predecessor in G.predecessors(node_id):
|
||||||
|
result = self._find_nearest_ancestor(G, predecessor)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# If there are no valid ancestors, return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_additional_metadata(
|
||||||
|
self, graph: Graph, node_id: str
|
||||||
|
) -> Union[dict[str, Any], None]:
|
||||||
|
"""
|
||||||
|
Returns additional metadata for a given node.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
graph (Graph): The execution graph.
|
||||||
|
node_id (str): The ID of the node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any] | None: A dictionary of additional metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Iterate over all edges in the graph
|
||||||
|
for edge in graph.edges:
|
||||||
|
dest_node_id = edge.destination.node_id
|
||||||
|
dest_field = edge.destination.field
|
||||||
|
source_node_dict = graph.nodes[edge.source.node_id].dict()
|
||||||
|
|
||||||
|
# If the destination node ID matches the given node ID, gather necessary metadata
|
||||||
|
if dest_node_id == node_id:
|
||||||
|
# Prompt
|
||||||
|
if dest_field == "positive_conditioning":
|
||||||
|
metadata["positive_conditioning"] = source_node_dict.get("prompt")
|
||||||
|
# Negative prompt
|
||||||
|
if dest_field == "negative_conditioning":
|
||||||
|
metadata["negative_conditioning"] = source_node_dict.get("prompt")
|
||||||
|
# Seed, width and height
|
||||||
|
if dest_field == "noise":
|
||||||
|
for field in self._NOISE_FIELDS:
|
||||||
|
metadata[field] = source_node_dict.get(field)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def _build_metadata_from_graph(
|
||||||
|
self, session: GraphExecutionState, node_id: str
|
||||||
|
) -> ImageMetadata:
|
||||||
|
"""
|
||||||
|
Builds an ImageMetadata object for a node.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
session (GraphExecutionState): The session.
|
||||||
|
node_id (str): The ID of the node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageMetadata: The metadata for the node.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We need to do all the traversal on the execution graph
|
||||||
|
graph = session.execution_graph
|
||||||
|
|
||||||
|
# Find the nearest `t2l`/`l2l` ancestor of the given node
|
||||||
|
ancestor_id = self._find_nearest_ancestor(graph.nx_graph_with_data(), node_id)
|
||||||
|
|
||||||
|
# If no ancestor was found, return an empty ImageMetadata object
|
||||||
|
if ancestor_id is None:
|
||||||
|
return ImageMetadata()
|
||||||
|
|
||||||
|
ancestor_node = graph.get_node(ancestor_id)
|
||||||
|
|
||||||
|
# Grab all the core metadata from the ancestor node
|
||||||
|
ancestor_metadata = {
|
||||||
|
param: val
|
||||||
|
for param, val in ancestor_node.dict().items()
|
||||||
|
if param in self._ANCESTOR_PARAMS
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get this image's prompts and noise parameters
|
||||||
|
addl_metadata = self._get_additional_metadata(graph, ancestor_id)
|
||||||
|
|
||||||
|
# If additional metadata was found, add it to the main metadata
|
||||||
|
if addl_metadata is not None:
|
||||||
|
ancestor_metadata.update(addl_metadata)
|
||||||
|
|
||||||
|
return ImageMetadata(**ancestor_metadata)
|
||||||
|
118
invokeai/app/services/models/image_record.py
Normal file
118
invokeai/app/services/models/image_record.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
import datetime
|
||||||
|
from typing import Optional, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from invokeai.app.models.image import ImageCategory, ImageType
|
||||||
|
from invokeai.app.models.metadata import ImageMetadata
|
||||||
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecord(BaseModel):
|
||||||
|
"""Deserialized image record."""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
|
image_type: ImageType = Field(description="The type of the image.")
|
||||||
|
"""The type of the image."""
|
||||||
|
image_category: ImageCategory = Field(description="The category of the image.")
|
||||||
|
"""The category of the image."""
|
||||||
|
width: int = Field(description="The width of the image in px.")
|
||||||
|
"""The actual width of the image in px. This may be different from the width in metadata."""
|
||||||
|
height: int = Field(description="The height of the image in px.")
|
||||||
|
"""The actual height of the image in px. This may be different from the height in metadata."""
|
||||||
|
created_at: Union[datetime.datetime, str] = Field(
|
||||||
|
description="The created timestamp of the image."
|
||||||
|
)
|
||||||
|
"""The created timestamp of the image."""
|
||||||
|
updated_at: Union[datetime.datetime, str] = Field(
|
||||||
|
description="The updated timestamp of the image."
|
||||||
|
)
|
||||||
|
"""The updated timestamp of the image."""
|
||||||
|
deleted_at: Union[datetime.datetime, str, None] = Field(
|
||||||
|
description="The deleted timestamp of the image."
|
||||||
|
)
|
||||||
|
"""The deleted timestamp of the image."""
|
||||||
|
session_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The session ID that generated this image, if it is a generated image.",
|
||||||
|
)
|
||||||
|
"""The session ID that generated this image, if it is a generated image."""
|
||||||
|
node_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The node ID that generated this image, if it is a generated image.",
|
||||||
|
)
|
||||||
|
"""The node ID that generated this image, if it is a generated image."""
|
||||||
|
metadata: Optional[ImageMetadata] = Field(
|
||||||
|
default=None,
|
||||||
|
description="A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.",
|
||||||
|
)
|
||||||
|
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUrlsDTO(BaseModel):
|
||||||
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
|
image_type: ImageType = Field(description="The type of the image.")
|
||||||
|
"""The type of the image."""
|
||||||
|
image_url: str = Field(description="The URL of the image.")
|
||||||
|
"""The URL of the image."""
|
||||||
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
"""The URL of the image's thumbnail."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
|
"""Deserialized image record, enriched for the frontend with URLs."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def image_record_to_dto(
|
||||||
|
image_record: ImageRecord, image_url: str, thumbnail_url: str
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Converts an image record to an image DTO."""
|
||||||
|
return ImageDTO(
|
||||||
|
**image_record.dict(),
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||||
|
"""Deserializes an image record."""
|
||||||
|
|
||||||
|
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||||
|
|
||||||
|
image_name = image_dict.get("image_name", "unknown")
|
||||||
|
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
|
||||||
|
image_category = ImageCategory(
|
||||||
|
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||||
|
)
|
||||||
|
width = image_dict.get("width", 0)
|
||||||
|
height = image_dict.get("height", 0)
|
||||||
|
session_id = image_dict.get("session_id", None)
|
||||||
|
node_id = image_dict.get("node_id", None)
|
||||||
|
created_at = image_dict.get("created_at", get_iso_timestamp())
|
||||||
|
updated_at = image_dict.get("updated_at", get_iso_timestamp())
|
||||||
|
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||||
|
|
||||||
|
raw_metadata = image_dict.get("metadata")
|
||||||
|
|
||||||
|
if raw_metadata is not None:
|
||||||
|
metadata = ImageMetadata.parse_raw(raw_metadata)
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
return ImageRecord(
|
||||||
|
image_name=image_name,
|
||||||
|
image_type=image_type,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
session_id=session_id,
|
||||||
|
node_id=node_id,
|
||||||
|
metadata=metadata,
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
deleted_at=deleted_at,
|
||||||
|
)
|
34
invokeai/app/services/urls.py
Normal file
34
invokeai/app/services/urls.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.models.image import ImageType
|
||||||
|
from invokeai.app.util.thumbnails import get_thumbnail_name
|
||||||
|
|
||||||
|
|
||||||
|
class UrlServiceBase(ABC):
|
||||||
|
"""Responsible for building URLs for resources."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""Gets the URL for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LocalUrlService(UrlServiceBase):
|
||||||
|
def __init__(self, base_url: str = "api/v1"):
|
||||||
|
self._base_url = base_url
|
||||||
|
|
||||||
|
def get_image_url(
|
||||||
|
self, image_type: ImageType, image_name: str, thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
image_basename = os.path.basename(image_name)
|
||||||
|
|
||||||
|
# These paths are determined by the routes in invokeai/app/api/routers/images.py
|
||||||
|
if thumbnail:
|
||||||
|
return (
|
||||||
|
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail"
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"{self._base_url}/images/{image_type.value}/{image_basename}"
|
15
invokeai/app/util/metaenum.py
Normal file
15
invokeai/app/util/metaenum.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from enum import EnumMeta
|
||||||
|
|
||||||
|
|
||||||
|
class MetaEnum(EnumMeta):
|
||||||
|
"""Metaclass to support additional features in Enums.
|
||||||
|
|
||||||
|
- `in` operator support: `'value' in MyEnum -> bool`
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __contains__(cls, item):
|
||||||
|
try:
|
||||||
|
cls(item)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return True
|
@ -6,6 +6,14 @@ def get_timestamp():
|
|||||||
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||||
|
|
||||||
|
|
||||||
|
def get_iso_timestamp() -> str:
|
||||||
|
return datetime.datetime.utcnow().isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||||
|
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||||
|
|
||||||
|
|
||||||
SEED_MAX = np.iinfo(np.int32).max
|
SEED_MAX = np.iinfo(np.int32).max
|
||||||
|
|
||||||
|
|
||||||
|
@ -196,7 +196,7 @@ class Inpaint(Img2Img):
|
|||||||
|
|
||||||
seam_noise = self.get_noise(im.width, im.height)
|
seam_noise = self.get_noise(im.width, im.height)
|
||||||
|
|
||||||
result = make_image(seam_noise, seed)
|
result = make_image(seam_noise, seed=None)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -76,16 +76,16 @@ class InvokeAILogFormatter(logging.Formatter):
|
|||||||
reset = "\x1b[0m"
|
reset = "\x1b[0m"
|
||||||
|
|
||||||
# Log Format
|
# Log Format
|
||||||
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
log_format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||||
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
|
||||||
|
|
||||||
# Format Map
|
# Format Map
|
||||||
FORMATS = {
|
FORMATS = {
|
||||||
logging.DEBUG: cyan + format + reset,
|
logging.DEBUG: cyan + log_format + reset,
|
||||||
logging.INFO: grey + format + reset,
|
logging.INFO: grey + log_format + reset,
|
||||||
logging.WARNING: yellow + format + reset,
|
logging.WARNING: yellow + log_format + reset,
|
||||||
logging.ERROR: red + format + reset,
|
logging.ERROR: red + log_format + reset,
|
||||||
logging.CRITICAL: bold_red + format + reset
|
logging.CRITICAL: bold_red + log_format + reset
|
||||||
}
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
@ -98,13 +98,13 @@ class InvokeAILogger(object):
|
|||||||
loggers = dict()
|
loggers = dict()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
|
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
||||||
if name not in self.loggers:
|
if name not in cls.loggers:
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
fmt = InvokeAILogFormatter()
|
fmt = InvokeAILogFormatter()
|
||||||
ch.setFormatter(fmt)
|
ch.setFormatter(fmt)
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
self.loggers[name] = logger
|
cls.loggers[name] = logger
|
||||||
return self.loggers[name]
|
return cls.loggers[name]
|
||||||
|
@ -23,8 +23,8 @@
|
|||||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||||
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
"dev:host": "concurrently \"vite dev --host\" \"yarn run theme:watch\"",
|
||||||
"build": "yarn run lint && vite build",
|
"build": "yarn run lint && vite build",
|
||||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
|
||||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --indent 2 --request src/services/fixtures/request.ts",
|
||||||
"preview": "vite preview",
|
"preview": "vite preview",
|
||||||
"lint:madge": "madge --circular src/main.tsx",
|
"lint:madge": "madge --circular src/main.tsx",
|
||||||
"lint:eslint": "eslint --max-warnings=0 .",
|
"lint:eslint": "eslint --max-warnings=0 .",
|
||||||
|
@ -10,7 +10,7 @@ export const readinessSelector = createSelector(
|
|||||||
[generationSelector, systemSelector, activeTabNameSelector],
|
[generationSelector, systemSelector, activeTabNameSelector],
|
||||||
(generation, system, activeTabName) => {
|
(generation, system, activeTabName) => {
|
||||||
const {
|
const {
|
||||||
prompt,
|
positivePrompt: prompt,
|
||||||
shouldGenerateVariations,
|
shouldGenerateVariations,
|
||||||
seedWeights,
|
seedWeights,
|
||||||
initialImage,
|
initialImage,
|
||||||
|
@ -5,7 +5,6 @@ import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
|
|||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { imageUploaded } from 'services/thunks/image';
|
import { imageUploaded } from 'services/thunks/image';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
|
||||||
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
import { setMergedCanvas } from 'features/canvas/store/canvasSlice';
|
||||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||||
|
|
||||||
@ -66,7 +65,7 @@ export const addCanvasMergedListener = () => {
|
|||||||
action.meta.arg.formData.file.name === filename
|
action.meta.arg.formData.file.name === filename
|
||||||
);
|
);
|
||||||
|
|
||||||
const mergedCanvasImage = deserializeImageResponse(payload.response);
|
const mergedCanvasImage = payload.response;
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
setMergedCanvas({
|
setMergedCanvas({
|
||||||
|
@ -17,24 +17,24 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { name, type } = image;
|
const { image_name, image_type } = image;
|
||||||
|
|
||||||
if (type !== 'uploads' && type !== 'results') {
|
if (image_type !== 'uploads' && image_type !== 'results') {
|
||||||
moduleLog.warn({ data: image }, `Invalid image type ${type}`);
|
moduleLog.warn({ data: image }, `Invalid image type ${image_type}`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const selectedImageName = getState().gallery.selectedImage?.name;
|
const selectedImageName = getState().gallery.selectedImage?.image_name;
|
||||||
|
|
||||||
if (selectedImageName === name) {
|
if (selectedImageName === image_name) {
|
||||||
const allIds = getState()[type].ids;
|
const allIds = getState()[image_type].ids;
|
||||||
const allEntities = getState()[type].entities;
|
const allEntities = getState()[image_type].entities;
|
||||||
|
|
||||||
const deletedImageIndex = allIds.findIndex(
|
const deletedImageIndex = allIds.findIndex(
|
||||||
(result) => result.toString() === name
|
(result) => result.toString() === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const filteredIds = allIds.filter((id) => id.toString() !== name);
|
const filteredIds = allIds.filter((id) => id.toString() !== image_name);
|
||||||
|
|
||||||
const newSelectedImageIndex = clamp(
|
const newSelectedImageIndex = clamp(
|
||||||
deletedImageIndex,
|
deletedImageIndex,
|
||||||
@ -53,7 +53,7 @@ export const addRequestedImageDeletionListener = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(imageDeleted({ imageName: name, imageType: type }));
|
dispatch(imageDeleted({ imageName: image_name, imageType: image_type }));
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
import { uploadAdded } from 'features/gallery/store/uploadsSlice';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||||
@ -7,6 +6,7 @@ import { addToast } from 'features/system/store/systemSlice';
|
|||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import { initialImageSelected } from 'features/parameters/store/actions';
|
||||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||||
|
import { isResultsImageDTO, isUploadsImageDTO } from 'services/types/guards';
|
||||||
|
|
||||||
export const addImageUploadedListener = () => {
|
export const addImageUploadedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -14,13 +14,11 @@ export const addImageUploadedListener = () => {
|
|||||||
imageUploaded.fulfilled.match(action) &&
|
imageUploaded.fulfilled.match(action) &&
|
||||||
action.payload.response.image_type !== 'intermediates',
|
action.payload.response.image_type !== 'intermediates',
|
||||||
effect: (action, { dispatch, getState }) => {
|
effect: (action, { dispatch, getState }) => {
|
||||||
const { response } = action.payload;
|
const { response: image } = action.payload;
|
||||||
const { imageType } = action.meta.arg;
|
|
||||||
|
|
||||||
const state = getState();
|
const state = getState();
|
||||||
const image = deserializeImageResponse(response);
|
|
||||||
|
|
||||||
if (imageType === 'uploads') {
|
if (isUploadsImageDTO(image)) {
|
||||||
dispatch(uploadAdded(image));
|
dispatch(uploadAdded(image));
|
||||||
|
|
||||||
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
|
||||||
@ -38,7 +36,7 @@ export const addImageUploadedListener = () => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (imageType === 'results') {
|
if (isResultsImageDTO(image)) {
|
||||||
dispatch(resultAdded(image));
|
dispatch(resultAdded(image));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||||
import { Image, isInvokeAIImage } from 'app/types/invokeai';
|
|
||||||
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
import { selectResultsById } from 'features/gallery/store/resultsSlice';
|
||||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { initialImageSelected } from 'features/parameters/store/actions';
|
import {
|
||||||
|
initialImageSelected,
|
||||||
|
isImageDTO,
|
||||||
|
} from 'features/parameters/store/actions';
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const addInitialImageSelectedListener = () => {
|
export const addInitialImageSelectedListener = () => {
|
||||||
startAppListening({
|
startAppListening({
|
||||||
@ -21,21 +24,21 @@ export const addInitialImageSelectedListener = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isInvokeAIImage(action.payload)) {
|
if (isImageDTO(action.payload)) {
|
||||||
dispatch(initialImageChanged(action.payload));
|
dispatch(initialImageChanged(action.payload));
|
||||||
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
dispatch(addToast(makeToast(t('toast.sentToImageToImage'))));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { name, type } = action.payload;
|
const { image_name, image_type } = action.payload;
|
||||||
|
|
||||||
let image: Image | undefined;
|
let image: ImageDTO | undefined;
|
||||||
const state = getState();
|
const state = getState();
|
||||||
|
|
||||||
if (type === 'results') {
|
if (image_type === 'results') {
|
||||||
image = selectResultsById(state, name);
|
image = selectResultsById(state, image_name);
|
||||||
} else if (type === 'uploads') {
|
} else if (image_type === 'uploads') {
|
||||||
image = selectUploadsById(state, name);
|
image = selectUploadsById(state, image_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!image) {
|
if (!image) {
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
import { invocationComplete } from 'services/events/actions';
|
import { invocationComplete } from 'services/events/actions';
|
||||||
import { isImageOutput } from 'services/types/guards';
|
import { isImageOutput } from 'services/types/guards';
|
||||||
import {
|
import {
|
||||||
buildImageUrls,
|
imageMetadataReceived,
|
||||||
extractTimestampFromImageName,
|
imageUrlsReceived,
|
||||||
} from 'services/util/deserializeImageField';
|
} from 'services/thunks/image';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
|
||||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
|
||||||
import { startAppListening } from '..';
|
import { startAppListening } from '..';
|
||||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
|
||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
|
|
||||||
const nodeDenylist = ['dataURL_image'];
|
const nodeDenylist = ['dataURL_image'];
|
||||||
@ -24,62 +20,40 @@ export const addImageResultReceivedListener = () => {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
},
|
},
|
||||||
effect: (action, { getState, dispatch }) => {
|
effect: async (action, { getState, dispatch, take }) => {
|
||||||
if (!invocationComplete.match(action)) {
|
if (!invocationComplete.match(action)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const { data, shouldFetchImages } = action.payload;
|
const { data } = action.payload;
|
||||||
const { result, node, graph_execution_state_id } = data;
|
const { result, node, graph_execution_state_id } = data;
|
||||||
|
|
||||||
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
|
||||||
const name = result.image.image_name;
|
const { image_name, image_type } = result.image;
|
||||||
const type = result.image.image_type;
|
|
||||||
const state = getState();
|
|
||||||
|
|
||||||
// if we need to refetch, set URLs to placeholder for now
|
dispatch(
|
||||||
const { url, thumbnail } = shouldFetchImages
|
imageUrlsReceived({ imageName: image_name, imageType: image_type })
|
||||||
? { url: '', thumbnail: '' }
|
);
|
||||||
: buildImageUrls(type, name);
|
|
||||||
|
|
||||||
const timestamp = extractTimestampFromImageName(name);
|
dispatch(
|
||||||
|
imageMetadataReceived({
|
||||||
const image: Image = {
|
imageName: image_name,
|
||||||
name,
|
imageType: image_type,
|
||||||
type,
|
})
|
||||||
url,
|
);
|
||||||
thumbnail,
|
|
||||||
metadata: {
|
|
||||||
created: timestamp,
|
|
||||||
width: result.width,
|
|
||||||
height: result.height,
|
|
||||||
invokeai: {
|
|
||||||
session_id: graph_execution_state_id,
|
|
||||||
...(node ? { node } : {}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
dispatch(resultAdded(image));
|
|
||||||
|
|
||||||
if (state.gallery.shouldAutoSwitchToNewImages) {
|
|
||||||
dispatch(imageSelected(image));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.config.shouldFetchImages) {
|
|
||||||
dispatch(imageReceived({ imageName: name, imageType: type }));
|
|
||||||
dispatch(
|
|
||||||
thumbnailReceived({
|
|
||||||
thumbnailName: name,
|
|
||||||
thumbnailType: type,
|
|
||||||
})
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Handle canvas image
|
||||||
if (
|
if (
|
||||||
graph_execution_state_id ===
|
graph_execution_state_id ===
|
||||||
state.canvas.layerState.stagingArea.sessionId
|
getState().canvas.layerState.stagingArea.sessionId
|
||||||
) {
|
) {
|
||||||
|
const [{ payload: image }] = await take(
|
||||||
|
(
|
||||||
|
action
|
||||||
|
): action is ReturnType<typeof imageMetadataReceived.fulfilled> =>
|
||||||
|
imageMetadataReceived.fulfilled.match(action) &&
|
||||||
|
action.payload.image_name === image_name
|
||||||
|
);
|
||||||
dispatch(addImageToStagingArea(image));
|
dispatch(addImageToStagingArea(image));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,21 +122,21 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
|
|||||||
/**
|
/**
|
||||||
* ResultImage
|
* ResultImage
|
||||||
*/
|
*/
|
||||||
export type Image = {
|
// export ty`pe Image = {
|
||||||
name: string;
|
// name: string;
|
||||||
type: ImageType;
|
// type: ImageType;
|
||||||
url: string;
|
// url: string;
|
||||||
thumbnail: string;
|
// thumbnail: string;
|
||||||
metadata: ImageResponseMetadata;
|
// metadata: ImageResponseMetadata;
|
||||||
};
|
// };
|
||||||
|
|
||||||
export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
// export const isInvokeAIImage = (obj: Image | SelectedImage): obj is Image => {
|
||||||
if ('url' in obj && 'thumbnail' in obj) {
|
// if ('url' in obj && 'thumbnail' in obj) {
|
||||||
return true;
|
// return true;
|
||||||
}
|
// }
|
||||||
|
|
||||||
return false;
|
// return false;
|
||||||
};
|
// };
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Types related to the system status.
|
* Types related to the system status.
|
||||||
@ -346,7 +346,6 @@ export type AppConfig = {
|
|||||||
/**
|
/**
|
||||||
* Whether or not we need to re-fetch images
|
* Whether or not we need to re-fetch images
|
||||||
*/
|
*/
|
||||||
shouldFetchImages: boolean;
|
|
||||||
disabledTabs: InvokeTabName[];
|
disabledTabs: InvokeTabName[];
|
||||||
disabledFeatures: AppFeature[];
|
disabledFeatures: AppFeature[];
|
||||||
disabledSDFeatures: SDFeature[];
|
disabledSDFeatures: SDFeature[];
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import { Badge, Flex } from '@chakra-ui/react';
|
import { Badge, Flex } from '@chakra-ui/react';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
import { isNumber, isString } from 'lodash-es';
|
import { isNumber, isString } from 'lodash-es';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
type ImageMetadataOverlayProps = {
|
type ImageMetadataOverlayProps = {
|
||||||
image: Image;
|
image: ImageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
||||||
@ -17,11 +17,11 @@ const ImageMetadataOverlay = ({ image }: ImageMetadataOverlayProps) => {
|
|||||||
}, [image.metadata]);
|
}, [image.metadata]);
|
||||||
|
|
||||||
const model = useMemo(() => {
|
const model = useMemo(() => {
|
||||||
if (!isString(image.metadata?.invokeai?.node?.model)) {
|
if (!isString(image.metadata?.model)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
return image.metadata?.invokeai?.node?.model;
|
return image.metadata?.model;
|
||||||
}, [image.metadata]);
|
}, [image.metadata]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
12
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
12
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
/**
|
||||||
|
* Comparator function for sorting dates in ascending order
|
||||||
|
*/
|
||||||
|
export const dateComparator = (a: string, b: string) => {
|
||||||
|
const dateA = new Date(a);
|
||||||
|
const dateB = new Date(b);
|
||||||
|
|
||||||
|
// sort in ascending order
|
||||||
|
if (dateA > dateB) return 1;
|
||||||
|
if (dateA < dateB) return -1;
|
||||||
|
return 0;
|
||||||
|
};
|
@ -46,7 +46,7 @@ const IAICanvasObjectRenderer = () => {
|
|||||||
key={i}
|
key={i}
|
||||||
x={obj.x}
|
x={obj.x}
|
||||||
y={obj.y}
|
y={obj.y}
|
||||||
url={getUrl(obj.image.url)}
|
url={getUrl(obj.image.image_url)}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
} else if (isCanvasBaseLine(obj)) {
|
} else if (isCanvasBaseLine(obj)) {
|
||||||
|
@ -62,7 +62,7 @@ const IAICanvasStagingArea = (props: Props) => {
|
|||||||
<Group {...rest}>
|
<Group {...rest}>
|
||||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||||
<IAICanvasImage
|
<IAICanvasImage
|
||||||
url={getUrl(currentStagingAreaImage.image.url)}
|
url={getUrl(currentStagingAreaImage.image.image_url)}
|
||||||
x={x}
|
x={x}
|
||||||
y={y}
|
y={y}
|
||||||
/>
|
/>
|
||||||
|
@ -157,17 +157,19 @@ const IAICanvasStagingAreaToolbar = () => {
|
|||||||
}
|
}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/>
|
||||||
<IAIIconButton
|
{/* <IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.saveToGallery')}
|
tooltip={t('unifiedCanvas.saveToGallery')}
|
||||||
aria-label={t('unifiedCanvas.saveToGallery')}
|
aria-label={t('unifiedCanvas.saveToGallery')}
|
||||||
icon={<FaSave />}
|
icon={<FaSave />}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
dispatch(
|
dispatch(
|
||||||
saveStagingAreaImageToGallery(currentStagingAreaImage.image.url)
|
saveStagingAreaImageToGallery(
|
||||||
|
currentStagingAreaImage.image.image_url
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
colorScheme="accent"
|
colorScheme="accent"
|
||||||
/>
|
/> */}
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
tooltip={t('unifiedCanvas.discardAll')}
|
tooltip={t('unifiedCanvas.discardAll')}
|
||||||
aria-label={t('unifiedCanvas.discardAll')}
|
aria-label={t('unifiedCanvas.discardAll')}
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import {
|
import {
|
||||||
roundDownToMultiple,
|
roundDownToMultiple,
|
||||||
roundToMultiple,
|
roundToMultiple,
|
||||||
@ -29,6 +28,7 @@ import {
|
|||||||
isCanvasBaseImage,
|
isCanvasBaseImage,
|
||||||
isCanvasMaskLine,
|
isCanvasMaskLine,
|
||||||
} from './canvasTypes';
|
} from './canvasTypes';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const initialLayerState: CanvasLayerState = {
|
export const initialLayerState: CanvasLayerState = {
|
||||||
objects: [],
|
objects: [],
|
||||||
@ -157,9 +157,9 @@ export const canvasSlice = createSlice({
|
|||||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||||
state.cursorPosition = action.payload;
|
state.cursorPosition = action.payload;
|
||||||
},
|
},
|
||||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
setInitialCanvasImage: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
const image = action.payload;
|
const image = action.payload;
|
||||||
const { width, height } = image.metadata;
|
const { width, height } = image;
|
||||||
const { stageDimensions } = state;
|
const { stageDimensions } = state;
|
||||||
|
|
||||||
const newBoundingBoxDimensions = {
|
const newBoundingBoxDimensions = {
|
||||||
@ -302,7 +302,7 @@ export const canvasSlice = createSlice({
|
|||||||
selectedImageIndex: -1,
|
selectedImageIndex: -1,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
addImageToStagingArea: (state, action: PayloadAction<InvokeAI.Image>) => {
|
addImageToStagingArea: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
const image = action.payload;
|
const image = action.payload;
|
||||||
|
|
||||||
if (!image || !state.layerState.stagingArea.boundingBox) {
|
if (!image || !state.layerState.stagingArea.boundingBox) {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
import { IRect, Vector2d } from 'konva/lib/types';
|
import { IRect, Vector2d } from 'konva/lib/types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const LAYER_NAMES_DICT = [
|
export const LAYER_NAMES_DICT = [
|
||||||
{ key: 'Base', value: 'base' },
|
{ key: 'Base', value: 'base' },
|
||||||
@ -37,7 +38,7 @@ export type CanvasImage = {
|
|||||||
y: number;
|
y: number;
|
||||||
width: number;
|
width: number;
|
||||||
height: number;
|
height: number;
|
||||||
image: InvokeAI.Image;
|
image: ImageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type CanvasMaskLine = {
|
export type CanvasMaskLine = {
|
||||||
|
@ -195,14 +195,14 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (shouldTransformUrls) {
|
if (shouldTransformUrls) {
|
||||||
return getUrl(image.url);
|
return getUrl(image.image_url);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (image.url.startsWith('http')) {
|
if (image.image_url.startsWith('http')) {
|
||||||
return image.url;
|
return image.image_url;
|
||||||
}
|
}
|
||||||
|
|
||||||
return window.location.toString() + image.url;
|
return window.location.toString() + image.image_url;
|
||||||
};
|
};
|
||||||
|
|
||||||
const url = getImageUrl();
|
const url = getImageUrl();
|
||||||
|
@ -61,8 +61,8 @@ const CurrentImagePreview = () => {
|
|||||||
if (!image) {
|
if (!image) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
e.dataTransfer.setData('invokeai/imageName', image.name);
|
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||||
e.dataTransfer.setData('invokeai/imageType', image.type);
|
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
||||||
e.dataTransfer.effectAllowed = 'move';
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
},
|
},
|
||||||
[image]
|
[image]
|
||||||
@ -108,7 +108,7 @@ const CurrentImagePreview = () => {
|
|||||||
image && (
|
image && (
|
||||||
<>
|
<>
|
||||||
<Image
|
<Image
|
||||||
src={getUrl(image.url)}
|
src={getUrl(image.image_url)}
|
||||||
fallbackStrategy="beforeLoadOrError"
|
fallbackStrategy="beforeLoadOrError"
|
||||||
fallback={<ImageFallbackSpinner />}
|
fallback={<ImageFallbackSpinner />}
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
|
@ -13,7 +13,6 @@ import { DragEvent, MouseEvent, memo, useCallback, useState } from 'react';
|
|||||||
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
import { FaCheck, FaExpand, FaImage, FaShare, FaTrash } from 'react-icons/fa';
|
||||||
import DeleteImageModal from './DeleteImageModal';
|
import DeleteImageModal from './DeleteImageModal';
|
||||||
import { ContextMenu } from 'chakra-ui-contextmenu';
|
import { ContextMenu } from 'chakra-ui-contextmenu';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import {
|
import {
|
||||||
resizeAndScaleCanvas,
|
resizeAndScaleCanvas,
|
||||||
setInitialCanvasImage,
|
setInitialCanvasImage,
|
||||||
@ -39,6 +38,7 @@ import {
|
|||||||
sentImageToImg2Img,
|
sentImageToImg2Img,
|
||||||
} from '../store/actions';
|
} from '../store/actions';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const selector = createSelector(
|
export const selector = createSelector(
|
||||||
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
[gallerySelector, systemSelector, lightboxSelector, activeTabNameSelector],
|
||||||
@ -70,14 +70,16 @@ export const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
interface HoverableImageProps {
|
interface HoverableImageProps {
|
||||||
image: InvokeAI.Image;
|
image: ImageDTO;
|
||||||
isSelected: boolean;
|
isSelected: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const memoEqualityCheck = (
|
const memoEqualityCheck = (
|
||||||
prev: HoverableImageProps,
|
prev: HoverableImageProps,
|
||||||
next: HoverableImageProps
|
next: HoverableImageProps
|
||||||
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
|
) =>
|
||||||
|
prev.image.image_name === next.image.image_name &&
|
||||||
|
prev.isSelected === next.isSelected;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||||
@ -100,7 +102,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
} = useDisclosure();
|
} = useDisclosure();
|
||||||
|
|
||||||
const { image, isSelected } = props;
|
const { image, isSelected } = props;
|
||||||
const { url, thumbnail, name } = image;
|
const { image_url, thumbnail_url, image_name } = image;
|
||||||
const { getUrl } = useGetUrl();
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||||
@ -144,8 +146,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
|
|
||||||
const handleDragStart = useCallback(
|
const handleDragStart = useCallback(
|
||||||
(e: DragEvent<HTMLDivElement>) => {
|
(e: DragEvent<HTMLDivElement>) => {
|
||||||
e.dataTransfer.setData('invokeai/imageName', image.name);
|
e.dataTransfer.setData('invokeai/imageName', image.image_name);
|
||||||
e.dataTransfer.setData('invokeai/imageType', image.type);
|
e.dataTransfer.setData('invokeai/imageType', image.image_type);
|
||||||
e.dataTransfer.effectAllowed = 'move';
|
e.dataTransfer.effectAllowed = 'move';
|
||||||
},
|
},
|
||||||
[image]
|
[image]
|
||||||
@ -153,11 +155,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
|
|
||||||
// Recall parameters handlers
|
// Recall parameters handlers
|
||||||
const handleRecallPrompt = useCallback(() => {
|
const handleRecallPrompt = useCallback(() => {
|
||||||
recallPrompt(image.metadata?.invokeai?.node?.prompt);
|
recallPrompt(image.metadata?.positive_conditioning);
|
||||||
}, [image, recallPrompt]);
|
}, [image, recallPrompt]);
|
||||||
|
|
||||||
const handleRecallSeed = useCallback(() => {
|
const handleRecallSeed = useCallback(() => {
|
||||||
recallSeed(image.metadata.invokeai?.node?.seed);
|
recallSeed(image.metadata?.seed);
|
||||||
}, [image, recallSeed]);
|
}, [image, recallSeed]);
|
||||||
|
|
||||||
const handleSendToImageToImage = useCallback(() => {
|
const handleSendToImageToImage = useCallback(() => {
|
||||||
@ -165,9 +167,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
dispatch(initialImageSelected(image));
|
dispatch(initialImageSelected(image));
|
||||||
}, [dispatch, image]);
|
}, [dispatch, image]);
|
||||||
|
|
||||||
const handleRecallInitialImage = useCallback(() => {
|
// const handleRecallInitialImage = useCallback(() => {
|
||||||
recallInitialImage(image.metadata.invokeai?.node?.image);
|
// recallInitialImage(image.metadata.invokeai?.node?.image);
|
||||||
}, [image, recallInitialImage]);
|
// }, [image, recallInitialImage]);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* TODO: the rest of these
|
* TODO: the rest of these
|
||||||
@ -200,7 +202,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleOpenInNewTab = () => {
|
const handleOpenInNewTab = () => {
|
||||||
window.open(getUrl(image.url), '_blank');
|
window.open(getUrl(image.image_url), '_blank');
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -223,7 +225,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleRecallPrompt}
|
onClickCapture={handleRecallPrompt}
|
||||||
isDisabled={image?.metadata?.invokeai?.node?.prompt === undefined}
|
isDisabled={image?.metadata?.positive_conditioning === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.usePrompt')}
|
{t('parameters.usePrompt')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
@ -231,23 +233,23 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleRecallSeed}
|
onClickCapture={handleRecallSeed}
|
||||||
isDisabled={image?.metadata?.invokeai?.node?.seed === undefined}
|
isDisabled={image?.metadata?.seed === undefined}
|
||||||
>
|
>
|
||||||
{t('parameters.useSeed')}
|
{t('parameters.useSeed')}
|
||||||
</MenuItem>
|
</MenuItem>
|
||||||
<MenuItem
|
{/* <MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleRecallInitialImage}
|
onClickCapture={handleRecallInitialImage}
|
||||||
isDisabled={image?.metadata?.invokeai?.node?.type !== 'img2img'}
|
isDisabled={image?.metadata?.type !== 'img2img'}
|
||||||
>
|
>
|
||||||
{t('parameters.useInitImg')}
|
{t('parameters.useInitImg')}
|
||||||
</MenuItem>
|
</MenuItem> */}
|
||||||
<MenuItem
|
<MenuItem
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
icon={<IoArrowUndoCircleOutline />}
|
||||||
onClickCapture={handleUseAllParameters}
|
onClickCapture={handleUseAllParameters}
|
||||||
isDisabled={
|
isDisabled={
|
||||||
!['txt2img', 'img2img', 'inpaint'].includes(
|
!['txt2img', 'img2img', 'inpaint'].includes(
|
||||||
String(image?.metadata?.invokeai?.node?.type)
|
String(image?.metadata?.type)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
>
|
>
|
||||||
@ -278,7 +280,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
{(ref) => (
|
{(ref) => (
|
||||||
<Box
|
<Box
|
||||||
position="relative"
|
position="relative"
|
||||||
key={name}
|
key={image_name}
|
||||||
onMouseOver={handleMouseOver}
|
onMouseOver={handleMouseOver}
|
||||||
onMouseOut={handleMouseOut}
|
onMouseOut={handleMouseOut}
|
||||||
userSelect="none"
|
userSelect="none"
|
||||||
@ -303,7 +305,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
|||||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||||
}
|
}
|
||||||
rounded="md"
|
rounded="md"
|
||||||
src={getUrl(thumbnail || url)}
|
src={getUrl(thumbnail_url || image_url)}
|
||||||
fallback={<FaImage />}
|
fallback={<FaImage />}
|
||||||
sx={{
|
sx={{
|
||||||
width: '100%',
|
width: '100%',
|
||||||
|
@ -12,7 +12,7 @@ import { memo, useCallback } from 'react';
|
|||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import DeleteImageModal from '../DeleteImageModal';
|
import DeleteImageModal from '../DeleteImageModal';
|
||||||
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
import { requestedImageDeletion } from 'features/gallery/store/actions';
|
||||||
import { Image } from 'app/types/invokeai';
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
[systemSelector],
|
[systemSelector],
|
||||||
@ -30,7 +30,7 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
type DeleteImageButtonProps = {
|
type DeleteImageButtonProps = {
|
||||||
image: Image | undefined;
|
image: ImageDTO | undefined;
|
||||||
};
|
};
|
||||||
|
|
||||||
const DeleteImageButton = (props: DeleteImageButtonProps) => {
|
const DeleteImageButton = (props: DeleteImageButtonProps) => {
|
||||||
|
@ -5,7 +5,6 @@ import {
|
|||||||
FlexProps,
|
FlexProps,
|
||||||
Grid,
|
Grid,
|
||||||
Icon,
|
Icon,
|
||||||
Image,
|
|
||||||
Text,
|
Text,
|
||||||
forwardRef,
|
forwardRef,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
@ -51,10 +50,10 @@ import { uploadsAdapter } from '../store/uploadsSlice';
|
|||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
import { Virtuoso, VirtuosoGrid } from 'react-virtuoso';
|
||||||
import { Image as ImageType } from 'app/types/invokeai';
|
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import GalleryProgressImage from './GalleryProgressImage';
|
import GalleryProgressImage from './GalleryProgressImage';
|
||||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||||
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
const PROGRESS_IMAGE_PLACEHOLDER = 'PROGRESS_IMAGE_PLACEHOLDER';
|
||||||
@ -66,7 +65,7 @@ const categorySelector = createSelector(
|
|||||||
const { currentCategory } = gallery;
|
const { currentCategory } = gallery;
|
||||||
|
|
||||||
if (currentCategory === 'results') {
|
if (currentCategory === 'results') {
|
||||||
const tempImages: (ImageType | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
const tempImages: (ImageDTO | typeof PROGRESS_IMAGE_PLACEHOLDER)[] = [];
|
||||||
|
|
||||||
if (system.progressImage) {
|
if (system.progressImage) {
|
||||||
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
tempImages.push(PROGRESS_IMAGE_PLACEHOLDER);
|
||||||
@ -352,7 +351,7 @@ const ImageGalleryContent = () => {
|
|||||||
const isSelected =
|
const isSelected =
|
||||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||||
? false
|
? false
|
||||||
: selectedImage?.name === image?.name;
|
: selectedImage?.image_name === image?.image_name;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex sx={{ pb: 2 }}>
|
<Flex sx={{ pb: 2 }}>
|
||||||
@ -362,7 +361,7 @@ const ImageGalleryContent = () => {
|
|||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.name}-${image.thumbnail}`}
|
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
/>
|
/>
|
||||||
@ -385,13 +384,13 @@ const ImageGalleryContent = () => {
|
|||||||
const isSelected =
|
const isSelected =
|
||||||
image === PROGRESS_IMAGE_PLACEHOLDER
|
image === PROGRESS_IMAGE_PLACEHOLDER
|
||||||
? false
|
? false
|
||||||
: selectedImage?.name === image?.name;
|
: selectedImage?.image_name === image?.image_name;
|
||||||
|
|
||||||
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
return image === PROGRESS_IMAGE_PLACEHOLDER ? (
|
||||||
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
<GalleryProgressImage key={PROGRESS_IMAGE_PLACEHOLDER} />
|
||||||
) : (
|
) : (
|
||||||
<HoverableImage
|
<HoverableImage
|
||||||
key={`${image.name}-${image.thumbnail}`}
|
key={`${image.image_name}-${image.thumbnail_url}`}
|
||||||
image={image}
|
image={image}
|
||||||
isSelected={isSelected}
|
isSelected={isSelected}
|
||||||
/>
|
/>
|
||||||
|
@ -18,7 +18,9 @@ import {
|
|||||||
setCfgScale,
|
setCfgScale,
|
||||||
setHeight,
|
setHeight,
|
||||||
setImg2imgStrength,
|
setImg2imgStrength,
|
||||||
|
setNegativePrompt,
|
||||||
setPerlin,
|
setPerlin,
|
||||||
|
setPositivePrompt,
|
||||||
setScheduler,
|
setScheduler,
|
||||||
setSeamless,
|
setSeamless,
|
||||||
setSeed,
|
setSeed,
|
||||||
@ -36,6 +38,9 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { FaCopy } from 'react-icons/fa';
|
import { FaCopy } from 'react-icons/fa';
|
||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { filter } from 'lodash-es';
|
||||||
|
import { Scheduler } from 'app/constants';
|
||||||
|
|
||||||
type MetadataItemProps = {
|
type MetadataItemProps = {
|
||||||
isLink?: boolean;
|
isLink?: boolean;
|
||||||
@ -58,7 +63,6 @@ const MetadataItem = ({
|
|||||||
withCopy = false,
|
withCopy = false,
|
||||||
}: MetadataItemProps) => {
|
}: MetadataItemProps) => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
{onClick && (
|
{onClick && (
|
||||||
@ -104,14 +108,14 @@ const MetadataItem = ({
|
|||||||
};
|
};
|
||||||
|
|
||||||
type ImageMetadataViewerProps = {
|
type ImageMetadataViewerProps = {
|
||||||
image: InvokeAI.Image;
|
image: ImageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: I don't know if this is needed.
|
// TODO: I don't know if this is needed.
|
||||||
const memoEqualityCheck = (
|
const memoEqualityCheck = (
|
||||||
prev: ImageMetadataViewerProps,
|
prev: ImageMetadataViewerProps,
|
||||||
next: ImageMetadataViewerProps
|
next: ImageMetadataViewerProps
|
||||||
) => prev.image.name === next.image.name;
|
) => prev.image.image_name === next.image.image_name;
|
||||||
|
|
||||||
// TODO: Show more interesting information in this component.
|
// TODO: Show more interesting information in this component.
|
||||||
|
|
||||||
@ -128,8 +132,9 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
dispatch(setShouldShowImageDetails(false));
|
dispatch(setShouldShowImageDetails(false));
|
||||||
});
|
});
|
||||||
|
|
||||||
const sessionId = image.metadata.invokeai?.session_id;
|
const sessionId = image?.session_id;
|
||||||
const node = image.metadata.invokeai?.node as Record<string, any>;
|
|
||||||
|
const metadata = image?.metadata;
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const { getUrl } = useGetUrl();
|
const { getUrl } = useGetUrl();
|
||||||
@ -154,110 +159,131 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
>
|
>
|
||||||
<Flex gap={2}>
|
<Flex gap={2}>
|
||||||
<Text fontWeight="semibold">File:</Text>
|
<Text fontWeight="semibold">File:</Text>
|
||||||
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
<Link
|
||||||
{image.url.length > 64
|
href={getUrl(image.image_url)}
|
||||||
? image.url.substring(0, 64).concat('...')
|
isExternal
|
||||||
: image.url}
|
maxW="calc(100% - 3rem)"
|
||||||
|
>
|
||||||
|
{image.image_name}
|
||||||
<ExternalLinkIcon mx="2px" />
|
<ExternalLinkIcon mx="2px" />
|
||||||
</Link>
|
</Link>
|
||||||
</Flex>
|
</Flex>
|
||||||
{node && Object.keys(node).length > 0 ? (
|
{metadata && Object.keys(metadata).length > 0 ? (
|
||||||
<>
|
<>
|
||||||
{node.type && (
|
{metadata.type && (
|
||||||
<MetadataItem label="Invocation type" value={node.type} />
|
<MetadataItem label="Invocation type" value={metadata.type} />
|
||||||
)}
|
)}
|
||||||
{node.model && <MetadataItem label="Model" value={node.model} />}
|
{metadata.width && (
|
||||||
{node.prompt && (
|
<MetadataItem
|
||||||
|
label="Width"
|
||||||
|
value={metadata.width}
|
||||||
|
onClick={() => dispatch(setWidth(Number(metadata.width)))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.height && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Height"
|
||||||
|
value={metadata.height}
|
||||||
|
onClick={() => dispatch(setHeight(Number(metadata.height)))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.model && (
|
||||||
|
<MetadataItem label="Model" value={metadata.model} />
|
||||||
|
)}
|
||||||
|
{metadata.positive_conditioning && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Prompt"
|
label="Prompt"
|
||||||
labelPosition="top"
|
labelPosition="top"
|
||||||
value={
|
value={
|
||||||
typeof node.prompt === 'string'
|
typeof metadata.positive_conditioning === 'string'
|
||||||
? node.prompt
|
? metadata.positive_conditioning
|
||||||
: promptToString(node.prompt)
|
: promptToString(metadata.positive_conditioning)
|
||||||
}
|
}
|
||||||
onClick={() => setBothPrompts(node.prompt)}
|
onClick={() => setPositivePrompt(metadata.positive_conditioning!)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.seed !== undefined && (
|
{metadata.negative_conditioning && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Prompt"
|
||||||
|
labelPosition="top"
|
||||||
|
value={
|
||||||
|
typeof metadata.negative_conditioning === 'string'
|
||||||
|
? metadata.negative_conditioning
|
||||||
|
: promptToString(metadata.negative_conditioning)
|
||||||
|
}
|
||||||
|
onClick={() => setNegativePrompt(metadata.negative_conditioning!)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.seed !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seed"
|
label="Seed"
|
||||||
value={node.seed}
|
value={metadata.seed}
|
||||||
onClick={() => dispatch(setSeed(Number(node.seed)))}
|
onClick={() => dispatch(setSeed(Number(metadata.seed)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.threshold !== undefined && (
|
{/* {metadata.threshold !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Noise Threshold"
|
label="Noise Threshold"
|
||||||
value={node.threshold}
|
value={metadata.threshold}
|
||||||
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
|
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.perlin !== undefined && (
|
{metadata.perlin !== undefined && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Perlin Noise"
|
label="Perlin Noise"
|
||||||
value={node.perlin}
|
value={metadata.perlin}
|
||||||
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
|
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
||||||
/>
|
/>
|
||||||
)}
|
)} */}
|
||||||
{node.scheduler && (
|
{metadata.scheduler && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Scheduler"
|
label="Scheduler"
|
||||||
value={node.scheduler}
|
value={metadata.scheduler}
|
||||||
onClick={() => dispatch(setScheduler(node.scheduler))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{node.steps && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Steps"
|
|
||||||
value={node.steps}
|
|
||||||
onClick={() => dispatch(setSteps(Number(node.steps)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{node.cfg_scale !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="CFG scale"
|
|
||||||
value={node.cfg_scale}
|
|
||||||
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{node.variations && node.variations.length > 0 && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Seed-weight pairs"
|
|
||||||
value={seedWeightsToString(node.variations)}
|
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
|
dispatch(setScheduler(metadata.scheduler as Scheduler))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.seamless && (
|
{metadata.steps && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Steps"
|
||||||
|
value={metadata.steps}
|
||||||
|
onClick={() => dispatch(setSteps(Number(metadata.steps)))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.cfg_scale !== undefined && (
|
||||||
|
<MetadataItem
|
||||||
|
label="CFG scale"
|
||||||
|
value={metadata.cfg_scale}
|
||||||
|
onClick={() => dispatch(setCfgScale(Number(metadata.cfg_scale)))}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||||
|
<MetadataItem
|
||||||
|
label="Seed-weight pairs"
|
||||||
|
value={seedWeightsToString(metadata.variations)}
|
||||||
|
onClick={() =>
|
||||||
|
dispatch(
|
||||||
|
setSeedWeights(seedWeightsToString(metadata.variations))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{metadata.seamless && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Seamless"
|
label="Seamless"
|
||||||
value={node.seamless}
|
value={metadata.seamless}
|
||||||
onClick={() => dispatch(setSeamless(node.seamless))}
|
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.hires_fix && (
|
{metadata.hires_fix && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="High Resolution Optimization"
|
label="High Resolution Optimization"
|
||||||
value={node.hires_fix}
|
value={metadata.hires_fix}
|
||||||
onClick={() => dispatch(setHiresFix(node.hires_fix))}
|
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
||||||
/>
|
/>
|
||||||
)}
|
)} */}
|
||||||
{node.width && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Width"
|
|
||||||
value={node.width}
|
|
||||||
onClick={() => dispatch(setWidth(Number(node.width)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{node.height && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Height"
|
|
||||||
value={node.height}
|
|
||||||
onClick={() => dispatch(setHeight(Number(node.height)))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* {init_image_path && (
|
{/* {init_image_path && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Initial image"
|
label="Initial image"
|
||||||
@ -266,22 +292,22 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|||||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||||
/>
|
/>
|
||||||
)} */}
|
)} */}
|
||||||
{node.strength && (
|
{metadata.strength && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Image to image strength"
|
label="Image to image strength"
|
||||||
value={node.strength}
|
value={metadata.strength}
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
dispatch(setImg2imgStrength(Number(node.strength)))
|
dispatch(setImg2imgStrength(Number(metadata.strength)))
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{node.fit && (
|
{/* {metadata.fit && (
|
||||||
<MetadataItem
|
<MetadataItem
|
||||||
label="Image to image fit"
|
label="Image to image fit"
|
||||||
value={node.fit}
|
value={metadata.fit}
|
||||||
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
|
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
||||||
/>
|
/>
|
||||||
)}
|
)} */}
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
<Center width="100%" pt={10}>
|
<Center width="100%" pt={10}>
|
||||||
|
@ -1,470 +0,0 @@
|
|||||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
|
||||||
import {
|
|
||||||
Box,
|
|
||||||
Center,
|
|
||||||
Flex,
|
|
||||||
Heading,
|
|
||||||
IconButton,
|
|
||||||
Link,
|
|
||||||
Text,
|
|
||||||
Tooltip,
|
|
||||||
} from '@chakra-ui/react';
|
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
|
||||||
import promptToString from 'common/util/promptToString';
|
|
||||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
|
||||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
|
||||||
import {
|
|
||||||
setCfgScale,
|
|
||||||
setHeight,
|
|
||||||
setImg2imgStrength,
|
|
||||||
// setInitialImage,
|
|
||||||
setMaskPath,
|
|
||||||
setPerlin,
|
|
||||||
setSampler,
|
|
||||||
setSeamless,
|
|
||||||
setSeed,
|
|
||||||
setSeedWeights,
|
|
||||||
setShouldFitToWidthHeight,
|
|
||||||
setSteps,
|
|
||||||
setThreshold,
|
|
||||||
setWidth,
|
|
||||||
} from 'features/parameters/store/generationSlice';
|
|
||||||
import {
|
|
||||||
setCodeformerFidelity,
|
|
||||||
setFacetoolStrength,
|
|
||||||
setFacetoolType,
|
|
||||||
setHiresFix,
|
|
||||||
setUpscalingDenoising,
|
|
||||||
setUpscalingLevel,
|
|
||||||
setUpscalingStrength,
|
|
||||||
} from 'features/parameters/store/postprocessingSlice';
|
|
||||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
|
||||||
import { memo } from 'react';
|
|
||||||
import { useHotkeys } from 'react-hotkeys-hook';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { FaCopy } from 'react-icons/fa';
|
|
||||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
|
||||||
import * as png from '@stevebel/png';
|
|
||||||
|
|
||||||
type MetadataItemProps = {
|
|
||||||
isLink?: boolean;
|
|
||||||
label: string;
|
|
||||||
onClick?: () => void;
|
|
||||||
value: number | string | boolean;
|
|
||||||
labelPosition?: string;
|
|
||||||
withCopy?: boolean;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Component to display an individual metadata item or parameter.
|
|
||||||
*/
|
|
||||||
const MetadataItem = ({
|
|
||||||
label,
|
|
||||||
value,
|
|
||||||
onClick,
|
|
||||||
isLink,
|
|
||||||
labelPosition,
|
|
||||||
withCopy = false,
|
|
||||||
}: MetadataItemProps) => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex gap={2}>
|
|
||||||
{onClick && (
|
|
||||||
<Tooltip label={`Recall ${label}`}>
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('accessibility.useThisParameter')}
|
|
||||||
icon={<IoArrowUndoCircleOutline />}
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
fontSize={20}
|
|
||||||
onClick={onClick}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
|
||||||
{withCopy && (
|
|
||||||
<Tooltip label={`Copy ${label}`}>
|
|
||||||
<IconButton
|
|
||||||
aria-label={`Copy ${label}`}
|
|
||||||
icon={<FaCopy />}
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
fontSize={14}
|
|
||||||
onClick={() => navigator.clipboard.writeText(value.toString())}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
)}
|
|
||||||
<Flex direction={labelPosition ? 'column' : 'row'}>
|
|
||||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
|
||||||
{label}:
|
|
||||||
</Text>
|
|
||||||
{isLink ? (
|
|
||||||
<Link href={value.toString()} isExternal wordBreak="break-all">
|
|
||||||
{value.toString()} <ExternalLinkIcon mx="2px" />
|
|
||||||
</Link>
|
|
||||||
) : (
|
|
||||||
<Text overflowY="scroll" wordBreak="break-all">
|
|
||||||
{value.toString()}
|
|
||||||
</Text>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
type ImageMetadataViewerProps = {
|
|
||||||
image: InvokeAI.Image;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: I don't know if this is needed.
|
|
||||||
const memoEqualityCheck = (
|
|
||||||
prev: ImageMetadataViewerProps,
|
|
||||||
next: ImageMetadataViewerProps
|
|
||||||
) => prev.image.name === next.image.name;
|
|
||||||
|
|
||||||
// TODO: Show more interesting information in this component.
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Image metadata viewer overlays currently selected image and provides
|
|
||||||
* access to any of its metadata for use in processing.
|
|
||||||
*/
|
|
||||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
|
|
||||||
const setBothPrompts = useSetBothPrompts();
|
|
||||||
|
|
||||||
useHotkeys('esc', () => {
|
|
||||||
dispatch(setShouldShowImageDetails(false));
|
|
||||||
});
|
|
||||||
|
|
||||||
const metadata = image?.metadata.sd_metadata || {};
|
|
||||||
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
|
||||||
|
|
||||||
const {
|
|
||||||
cfg_scale,
|
|
||||||
fit,
|
|
||||||
height,
|
|
||||||
hires_fix,
|
|
||||||
init_image_path,
|
|
||||||
mask_image_path,
|
|
||||||
orig_path,
|
|
||||||
perlin,
|
|
||||||
postprocessing,
|
|
||||||
prompt,
|
|
||||||
sampler,
|
|
||||||
seamless,
|
|
||||||
seed,
|
|
||||||
steps,
|
|
||||||
strength,
|
|
||||||
threshold,
|
|
||||||
type,
|
|
||||||
variations,
|
|
||||||
width,
|
|
||||||
model_weights,
|
|
||||||
} = metadata;
|
|
||||||
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const { getUrl } = useGetUrl();
|
|
||||||
|
|
||||||
const metadataJSON = JSON.stringify(image, null, 2);
|
|
||||||
|
|
||||||
// fetch(getUrl(image.url))
|
|
||||||
// .then((r) => r.arrayBuffer())
|
|
||||||
// .then((buffer) => {
|
|
||||||
// const { text } = png.decode(buffer);
|
|
||||||
// const metadata = text?.['sd-metadata']
|
|
||||||
// ? JSON.parse(text['sd-metadata'] ?? {})
|
|
||||||
// : {};
|
|
||||||
// console.log(metadata);
|
|
||||||
// });
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Flex
|
|
||||||
sx={{
|
|
||||||
padding: 4,
|
|
||||||
gap: 1,
|
|
||||||
flexDirection: 'column',
|
|
||||||
width: 'full',
|
|
||||||
height: 'full',
|
|
||||||
backdropFilter: 'blur(20px)',
|
|
||||||
bg: 'whiteAlpha.600',
|
|
||||||
_dark: {
|
|
||||||
bg: 'blackAlpha.600',
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<Flex gap={2}>
|
|
||||||
<Text fontWeight="semibold">File:</Text>
|
|
||||||
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
|
||||||
{image.url.length > 64
|
|
||||||
? image.url.substring(0, 64).concat('...')
|
|
||||||
: image.url}
|
|
||||||
<ExternalLinkIcon mx="2px" />
|
|
||||||
</Link>
|
|
||||||
</Flex>
|
|
||||||
<Flex gap={2} direction="column">
|
|
||||||
<Flex gap={2}>
|
|
||||||
<Tooltip label="Copy metadata JSON">
|
|
||||||
<IconButton
|
|
||||||
aria-label={t('accessibility.copyMetadataJson')}
|
|
||||||
icon={<FaCopy />}
|
|
||||||
size="xs"
|
|
||||||
variant="ghost"
|
|
||||||
fontSize={14}
|
|
||||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
|
||||||
/>
|
|
||||||
</Tooltip>
|
|
||||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
|
||||||
</Flex>
|
|
||||||
<Box
|
|
||||||
sx={{
|
|
||||||
mt: 0,
|
|
||||||
mr: 2,
|
|
||||||
mb: 4,
|
|
||||||
ml: 2,
|
|
||||||
padding: 4,
|
|
||||||
borderRadius: 'base',
|
|
||||||
overflowX: 'scroll',
|
|
||||||
wordBreak: 'break-all',
|
|
||||||
bg: 'whiteAlpha.500',
|
|
||||||
_dark: { bg: 'blackAlpha.500' },
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
<pre>{metadataJSON}</pre>
|
|
||||||
</Box>
|
|
||||||
</Flex>
|
|
||||||
{Object.keys(metadata).length > 0 ? (
|
|
||||||
<>
|
|
||||||
{type && <MetadataItem label="Generation type" value={type} />}
|
|
||||||
{model_weights && (
|
|
||||||
<MetadataItem label="Model" value={model_weights} />
|
|
||||||
)}
|
|
||||||
{['esrgan', 'gfpgan'].includes(type) && (
|
|
||||||
<MetadataItem label="Original image" value={orig_path} />
|
|
||||||
)}
|
|
||||||
{prompt && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Prompt"
|
|
||||||
labelPosition="top"
|
|
||||||
value={
|
|
||||||
typeof prompt === 'string' ? prompt : promptToString(prompt)
|
|
||||||
}
|
|
||||||
onClick={() => setBothPrompts(prompt)}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{seed !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Seed"
|
|
||||||
value={seed}
|
|
||||||
onClick={() => dispatch(setSeed(seed))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{threshold !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Noise Threshold"
|
|
||||||
value={threshold}
|
|
||||||
onClick={() => dispatch(setThreshold(threshold))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{perlin !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Perlin Noise"
|
|
||||||
value={perlin}
|
|
||||||
onClick={() => dispatch(setPerlin(perlin))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{sampler && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Sampler"
|
|
||||||
value={sampler}
|
|
||||||
onClick={() => dispatch(setSampler(sampler))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{steps && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Steps"
|
|
||||||
value={steps}
|
|
||||||
onClick={() => dispatch(setSteps(steps))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{cfg_scale !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="CFG scale"
|
|
||||||
value={cfg_scale}
|
|
||||||
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{variations && variations.length > 0 && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Seed-weight pairs"
|
|
||||||
value={seedWeightsToString(variations)}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{seamless && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Seamless"
|
|
||||||
value={seamless}
|
|
||||||
onClick={() => dispatch(setSeamless(seamless))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{hires_fix && (
|
|
||||||
<MetadataItem
|
|
||||||
label="High Resolution Optimization"
|
|
||||||
value={hires_fix}
|
|
||||||
onClick={() => dispatch(setHiresFix(hires_fix))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{width && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Width"
|
|
||||||
value={width}
|
|
||||||
onClick={() => dispatch(setWidth(width))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{height && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Height"
|
|
||||||
value={height}
|
|
||||||
onClick={() => dispatch(setHeight(height))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{/* {init_image_path && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Initial image"
|
|
||||||
value={init_image_path}
|
|
||||||
isLink
|
|
||||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
|
||||||
/>
|
|
||||||
)} */}
|
|
||||||
{mask_image_path && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Mask image"
|
|
||||||
value={mask_image_path}
|
|
||||||
isLink
|
|
||||||
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{type === 'img2img' && strength && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Image to image strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() => dispatch(setImg2imgStrength(strength))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{fit && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Image to image fit"
|
|
||||||
value={fit}
|
|
||||||
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
{postprocessing && postprocessing.length > 0 && (
|
|
||||||
<>
|
|
||||||
<Heading size="sm">Postprocessing</Heading>
|
|
||||||
{postprocessing.map(
|
|
||||||
(
|
|
||||||
postprocess: InvokeAI.PostProcessedImageMetadata,
|
|
||||||
i: number
|
|
||||||
) => {
|
|
||||||
if (postprocess.type === 'esrgan') {
|
|
||||||
const { scale, strength, denoise_str } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
|
|
||||||
<MetadataItem
|
|
||||||
label="Scale"
|
|
||||||
value={scale}
|
|
||||||
onClick={() => dispatch(setUpscalingLevel(scale))}
|
|
||||||
/>
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setUpscalingStrength(strength))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
{denoise_str !== undefined && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Denoising strength"
|
|
||||||
value={denoise_str}
|
|
||||||
onClick={() =>
|
|
||||||
dispatch(setUpscalingDenoising(denoise_str))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
} else if (postprocess.type === 'gfpgan') {
|
|
||||||
const { strength } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${
|
|
||||||
i + 1
|
|
||||||
}: Face restoration (GFPGAN)`}</Text>
|
|
||||||
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setFacetoolStrength(strength));
|
|
||||||
dispatch(setFacetoolType('gfpgan'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
} else if (postprocess.type === 'codeformer') {
|
|
||||||
const { strength, fidelity } = postprocess;
|
|
||||||
return (
|
|
||||||
<Flex key={i} pl={8} gap={1} direction="column">
|
|
||||||
<Text size="md">{`${
|
|
||||||
i + 1
|
|
||||||
}: Face restoration (Codeformer)`}</Text>
|
|
||||||
|
|
||||||
<MetadataItem
|
|
||||||
label="Strength"
|
|
||||||
value={strength}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setFacetoolStrength(strength));
|
|
||||||
dispatch(setFacetoolType('codeformer'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
{fidelity && (
|
|
||||||
<MetadataItem
|
|
||||||
label="Fidelity"
|
|
||||||
value={fidelity}
|
|
||||||
onClick={() => {
|
|
||||||
dispatch(setCodeformerFidelity(fidelity));
|
|
||||||
dispatch(setFacetoolType('codeformer'));
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
{dreamPrompt && (
|
|
||||||
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<Center width="100%" pt={10}>
|
|
||||||
<Text fontSize="lg" fontWeight="semibold">
|
|
||||||
No metadata available
|
|
||||||
</Text>
|
|
||||||
</Center>
|
|
||||||
)}
|
|
||||||
</Flex>
|
|
||||||
);
|
|
||||||
}, memoEqualityCheck);
|
|
||||||
|
|
||||||
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
|
||||||
|
|
||||||
export default ImageMetadataViewer;
|
|
@ -13,11 +13,9 @@ const useGetImageByNameSelector = createSelector(
|
|||||||
|
|
||||||
const useGetImageByNameAndType = () => {
|
const useGetImageByNameAndType = () => {
|
||||||
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
||||||
|
|
||||||
return (name: string, type: ImageType) => {
|
return (name: string, type: ImageType) => {
|
||||||
if (type === 'results') {
|
if (type === 'results') {
|
||||||
const resultImagesResult = allResults[name];
|
const resultImagesResult = allResults[name];
|
||||||
|
|
||||||
if (resultImagesResult) {
|
if (resultImagesResult) {
|
||||||
return resultImagesResult;
|
return resultImagesResult;
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
import { ImageNameAndType } from 'features/parameters/store/actions';
|
||||||
import { SelectedImage } from 'features/parameters/store/actions';
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const requestedImageDeletion = createAction<
|
export const requestedImageDeletion = createAction<
|
||||||
Image | SelectedImage | undefined
|
ImageDTO | ImageNameAndType | undefined
|
||||||
>('gallery/requestedImageDeletion');
|
>('gallery/requestedImageDeletion');
|
||||||
|
|
||||||
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice } from '@reduxjs/toolkit';
|
import { createSlice } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
import { imageReceived, thumbnailReceived } from 'services/thunks/image';
|
|
||||||
import {
|
import {
|
||||||
receivedResultImagesPage,
|
receivedResultImagesPage,
|
||||||
receivedUploadImagesPage,
|
receivedUploadImagesPage,
|
||||||
} from '../../../services/thunks/gallery';
|
} from '../../../services/thunks/gallery';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||||
|
|
||||||
export interface GalleryState {
|
export interface GalleryState {
|
||||||
selectedImage?: Image;
|
selectedImage?: ImageDTO;
|
||||||
galleryImageMinimumWidth: number;
|
galleryImageMinimumWidth: number;
|
||||||
galleryImageObjectFit: GalleryImageObjectFitType;
|
galleryImageObjectFit: GalleryImageObjectFitType;
|
||||||
shouldAutoSwitchToNewImages: boolean;
|
shouldAutoSwitchToNewImages: boolean;
|
||||||
@ -30,7 +29,7 @@ export const gallerySlice = createSlice({
|
|||||||
name: 'gallery',
|
name: 'gallery',
|
||||||
initialState: initialGalleryState,
|
initialState: initialGalleryState,
|
||||||
reducers: {
|
reducers: {
|
||||||
imageSelected: (state, action: PayloadAction<Image | undefined>) => {
|
imageSelected: (state, action: PayloadAction<ImageDTO | undefined>) => {
|
||||||
state.selectedImage = action.payload;
|
state.selectedImage = action.payload;
|
||||||
// TODO: if the user selects an image, disable the auto switch?
|
// TODO: if the user selects an image, disable the auto switch?
|
||||||
// state.shouldAutoSwitchToNewImages = false;
|
// state.shouldAutoSwitchToNewImages = false;
|
||||||
@ -61,37 +60,18 @@ export const gallerySlice = createSlice({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
extraReducers(builder) {
|
extraReducers(builder) {
|
||||||
builder.addCase(imageReceived.fulfilled, (state, action) => {
|
|
||||||
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
|
||||||
// which is currently its own object (instead of a reference to an image in results/uploads)
|
|
||||||
const { imagePath } = action.payload;
|
|
||||||
const { imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (state.selectedImage?.name === imageName) {
|
|
||||||
state.selectedImage.url = imagePath;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
|
|
||||||
// When we get an updated URL for an image, we need to update the selectedImage in gallery,
|
|
||||||
// which is currently its own object (instead of a reference to an image in results/uploads)
|
|
||||||
const { thumbnailPath } = action.payload;
|
|
||||||
const { thumbnailName } = action.meta.arg;
|
|
||||||
|
|
||||||
if (state.selectedImage?.name === thumbnailName) {
|
|
||||||
state.selectedImage.thumbnail = thumbnailPath;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||||
// rehydrate selectedImage URL when results list comes in
|
// rehydrate selectedImage URL when results list comes in
|
||||||
// solves case when outdated URL is in local storage
|
// solves case when outdated URL is in local storage
|
||||||
const selectedImage = state.selectedImage;
|
const selectedImage = state.selectedImage;
|
||||||
if (selectedImage) {
|
if (selectedImage) {
|
||||||
const selectedImageInResults = action.payload.items.find(
|
const selectedImageInResults = action.payload.items.find(
|
||||||
(image) => image.image_name === selectedImage.name
|
(image) => image.image_name === selectedImage.image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
if (selectedImageInResults) {
|
if (selectedImageInResults) {
|
||||||
selectedImage.url = selectedImageInResults.image_url;
|
selectedImage.image_url = selectedImageInResults.image_url;
|
||||||
|
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
||||||
state.selectedImage = selectedImage;
|
state.selectedImage = selectedImage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -102,10 +82,12 @@ export const gallerySlice = createSlice({
|
|||||||
const selectedImage = state.selectedImage;
|
const selectedImage = state.selectedImage;
|
||||||
if (selectedImage) {
|
if (selectedImage) {
|
||||||
const selectedImageInResults = action.payload.items.find(
|
const selectedImageInResults = action.payload.items.find(
|
||||||
(image) => image.image_name === selectedImage.name
|
(image) => image.image_name === selectedImage.image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
if (selectedImageInResults) {
|
if (selectedImageInResults) {
|
||||||
selectedImage.url = selectedImageInResults.image_url;
|
selectedImage.image_url = selectedImageInResults.image_url;
|
||||||
|
selectedImage.thumbnail_url = selectedImageInResults.thumbnail_url;
|
||||||
state.selectedImage = selectedImage;
|
state.selectedImage = selectedImage;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,21 +1,24 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
receivedResultImagesPage,
|
receivedResultImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
|
||||||
import {
|
import {
|
||||||
imageDeleted,
|
imageDeleted,
|
||||||
imageReceived,
|
imageMetadataReceived,
|
||||||
thumbnailReceived,
|
imageUrlsReceived,
|
||||||
} from 'services/thunks/image';
|
} from 'services/thunks/image';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
|
|
||||||
export const resultsAdapter = createEntityAdapter<Image>({
|
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||||
selectId: (image) => image.name,
|
image_type: 'results';
|
||||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
};
|
||||||
|
|
||||||
|
export const resultsAdapter = createEntityAdapter<ResultsImageDTO>({
|
||||||
|
selectId: (image) => image.image_name,
|
||||||
|
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||||
});
|
});
|
||||||
|
|
||||||
type AdditionalResultsState = {
|
type AdditionalResultsState = {
|
||||||
@ -53,13 +56,12 @@ const resultsSlice = createSlice({
|
|||||||
* Received Result Images Page - FULFILLED
|
* Received Result Images Page - FULFILLED
|
||||||
*/
|
*/
|
||||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||||
const { items, page, pages } = action.payload;
|
const { page, pages } = action.payload;
|
||||||
|
|
||||||
const resultImages = items.map((image) =>
|
// We know these will all be of the results type, but it's not represented in the API types
|
||||||
deserializeImageResponse(image)
|
const items = action.payload.items as ResultsImageDTO[];
|
||||||
);
|
|
||||||
|
|
||||||
resultsAdapter.setMany(state, resultImages);
|
resultsAdapter.setMany(state, items);
|
||||||
|
|
||||||
state.page = page;
|
state.page = page;
|
||||||
state.pages = pages;
|
state.pages = pages;
|
||||||
@ -68,33 +70,32 @@ const resultsSlice = createSlice({
|
|||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Image Received - FULFILLED
|
* Image Metadata Received - FULFILLED
|
||||||
*/
|
*/
|
||||||
builder.addCase(imageReceived.fulfilled, (state, action) => {
|
builder.addCase(imageMetadataReceived.fulfilled, (state, action) => {
|
||||||
const { imagePath } = action.payload;
|
const { image_type } = action.payload;
|
||||||
const { imageName } = action.meta.arg;
|
|
||||||
|
|
||||||
resultsAdapter.updateOne(state, {
|
if (image_type === 'results') {
|
||||||
id: imageName,
|
resultsAdapter.upsertOne(state, action.payload as ResultsImageDTO);
|
||||||
changes: {
|
}
|
||||||
url: imagePath,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Thumbnail Received - FULFILLED
|
* Image URLs Received - FULFILLED
|
||||||
*/
|
*/
|
||||||
builder.addCase(thumbnailReceived.fulfilled, (state, action) => {
|
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
const { thumbnailPath } = action.payload;
|
const { image_name, image_type, image_url, thumbnail_url } =
|
||||||
const { thumbnailName } = action.meta.arg;
|
action.payload;
|
||||||
|
|
||||||
resultsAdapter.updateOne(state, {
|
if (image_type === 'results') {
|
||||||
id: thumbnailName,
|
resultsAdapter.updateOne(state, {
|
||||||
changes: {
|
id: image_name,
|
||||||
thumbnail: thumbnailPath,
|
changes: {
|
||||||
},
|
image_url: image_url,
|
||||||
});
|
thumbnail_url: thumbnail_url,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1,17 +1,21 @@
|
|||||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import {
|
import {
|
||||||
receivedUploadImagesPage,
|
receivedUploadImagesPage,
|
||||||
IMAGES_PER_PAGE,
|
IMAGES_PER_PAGE,
|
||||||
} from 'services/thunks/gallery';
|
} from 'services/thunks/gallery';
|
||||||
import { imageDeleted } from 'services/thunks/image';
|
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
||||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
import { ImageDTO } from 'services/api';
|
||||||
|
import { dateComparator } from 'common/util/dateComparator';
|
||||||
|
|
||||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & {
|
||||||
selectId: (image) => image.name,
|
image_type: 'uploads';
|
||||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
};
|
||||||
|
|
||||||
|
export const uploadsAdapter = createEntityAdapter<UploadsImageDTO>({
|
||||||
|
selectId: (image) => image.image_name,
|
||||||
|
sortComparer: (a, b) => dateComparator(b.created_at, a.created_at),
|
||||||
});
|
});
|
||||||
|
|
||||||
type AdditionalUploadsState = {
|
type AdditionalUploadsState = {
|
||||||
@ -49,11 +53,12 @@ const uploadsSlice = createSlice({
|
|||||||
* Received Upload Images Page - FULFILLED
|
* Received Upload Images Page - FULFILLED
|
||||||
*/
|
*/
|
||||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||||
const { items, page, pages } = action.payload;
|
const { page, pages } = action.payload;
|
||||||
|
|
||||||
const images = items.map((image) => deserializeImageResponse(image));
|
// We know these will all be of the uploads type, but it's not represented in the API types
|
||||||
|
const items = action.payload.items as UploadsImageDTO[];
|
||||||
|
|
||||||
uploadsAdapter.setMany(state, images);
|
uploadsAdapter.setMany(state, items);
|
||||||
|
|
||||||
state.page = page;
|
state.page = page;
|
||||||
state.pages = pages;
|
state.pages = pages;
|
||||||
@ -61,6 +66,24 @@ const uploadsSlice = createSlice({
|
|||||||
state.isLoading = false;
|
state.isLoading = false;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Image URLs Received - FULFILLED
|
||||||
|
*/
|
||||||
|
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||||
|
const { image_name, image_type, image_url, thumbnail_url } =
|
||||||
|
action.payload;
|
||||||
|
|
||||||
|
if (image_type === 'uploads') {
|
||||||
|
uploadsAdapter.updateOne(state, {
|
||||||
|
id: image_name,
|
||||||
|
changes: {
|
||||||
|
image_url: image_url,
|
||||||
|
thumbnail_url: thumbnail_url,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delete Image - pending
|
* Delete Image - pending
|
||||||
* Pre-emptively remove the image from the gallery
|
* Pre-emptively remove the image from the gallery
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
||||||
import * as InvokeAI from 'app/types/invokeai';
|
|
||||||
import { useGetUrl } from 'common/util/getUrl';
|
import { useGetUrl } from 'common/util/getUrl';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
type ReactPanZoomProps = {
|
type ReactPanZoomProps = {
|
||||||
image: InvokeAI.Image;
|
image: ImageDTO;
|
||||||
styleClass?: string;
|
styleClass?: string;
|
||||||
alt?: string;
|
alt?: string;
|
||||||
ref?: React.Ref<HTMLImageElement>;
|
ref?: React.Ref<HTMLImageElement>;
|
||||||
@ -37,7 +37,7 @@ export default function ReactPanZoomImage({
|
|||||||
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
||||||
width: '100%',
|
width: '100%',
|
||||||
}}
|
}}
|
||||||
src={getUrl(image.url)}
|
src={getUrl(image.image_url)}
|
||||||
alt={alt}
|
alt={alt}
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={styleClass ? styleClass : ''}
|
className={styleClass ? styleClass : ''}
|
||||||
|
@ -21,7 +21,7 @@ const ImageInputFieldComponent = (
|
|||||||
|
|
||||||
const getImageByNameAndType = useGetImageByNameAndType();
|
const getImageByNameAndType = useGetImageByNameAndType();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const [url, setUrl] = useState<string>();
|
const [url, setUrl] = useState<string | undefined>(field.value?.image_url);
|
||||||
const { getUrl } = useGetUrl();
|
const { getUrl } = useGetUrl();
|
||||||
|
|
||||||
const handleDrop = useCallback(
|
const handleDrop = useCallback(
|
||||||
@ -39,16 +39,13 @@ const ImageInputFieldComponent = (
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
setUrl(image.url);
|
setUrl(image.image_url);
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
fieldValueChanged({
|
fieldValueChanged({
|
||||||
nodeId,
|
nodeId,
|
||||||
fieldName: field.name,
|
fieldName: field.name,
|
||||||
value: {
|
value: image,
|
||||||
image_name: name,
|
|
||||||
image_type: type,
|
|
||||||
},
|
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
},
|
},
|
||||||
|
@ -11,7 +11,7 @@ import {
|
|||||||
NodeChange,
|
NodeChange,
|
||||||
OnConnectStartParams,
|
OnConnectStartParams,
|
||||||
} from 'reactflow';
|
} from 'reactflow';
|
||||||
import { ImageField } from 'services/api';
|
import { ImageDTO } from 'services/api';
|
||||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||||
import { parseSchema } from '../util/parseSchema';
|
import { parseSchema } from '../util/parseSchema';
|
||||||
@ -65,13 +65,7 @@ const nodesSlice = createSlice({
|
|||||||
action: PayloadAction<{
|
action: PayloadAction<{
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
fieldName: string;
|
fieldName: string;
|
||||||
value:
|
value: string | number | boolean | ImageDTO | RgbaColor | undefined;
|
||||||
| string
|
|
||||||
| number
|
|
||||||
| boolean
|
|
||||||
| Pick<ImageField, 'image_name' | 'image_type'>
|
|
||||||
| RgbaColor
|
|
||||||
| undefined;
|
|
||||||
}>
|
}>
|
||||||
) => {
|
) => {
|
||||||
const { nodeId, fieldName, value } = action.payload;
|
const { nodeId, fieldName, value } = action.payload;
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import { OpenAPIV3 } from 'openapi-types';
|
import { OpenAPIV3 } from 'openapi-types';
|
||||||
import { RgbaColor } from 'react-colorful';
|
import { RgbaColor } from 'react-colorful';
|
||||||
import { ImageField } from 'services/api';
|
import { Graph, ImageDTO } from 'services/api';
|
||||||
import { AnyInvocationType } from 'services/events/types';
|
import { AnyInvocationType } from 'services/events/types';
|
||||||
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
|
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||||
|
|
||||||
export type InvocationValue = {
|
export type InvocationValue = {
|
||||||
id: string;
|
id: string;
|
||||||
@ -179,7 +182,7 @@ export type ConditioningInputFieldValue = FieldValueBase & {
|
|||||||
|
|
||||||
export type ImageInputFieldValue = FieldValueBase & {
|
export type ImageInputFieldValue = FieldValueBase & {
|
||||||
type: 'image';
|
type: 'image';
|
||||||
value?: Pick<ImageField, 'image_name' | 'image_type'>;
|
value?: ImageDTO;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type ModelInputFieldValue = FieldValueBase & {
|
export type ModelInputFieldValue = FieldValueBase & {
|
||||||
@ -245,7 +248,7 @@ export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: Pick<ImageField, 'image_name' | 'image_type'>;
|
default: ImageDTO;
|
||||||
type: 'image';
|
type: 'image';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,35 +1,131 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { Graph } from 'services/api';
|
import {
|
||||||
import { buildImg2ImgNode } from '../nodeBuilders/buildImageToImageNode';
|
CompelInvocation,
|
||||||
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
|
Graph,
|
||||||
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
|
ImageToLatentsInvocation,
|
||||||
import { buildEdges } from '../edgeBuilders/buildEdges';
|
LatentsToImageInvocation,
|
||||||
|
LatentsToLatentsInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
|
||||||
|
import { log } from 'app/logging/useLogger';
|
||||||
|
|
||||||
|
const moduleLog = log.child({ namespace: 'buildImageToImageGraph' });
|
||||||
|
|
||||||
|
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||||
|
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||||
|
const IMAGE_TO_LATENTS = 'image_to_latents';
|
||||||
|
const LATENTS_TO_LATENTS = 'latents_to_latents';
|
||||||
|
const LATENTS_TO_IMAGE = 'latents_to_image';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Linear workflow graph.
|
* Builds the Image to Image tab graph.
|
||||||
*/
|
*/
|
||||||
export const buildImageToImageGraph = (state: RootState): Graph => {
|
export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||||
const baseNode = buildImg2ImgNode(state);
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
initialImage,
|
||||||
|
img2imgStrength: strength,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
// We always range and iterate nodes, no matter the iteration count
|
if (!initialImage) {
|
||||||
// This is required to provide the correct seeds to the backend engine
|
moduleLog.error('No initial image found in state');
|
||||||
const rangeNode = buildRangeNode(state);
|
throw new Error('No initial image found in state');
|
||||||
const iterateNode = buildIterateNode();
|
}
|
||||||
|
|
||||||
// Build the edges for the nodes selected.
|
let graph: NonNullableGraph = {
|
||||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
nodes: {},
|
||||||
|
edges: [],
|
||||||
// Assemble!
|
|
||||||
const graph = {
|
|
||||||
nodes: {
|
|
||||||
[rangeNode.id]: rangeNode,
|
|
||||||
[iterateNode.id]: iterateNode,
|
|
||||||
[baseNode.id]: baseNode,
|
|
||||||
},
|
|
||||||
edges,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
|
// Create the conditioning, t2l and l2i nodes
|
||||||
|
const positiveConditioningNode: CompelInvocation = {
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
type: 'compel',
|
||||||
|
prompt: positivePrompt,
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
const negativeConditioningNode: CompelInvocation = {
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
type: 'compel',
|
||||||
|
prompt: negativePrompt,
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
const imageToLatentsNode: ImageToLatentsInvocation = {
|
||||||
|
id: IMAGE_TO_LATENTS,
|
||||||
|
type: 'i2l',
|
||||||
|
model,
|
||||||
|
image: {
|
||||||
|
image_name: initialImage?.image_name,
|
||||||
|
image_type: initialImage?.image_type,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const latentsToLatentsNode: LatentsToLatentsInvocation = {
|
||||||
|
id: LATENTS_TO_LATENTS,
|
||||||
|
type: 'l2l',
|
||||||
|
cfg_scale,
|
||||||
|
model,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
strength,
|
||||||
|
};
|
||||||
|
|
||||||
|
const latentsToImageNode: LatentsToImageInvocation = {
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
type: 'l2i',
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add to the graph
|
||||||
|
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
||||||
|
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
||||||
|
graph.nodes[IMAGE_TO_LATENTS] = imageToLatentsNode;
|
||||||
|
graph.nodes[LATENTS_TO_LATENTS] = latentsToLatentsNode;
|
||||||
|
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
||||||
|
|
||||||
|
// Connect them
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: IMAGE_TO_LATENTS, field: 'latents' },
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_LATENTS,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: LATENTS_TO_LATENTS, field: 'latents' },
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create and add the noise nodes
|
||||||
|
graph = addNoiseNodes(graph, latentsToLatentsNode.id, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
};
|
};
|
||||||
|
@ -1,35 +1,99 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { Graph } from 'services/api';
|
import {
|
||||||
import { buildTxt2ImgNode } from '../nodeBuilders/buildTextToImageNode';
|
CompelInvocation,
|
||||||
import { buildRangeNode } from '../nodeBuilders/buildRangeNode';
|
Graph,
|
||||||
import { buildIterateNode } from '../nodeBuilders/buildIterateNode';
|
LatentsToImageInvocation,
|
||||||
import { buildEdges } from '../edgeBuilders/buildEdges';
|
TextToLatentsInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { addNoiseNodes } from '../nodeBuilders/addNoiseNodes';
|
||||||
|
|
||||||
|
const POSITIVE_CONDITIONING = 'positive_conditioning';
|
||||||
|
const NEGATIVE_CONDITIONING = 'negative_conditioning';
|
||||||
|
const TEXT_TO_LATENTS = 'text_to_latents';
|
||||||
|
const LATENTS_TO_IMAGE = 'latnets_to_image';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builds the Linear workflow graph.
|
* Builds the Text to Image tab graph.
|
||||||
*/
|
*/
|
||||||
export const buildTextToImageGraph = (state: RootState): Graph => {
|
export const buildTextToImageGraph = (state: RootState): Graph => {
|
||||||
const baseNode = buildTxt2ImgNode(state);
|
const {
|
||||||
|
positivePrompt,
|
||||||
|
negativePrompt,
|
||||||
|
model,
|
||||||
|
cfgScale: cfg_scale,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
} = state.generation;
|
||||||
|
|
||||||
// We always range and iterate nodes, no matter the iteration count
|
let graph: NonNullableGraph = {
|
||||||
// This is required to provide the correct seeds to the backend engine
|
nodes: {},
|
||||||
const rangeNode = buildRangeNode(state);
|
edges: [],
|
||||||
const iterateNode = buildIterateNode();
|
|
||||||
|
|
||||||
// Build the edges for the nodes selected.
|
|
||||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
|
||||||
|
|
||||||
// Assemble!
|
|
||||||
const graph = {
|
|
||||||
nodes: {
|
|
||||||
[rangeNode.id]: rangeNode,
|
|
||||||
[iterateNode.id]: iterateNode,
|
|
||||||
[baseNode.id]: baseNode,
|
|
||||||
},
|
|
||||||
edges,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
|
// Create the conditioning, t2l and l2i nodes
|
||||||
|
const positiveConditioningNode: CompelInvocation = {
|
||||||
|
id: POSITIVE_CONDITIONING,
|
||||||
|
type: 'compel',
|
||||||
|
prompt: positivePrompt,
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
const negativeConditioningNode: CompelInvocation = {
|
||||||
|
id: NEGATIVE_CONDITIONING,
|
||||||
|
type: 'compel',
|
||||||
|
prompt: negativePrompt,
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
const textToLatentsNode: TextToLatentsInvocation = {
|
||||||
|
id: TEXT_TO_LATENTS,
|
||||||
|
type: 't2l',
|
||||||
|
cfg_scale,
|
||||||
|
model,
|
||||||
|
scheduler,
|
||||||
|
steps,
|
||||||
|
};
|
||||||
|
|
||||||
|
const latentsToImageNode: LatentsToImageInvocation = {
|
||||||
|
id: LATENTS_TO_IMAGE,
|
||||||
|
type: 'l2i',
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add to the graph
|
||||||
|
graph.nodes[POSITIVE_CONDITIONING] = positiveConditioningNode;
|
||||||
|
graph.nodes[NEGATIVE_CONDITIONING] = negativeConditioningNode;
|
||||||
|
graph.nodes[TEXT_TO_LATENTS] = textToLatentsNode;
|
||||||
|
graph.nodes[LATENTS_TO_IMAGE] = latentsToImageNode;
|
||||||
|
|
||||||
|
// Connect them
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: POSITIVE_CONDITIONING, field: 'conditioning' },
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'positive_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: NEGATIVE_CONDITIONING, field: 'conditioning' },
|
||||||
|
destination: {
|
||||||
|
node_id: TEXT_TO_LATENTS,
|
||||||
|
field: 'negative_conditioning',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graph.edges.push({
|
||||||
|
source: { node_id: TEXT_TO_LATENTS, field: 'latents' },
|
||||||
|
destination: {
|
||||||
|
node_id: LATENTS_TO_IMAGE,
|
||||||
|
field: 'latents',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create and add the noise nodes
|
||||||
|
graph = addNoiseNodes(graph, TEXT_TO_LATENTS, state);
|
||||||
|
|
||||||
return graph;
|
return graph;
|
||||||
};
|
};
|
||||||
|
@ -0,0 +1,208 @@
|
|||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import {
|
||||||
|
IterateInvocation,
|
||||||
|
NoiseInvocation,
|
||||||
|
RandomIntInvocation,
|
||||||
|
RangeOfSizeInvocation,
|
||||||
|
} from 'services/api';
|
||||||
|
import { NonNullableGraph } from 'features/nodes/types/types';
|
||||||
|
import { cloneDeep } from 'lodash-es';
|
||||||
|
|
||||||
|
const NOISE = 'noise';
|
||||||
|
const RANDOM_INT = 'rand_int';
|
||||||
|
const RANGE_OF_SIZE = 'range_of_size';
|
||||||
|
const ITERATE = 'iterate';
|
||||||
|
/**
|
||||||
|
* Adds the appropriate noise nodes to a linear UI t2l or l2l graph.
|
||||||
|
*
|
||||||
|
* @param graph The graph to add the noise nodes to.
|
||||||
|
* @param baseNodeId The id of the base node to connect the noise nodes to.
|
||||||
|
* @param state The app state..
|
||||||
|
*/
|
||||||
|
export const addNoiseNodes = (
|
||||||
|
graph: NonNullableGraph,
|
||||||
|
baseNodeId: string,
|
||||||
|
state: RootState
|
||||||
|
): NonNullableGraph => {
|
||||||
|
const graphClone = cloneDeep(graph);
|
||||||
|
|
||||||
|
// Create and add the noise nodes
|
||||||
|
const { width, height, seed, iterations, shouldRandomizeSeed } =
|
||||||
|
state.generation;
|
||||||
|
|
||||||
|
// Single iteration, explicit seed
|
||||||
|
if (!shouldRandomizeSeed && iterations === 1) {
|
||||||
|
const noiseNode: NoiseInvocation = {
|
||||||
|
id: NOISE,
|
||||||
|
type: 'noise',
|
||||||
|
seed: seed,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graphClone.nodes[NOISE] = noiseNode;
|
||||||
|
|
||||||
|
// Connect them
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: NOISE, field: 'noise' },
|
||||||
|
destination: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single iteration, random seed
|
||||||
|
if (shouldRandomizeSeed && iterations === 1) {
|
||||||
|
// TODO: This assumes the `high` value is the max seed value
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
const noiseNode: NoiseInvocation = {
|
||||||
|
id: NOISE,
|
||||||
|
type: 'noise',
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graphClone.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
graphClone.nodes[NOISE] = noiseNode;
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: NOISE, field: 'noise' },
|
||||||
|
destination: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple iterations, explicit seed
|
||||||
|
if (!shouldRandomizeSeed && iterations > 1) {
|
||||||
|
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
type: 'range_of_size',
|
||||||
|
start: seed,
|
||||||
|
size: iterations,
|
||||||
|
};
|
||||||
|
|
||||||
|
const iterateNode: IterateInvocation = {
|
||||||
|
id: ITERATE,
|
||||||
|
type: 'iterate',
|
||||||
|
};
|
||||||
|
|
||||||
|
const noiseNode: NoiseInvocation = {
|
||||||
|
id: NOISE,
|
||||||
|
type: 'noise',
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||||
|
graphClone.nodes[ITERATE] = iterateNode;
|
||||||
|
graphClone.nodes[NOISE] = noiseNode;
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: NOISE, field: 'noise' },
|
||||||
|
destination: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple iterations, random seed
|
||||||
|
if (shouldRandomizeSeed && iterations > 1) {
|
||||||
|
// TODO: This assumes the `high` value is the max seed value
|
||||||
|
const randomIntNode: RandomIntInvocation = {
|
||||||
|
id: RANDOM_INT,
|
||||||
|
type: 'rand_int',
|
||||||
|
};
|
||||||
|
|
||||||
|
const rangeOfSizeNode: RangeOfSizeInvocation = {
|
||||||
|
id: RANGE_OF_SIZE,
|
||||||
|
type: 'range_of_size',
|
||||||
|
size: iterations,
|
||||||
|
};
|
||||||
|
|
||||||
|
const iterateNode: IterateInvocation = {
|
||||||
|
id: ITERATE,
|
||||||
|
type: 'iterate',
|
||||||
|
};
|
||||||
|
|
||||||
|
const noiseNode: NoiseInvocation = {
|
||||||
|
id: NOISE,
|
||||||
|
type: 'noise',
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
};
|
||||||
|
|
||||||
|
graphClone.nodes[RANDOM_INT] = randomIntNode;
|
||||||
|
graphClone.nodes[RANGE_OF_SIZE] = rangeOfSizeNode;
|
||||||
|
graphClone.nodes[ITERATE] = iterateNode;
|
||||||
|
graphClone.nodes[NOISE] = noiseNode;
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: RANDOM_INT, field: 'a' },
|
||||||
|
destination: { node_id: RANGE_OF_SIZE, field: 'start' },
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: RANGE_OF_SIZE, field: 'collection' },
|
||||||
|
destination: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'collection',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: {
|
||||||
|
node_id: ITERATE,
|
||||||
|
field: 'item',
|
||||||
|
},
|
||||||
|
destination: {
|
||||||
|
node_id: NOISE,
|
||||||
|
field: 'seed',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
graphClone.edges.push({
|
||||||
|
source: { node_id: NOISE, field: 'noise' },
|
||||||
|
destination: {
|
||||||
|
node_id: baseNodeId,
|
||||||
|
field: 'noise',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return graphClone;
|
||||||
|
};
|
@ -0,0 +1,26 @@
|
|||||||
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { RootState } from 'app/store/store';
|
||||||
|
import { CompelInvocation } from 'services/api';
|
||||||
|
import { O } from 'ts-toolbelt';
|
||||||
|
|
||||||
|
export const buildCompelNode = (
|
||||||
|
prompt: string,
|
||||||
|
state: RootState,
|
||||||
|
overrides: O.Partial<CompelInvocation, 'deep'> = {}
|
||||||
|
): CompelInvocation => {
|
||||||
|
const nodeId = uuidv4();
|
||||||
|
const { generation } = state;
|
||||||
|
|
||||||
|
const { model } = generation;
|
||||||
|
|
||||||
|
const compelNode: CompelInvocation = {
|
||||||
|
id: nodeId,
|
||||||
|
type: 'compel',
|
||||||
|
prompt,
|
||||||
|
model,
|
||||||
|
};
|
||||||
|
|
||||||
|
Object.assign(compelNode, overrides);
|
||||||
|
|
||||||
|
return compelNode;
|
||||||
|
};
|
@ -18,8 +18,8 @@ export const buildImg2ImgNode = (
|
|||||||
const activeTabName = activeTabNameSelector(state);
|
const activeTabName = activeTabNameSelector(state);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
positivePrompt: prompt,
|
||||||
negativePrompt,
|
negativePrompt: negativePrompt,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
|
@ -13,8 +13,8 @@ export const buildInpaintNode = (
|
|||||||
const activeTabName = activeTabNameSelector(state);
|
const activeTabName = activeTabNameSelector(state);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
positivePrompt: prompt,
|
||||||
negativePrompt,
|
negativePrompt: negativePrompt,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
|
@ -11,8 +11,8 @@ export const buildTxt2ImgNode = (
|
|||||||
const { generation } = state;
|
const { generation } = state;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
prompt,
|
positivePrompt: prompt,
|
||||||
negativePrompt,
|
negativePrompt: negativePrompt,
|
||||||
seed,
|
seed,
|
||||||
steps,
|
steps,
|
||||||
width,
|
width,
|
||||||
|
@ -13,7 +13,7 @@ import {
|
|||||||
buildOutputFieldTemplates,
|
buildOutputFieldTemplates,
|
||||||
} from './fieldTemplateBuilders';
|
} from './fieldTemplateBuilders';
|
||||||
|
|
||||||
const invocationDenylist = ['Graph', 'LoadImage'];
|
const invocationDenylist = ['Graph'];
|
||||||
|
|
||||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||||
|
@ -8,7 +8,7 @@ import { readinessSelector } from 'app/selectors/readinessSelector';
|
|||||||
import {
|
import {
|
||||||
GenerationState,
|
GenerationState,
|
||||||
clampSymmetrySteps,
|
clampSymmetrySteps,
|
||||||
setPrompt,
|
setPositivePrompt,
|
||||||
} from 'features/parameters/store/generationSlice';
|
} from 'features/parameters/store/generationSlice';
|
||||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ const promptInputSelector = createSelector(
|
|||||||
[(state: RootState) => state.generation, activeTabNameSelector],
|
[(state: RootState) => state.generation, activeTabNameSelector],
|
||||||
(parameters: GenerationState, activeTabName) => {
|
(parameters: GenerationState, activeTabName) => {
|
||||||
return {
|
return {
|
||||||
prompt: parameters.prompt,
|
prompt: parameters.positivePrompt,
|
||||||
activeTabName,
|
activeTabName,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
@ -46,7 +46,7 @@ const ParamPositiveConditioning = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
const handleChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||||
dispatch(setPrompt(e.target.value));
|
dispatch(setPositivePrompt(e.target.value));
|
||||||
};
|
};
|
||||||
|
|
||||||
useHotkeys(
|
useHotkeys(
|
||||||
|
@ -57,7 +57,7 @@ const InitialImagePreview = () => {
|
|||||||
const name = e.dataTransfer.getData('invokeai/imageName');
|
const name = e.dataTransfer.getData('invokeai/imageName');
|
||||||
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
|
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
|
||||||
|
|
||||||
dispatch(initialImageSelected({ name, type }));
|
dispatch(initialImageSelected({ image_name: name, image_type: type }));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
@ -73,10 +73,10 @@ const InitialImagePreview = () => {
|
|||||||
}}
|
}}
|
||||||
onDrop={handleDrop}
|
onDrop={handleDrop}
|
||||||
>
|
>
|
||||||
{initialImage?.url && (
|
{initialImage?.image_url && (
|
||||||
<>
|
<>
|
||||||
<Image
|
<Image
|
||||||
src={getUrl(initialImage?.url)}
|
src={getUrl(initialImage?.image_url)}
|
||||||
fallbackStrategy="beforeLoadOrError"
|
fallbackStrategy="beforeLoadOrError"
|
||||||
fallback={<ImageFallbackSpinner />}
|
fallback={<ImageFallbackSpinner />}
|
||||||
onError={handleError}
|
onError={handleError}
|
||||||
@ -92,7 +92,7 @@ const InitialImagePreview = () => {
|
|||||||
<ImageMetadataOverlay image={initialImage} />
|
<ImageMetadataOverlay image={initialImage} />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
{!initialImage?.url && (
|
{!initialImage?.image_url && (
|
||||||
<Icon
|
<Icon
|
||||||
as={FaImage}
|
as={FaImage}
|
||||||
sx={{
|
sx={{
|
||||||
|
@ -7,9 +7,9 @@ import { allParametersSet, setSeed } from '../store/generationSlice';
|
|||||||
import { isImageField } from 'services/types/guards';
|
import { isImageField } from 'services/types/guards';
|
||||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||||
import { initialImageSelected } from '../store/actions';
|
import { initialImageSelected } from '../store/actions';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export const useParameters = () => {
|
export const useParameters = () => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
@ -88,9 +88,7 @@ export const useParameters = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(
|
dispatch(initialImageSelected(image));
|
||||||
initialImageSelected({ name: image.image_name, type: image.image_type })
|
|
||||||
);
|
|
||||||
toaster({
|
toaster({
|
||||||
title: t('toast.initialImageSet'),
|
title: t('toast.initialImageSet'),
|
||||||
status: 'info',
|
status: 'info',
|
||||||
@ -105,21 +103,21 @@ export const useParameters = () => {
|
|||||||
* Sets image as initial image with toast
|
* Sets image as initial image with toast
|
||||||
*/
|
*/
|
||||||
const sendToImageToImage = useCallback(
|
const sendToImageToImage = useCallback(
|
||||||
(image: Image) => {
|
(image: ImageDTO) => {
|
||||||
dispatch(initialImageSelected({ name: image.name, type: image.type }));
|
dispatch(initialImageSelected(image));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
const recallAllParameters = useCallback(
|
const recallAllParameters = useCallback(
|
||||||
(image: Image | undefined) => {
|
(image: ImageDTO | undefined) => {
|
||||||
const type = image?.metadata?.invokeai?.node?.type;
|
const type = image?.metadata?.type;
|
||||||
if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) {
|
if (['txt2img', 'img2img', 'inpaint'].includes(String(type))) {
|
||||||
dispatch(allParametersSet(image));
|
dispatch(allParametersSet(image));
|
||||||
|
|
||||||
if (image?.metadata?.invokeai?.node?.type === 'img2img') {
|
if (image?.metadata?.type === 'img2img') {
|
||||||
dispatch(setActiveTab('img2img'));
|
dispatch(setActiveTab('img2img'));
|
||||||
} else if (image?.metadata?.invokeai?.node?.type === 'txt2img') {
|
} else if (image?.metadata?.type === 'txt2img') {
|
||||||
dispatch(setActiveTab('txt2img'));
|
dispatch(setActiveTab('txt2img'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
|
|||||||
import * as InvokeAI from 'app/types/invokeai';
|
import * as InvokeAI from 'app/types/invokeai';
|
||||||
import promptToString from 'common/util/promptToString';
|
import promptToString from 'common/util/promptToString';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import { setNegativePrompt, setPrompt } from '../store/generationSlice';
|
import { setNegativePrompt, setPositivePrompt } from '../store/generationSlice';
|
||||||
import { useCallback } from 'react';
|
import { useCallback } from 'react';
|
||||||
|
|
||||||
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
|
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
|
||||||
@ -20,7 +20,7 @@ const useSetBothPrompts = () => {
|
|||||||
|
|
||||||
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
|
const [prompt, negativePrompt] = getPromptAndNegative(promptString);
|
||||||
|
|
||||||
dispatch(setPrompt(prompt));
|
dispatch(setPositivePrompt(prompt));
|
||||||
dispatch(setNegativePrompt(negativePrompt));
|
dispatch(setNegativePrompt(negativePrompt));
|
||||||
},
|
},
|
||||||
[dispatch]
|
[dispatch]
|
||||||
|
@ -1,12 +1,31 @@
|
|||||||
import { createAction } from '@reduxjs/toolkit';
|
import { createAction } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
import { isObject } from 'lodash-es';
|
||||||
import { ImageType } from 'services/api';
|
import { ImageDTO, ImageType } from 'services/api';
|
||||||
|
|
||||||
export type SelectedImage = {
|
export type ImageNameAndType = {
|
||||||
name: string;
|
image_name: string;
|
||||||
type: ImageType;
|
image_type: ImageType;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const isImageDTO = (image: any): image is ImageDTO => {
|
||||||
|
return (
|
||||||
|
image &&
|
||||||
|
isObject(image) &&
|
||||||
|
'image_name' in image &&
|
||||||
|
image?.image_name !== undefined &&
|
||||||
|
'image_type' in image &&
|
||||||
|
image?.image_type !== undefined &&
|
||||||
|
'image_url' in image &&
|
||||||
|
image?.image_url !== undefined &&
|
||||||
|
'thumbnail_url' in image &&
|
||||||
|
image?.thumbnail_url !== undefined &&
|
||||||
|
'image_category' in image &&
|
||||||
|
image?.image_category !== undefined &&
|
||||||
|
'created_at' in image &&
|
||||||
|
image?.created_at !== undefined
|
||||||
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export const initialImageSelected = createAction<
|
export const initialImageSelected = createAction<
|
||||||
Image | SelectedImage | undefined
|
ImageDTO | ImageNameAndType | undefined
|
||||||
>('generation/initialImageSelected');
|
>('generation/initialImageSelected');
|
||||||
|
@ -6,16 +6,17 @@ import { clamp, sample } from 'lodash-es';
|
|||||||
import { setAllParametersReducer } from './setAllParametersReducer';
|
import { setAllParametersReducer } from './setAllParametersReducer';
|
||||||
import { receivedModels } from 'services/thunks/model';
|
import { receivedModels } from 'services/thunks/model';
|
||||||
import { Scheduler } from 'app/constants';
|
import { Scheduler } from 'app/constants';
|
||||||
|
import { ImageDTO } from 'services/api';
|
||||||
|
|
||||||
export interface GenerationState {
|
export interface GenerationState {
|
||||||
cfgScale: number;
|
cfgScale: number;
|
||||||
height: number;
|
height: number;
|
||||||
img2imgStrength: number;
|
img2imgStrength: number;
|
||||||
infillMethod: string;
|
infillMethod: string;
|
||||||
initialImage?: InvokeAI.Image;
|
initialImage?: ImageDTO;
|
||||||
iterations: number;
|
iterations: number;
|
||||||
perlin: number;
|
perlin: number;
|
||||||
prompt: string;
|
positivePrompt: string;
|
||||||
negativePrompt: string;
|
negativePrompt: string;
|
||||||
scheduler: Scheduler;
|
scheduler: Scheduler;
|
||||||
seamBlur: number;
|
seamBlur: number;
|
||||||
@ -49,7 +50,7 @@ export const initialGenerationState: GenerationState = {
|
|||||||
infillMethod: 'patchmatch',
|
infillMethod: 'patchmatch',
|
||||||
iterations: 1,
|
iterations: 1,
|
||||||
perlin: 0,
|
perlin: 0,
|
||||||
prompt: '',
|
positivePrompt: '',
|
||||||
negativePrompt: '',
|
negativePrompt: '',
|
||||||
scheduler: 'lms',
|
scheduler: 'lms',
|
||||||
seamBlur: 16,
|
seamBlur: 16,
|
||||||
@ -82,12 +83,15 @@ export const generationSlice = createSlice({
|
|||||||
name: 'generation',
|
name: 'generation',
|
||||||
initialState,
|
initialState,
|
||||||
reducers: {
|
reducers: {
|
||||||
setPrompt: (state, action: PayloadAction<string | InvokeAI.Prompt>) => {
|
setPositivePrompt: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<string | InvokeAI.Prompt>
|
||||||
|
) => {
|
||||||
const newPrompt = action.payload;
|
const newPrompt = action.payload;
|
||||||
if (typeof newPrompt === 'string') {
|
if (typeof newPrompt === 'string') {
|
||||||
state.prompt = newPrompt;
|
state.positivePrompt = newPrompt;
|
||||||
} else {
|
} else {
|
||||||
state.prompt = promptToString(newPrompt);
|
state.positivePrompt = promptToString(newPrompt);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
setNegativePrompt: (
|
setNegativePrompt: (
|
||||||
@ -213,7 +217,7 @@ export const generationSlice = createSlice({
|
|||||||
setShouldUseNoiseSettings: (state, action: PayloadAction<boolean>) => {
|
setShouldUseNoiseSettings: (state, action: PayloadAction<boolean>) => {
|
||||||
state.shouldUseNoiseSettings = action.payload;
|
state.shouldUseNoiseSettings = action.payload;
|
||||||
},
|
},
|
||||||
initialImageChanged: (state, action: PayloadAction<InvokeAI.Image>) => {
|
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
|
||||||
state.initialImage = action.payload;
|
state.initialImage = action.payload;
|
||||||
},
|
},
|
||||||
modelSelected: (state, action: PayloadAction<string>) => {
|
modelSelected: (state, action: PayloadAction<string>) => {
|
||||||
@ -243,7 +247,7 @@ export const {
|
|||||||
setInfillMethod,
|
setInfillMethod,
|
||||||
setIterations,
|
setIterations,
|
||||||
setPerlin,
|
setPerlin,
|
||||||
setPrompt,
|
setPositivePrompt,
|
||||||
setNegativePrompt,
|
setNegativePrompt,
|
||||||
setScheduler,
|
setScheduler,
|
||||||
setSeamBlur,
|
setSeamBlur,
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import { Draft, PayloadAction } from '@reduxjs/toolkit';
|
import { Draft, PayloadAction } from '@reduxjs/toolkit';
|
||||||
import { Image } from 'app/types/invokeai';
|
|
||||||
import { GenerationState } from './generationSlice';
|
import { GenerationState } from './generationSlice';
|
||||||
import { ImageToImageInvocation } from 'services/api';
|
import { ImageDTO, ImageToImageInvocation } from 'services/api';
|
||||||
import { isScheduler } from 'app/constants';
|
import { isScheduler } from 'app/constants';
|
||||||
|
|
||||||
export const setAllParametersReducer = (
|
export const setAllParametersReducer = (
|
||||||
state: Draft<GenerationState>,
|
state: Draft<GenerationState>,
|
||||||
action: PayloadAction<Image | undefined>
|
action: PayloadAction<ImageDTO | undefined>
|
||||||
) => {
|
) => {
|
||||||
const node = action.payload?.metadata.invokeai?.node;
|
const node = action.payload?.metadata.invokeai?.node;
|
||||||
|
|
||||||
@ -32,7 +31,7 @@ export const setAllParametersReducer = (
|
|||||||
state.model = String(model);
|
state.model = String(model);
|
||||||
}
|
}
|
||||||
if (prompt !== undefined) {
|
if (prompt !== undefined) {
|
||||||
state.prompt = String(prompt);
|
state.positivePrompt = String(prompt);
|
||||||
}
|
}
|
||||||
if (scheduler !== undefined) {
|
if (scheduler !== undefined) {
|
||||||
const schedulerString = String(scheduler);
|
const schedulerString = String(scheduler);
|
||||||
|
@ -5,7 +5,6 @@ import { merge } from 'lodash-es';
|
|||||||
|
|
||||||
export const initialConfigState: AppConfig = {
|
export const initialConfigState: AppConfig = {
|
||||||
shouldTransformUrls: false,
|
shouldTransformUrls: false,
|
||||||
shouldFetchImages: false,
|
|
||||||
disabledTabs: [],
|
disabledTabs: [],
|
||||||
disabledFeatures: [],
|
disabledFeatures: [],
|
||||||
disabledSDFeatures: [],
|
disabledSDFeatures: [],
|
||||||
|
@ -3,4 +3,4 @@ import { UIState } from './uiTypes';
|
|||||||
/**
|
/**
|
||||||
* UI slice persist denylist
|
* UI slice persist denylist
|
||||||
*/
|
*/
|
||||||
export const uiPersistDenylist: (keyof UIState)[] = [];
|
export const uiPersistDenylist: (keyof UIState)[] = ['shouldShowImageDetails'];
|
||||||
|
@ -28,13 +28,15 @@ export type { GraphExecutionState } from './models/GraphExecutionState';
|
|||||||
export type { GraphInvocation } from './models/GraphInvocation';
|
export type { GraphInvocation } from './models/GraphInvocation';
|
||||||
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
export type { GraphInvocationOutput } from './models/GraphInvocationOutput';
|
||||||
export type { HTTPValidationError } from './models/HTTPValidationError';
|
export type { HTTPValidationError } from './models/HTTPValidationError';
|
||||||
|
export type { ImageCategory } from './models/ImageCategory';
|
||||||
|
export type { ImageDTO } from './models/ImageDTO';
|
||||||
export type { ImageField } from './models/ImageField';
|
export type { ImageField } from './models/ImageField';
|
||||||
|
export type { ImageMetadata } from './models/ImageMetadata';
|
||||||
export type { ImageOutput } from './models/ImageOutput';
|
export type { ImageOutput } from './models/ImageOutput';
|
||||||
export type { ImageResponse } from './models/ImageResponse';
|
|
||||||
export type { ImageResponseMetadata } from './models/ImageResponseMetadata';
|
|
||||||
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
|
||||||
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
|
||||||
export type { ImageType } from './models/ImageType';
|
export type { ImageType } from './models/ImageType';
|
||||||
|
export type { ImageUrlsDTO } from './models/ImageUrlsDTO';
|
||||||
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
export type { InfillColorInvocation } from './models/InfillColorInvocation';
|
||||||
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
|
export type { InfillPatchMatchInvocation } from './models/InfillPatchMatchInvocation';
|
||||||
export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
export type { InfillTileInvocation } from './models/InfillTileInvocation';
|
||||||
@ -42,7 +44,6 @@ export type { InpaintInvocation } from './models/InpaintInvocation';
|
|||||||
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
export type { IntCollectionOutput } from './models/IntCollectionOutput';
|
||||||
export type { IntOutput } from './models/IntOutput';
|
export type { IntOutput } from './models/IntOutput';
|
||||||
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
|
export type { InverseLerpInvocation } from './models/InverseLerpInvocation';
|
||||||
export type { InvokeAIMetadata } from './models/InvokeAIMetadata';
|
|
||||||
export type { IterateInvocation } from './models/IterateInvocation';
|
export type { IterateInvocation } from './models/IterateInvocation';
|
||||||
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
export type { IterateInvocationOutput } from './models/IterateInvocationOutput';
|
||||||
export type { LatentsField } from './models/LatentsField';
|
export type { LatentsField } from './models/LatentsField';
|
||||||
@ -53,21 +54,19 @@ export type { LerpInvocation } from './models/LerpInvocation';
|
|||||||
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
export type { LoadImageInvocation } from './models/LoadImageInvocation';
|
||||||
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
|
||||||
export type { MaskOutput } from './models/MaskOutput';
|
export type { MaskOutput } from './models/MaskOutput';
|
||||||
export type { MetadataColorField } from './models/MetadataColorField';
|
|
||||||
export type { MetadataImageField } from './models/MetadataImageField';
|
|
||||||
export type { MetadataLatentsField } from './models/MetadataLatentsField';
|
|
||||||
export type { ModelsList } from './models/ModelsList';
|
export type { ModelsList } from './models/ModelsList';
|
||||||
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
export type { MultiplyInvocation } from './models/MultiplyInvocation';
|
||||||
export type { NoiseInvocation } from './models/NoiseInvocation';
|
export type { NoiseInvocation } from './models/NoiseInvocation';
|
||||||
export type { NoiseOutput } from './models/NoiseOutput';
|
export type { NoiseOutput } from './models/NoiseOutput';
|
||||||
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
export type { PaginatedResults_GraphExecutionState_ } from './models/PaginatedResults_GraphExecutionState_';
|
||||||
export type { PaginatedResults_ImageResponse_ } from './models/PaginatedResults_ImageResponse_';
|
export type { PaginatedResults_ImageDTO_ } from './models/PaginatedResults_ImageDTO_';
|
||||||
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
export type { ParamIntInvocation } from './models/ParamIntInvocation';
|
||||||
export type { PasteImageInvocation } from './models/PasteImageInvocation';
|
export type { PasteImageInvocation } from './models/PasteImageInvocation';
|
||||||
export type { PromptOutput } from './models/PromptOutput';
|
export type { PromptOutput } from './models/PromptOutput';
|
||||||
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
export type { RandomIntInvocation } from './models/RandomIntInvocation';
|
||||||
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
export type { RandomRangeInvocation } from './models/RandomRangeInvocation';
|
||||||
export type { RangeInvocation } from './models/RangeInvocation';
|
export type { RangeInvocation } from './models/RangeInvocation';
|
||||||
|
export type { RangeOfSizeInvocation } from './models/RangeOfSizeInvocation';
|
||||||
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
|
export type { ResizeLatentsInvocation } from './models/ResizeLatentsInvocation';
|
||||||
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
export type { RestoreFaceInvocation } from './models/RestoreFaceInvocation';
|
||||||
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
export type { ScaleLatentsInvocation } from './models/ScaleLatentsInvocation';
|
||||||
@ -79,79 +78,6 @@ export type { UpscaleInvocation } from './models/UpscaleInvocation';
|
|||||||
export type { VaeRepo } from './models/VaeRepo';
|
export type { VaeRepo } from './models/VaeRepo';
|
||||||
export type { ValidationError } from './models/ValidationError';
|
export type { ValidationError } from './models/ValidationError';
|
||||||
|
|
||||||
export { $AddInvocation } from './schemas/$AddInvocation';
|
|
||||||
export { $BlurInvocation } from './schemas/$BlurInvocation';
|
|
||||||
export { $Body_upload_image } from './schemas/$Body_upload_image';
|
|
||||||
export { $CkptModelInfo } from './schemas/$CkptModelInfo';
|
|
||||||
export { $CollectInvocation } from './schemas/$CollectInvocation';
|
|
||||||
export { $CollectInvocationOutput } from './schemas/$CollectInvocationOutput';
|
|
||||||
export { $ColorField } from './schemas/$ColorField';
|
|
||||||
export { $CompelInvocation } from './schemas/$CompelInvocation';
|
|
||||||
export { $CompelOutput } from './schemas/$CompelOutput';
|
|
||||||
export { $ConditioningField } from './schemas/$ConditioningField';
|
|
||||||
export { $CreateModelRequest } from './schemas/$CreateModelRequest';
|
|
||||||
export { $CropImageInvocation } from './schemas/$CropImageInvocation';
|
|
||||||
export { $CvInpaintInvocation } from './schemas/$CvInpaintInvocation';
|
|
||||||
export { $DiffusersModelInfo } from './schemas/$DiffusersModelInfo';
|
|
||||||
export { $DivideInvocation } from './schemas/$DivideInvocation';
|
|
||||||
export { $Edge } from './schemas/$Edge';
|
|
||||||
export { $EdgeConnection } from './schemas/$EdgeConnection';
|
|
||||||
export { $Graph } from './schemas/$Graph';
|
|
||||||
export { $GraphExecutionState } from './schemas/$GraphExecutionState';
|
|
||||||
export { $GraphInvocation } from './schemas/$GraphInvocation';
|
|
||||||
export { $GraphInvocationOutput } from './schemas/$GraphInvocationOutput';
|
|
||||||
export { $HTTPValidationError } from './schemas/$HTTPValidationError';
|
|
||||||
export { $ImageField } from './schemas/$ImageField';
|
|
||||||
export { $ImageOutput } from './schemas/$ImageOutput';
|
|
||||||
export { $ImageResponse } from './schemas/$ImageResponse';
|
|
||||||
export { $ImageResponseMetadata } from './schemas/$ImageResponseMetadata';
|
|
||||||
export { $ImageToImageInvocation } from './schemas/$ImageToImageInvocation';
|
|
||||||
export { $ImageToLatentsInvocation } from './schemas/$ImageToLatentsInvocation';
|
|
||||||
export { $ImageType } from './schemas/$ImageType';
|
|
||||||
export { $InfillColorInvocation } from './schemas/$InfillColorInvocation';
|
|
||||||
export { $InfillPatchMatchInvocation } from './schemas/$InfillPatchMatchInvocation';
|
|
||||||
export { $InfillTileInvocation } from './schemas/$InfillTileInvocation';
|
|
||||||
export { $InpaintInvocation } from './schemas/$InpaintInvocation';
|
|
||||||
export { $IntCollectionOutput } from './schemas/$IntCollectionOutput';
|
|
||||||
export { $IntOutput } from './schemas/$IntOutput';
|
|
||||||
export { $InverseLerpInvocation } from './schemas/$InverseLerpInvocation';
|
|
||||||
export { $InvokeAIMetadata } from './schemas/$InvokeAIMetadata';
|
|
||||||
export { $IterateInvocation } from './schemas/$IterateInvocation';
|
|
||||||
export { $IterateInvocationOutput } from './schemas/$IterateInvocationOutput';
|
|
||||||
export { $LatentsField } from './schemas/$LatentsField';
|
|
||||||
export { $LatentsOutput } from './schemas/$LatentsOutput';
|
|
||||||
export { $LatentsToImageInvocation } from './schemas/$LatentsToImageInvocation';
|
|
||||||
export { $LatentsToLatentsInvocation } from './schemas/$LatentsToLatentsInvocation';
|
|
||||||
export { $LerpInvocation } from './schemas/$LerpInvocation';
|
|
||||||
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
|
|
||||||
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
|
|
||||||
export { $MaskOutput } from './schemas/$MaskOutput';
|
|
||||||
export { $MetadataColorField } from './schemas/$MetadataColorField';
|
|
||||||
export { $MetadataImageField } from './schemas/$MetadataImageField';
|
|
||||||
export { $MetadataLatentsField } from './schemas/$MetadataLatentsField';
|
|
||||||
export { $ModelsList } from './schemas/$ModelsList';
|
|
||||||
export { $MultiplyInvocation } from './schemas/$MultiplyInvocation';
|
|
||||||
export { $NoiseInvocation } from './schemas/$NoiseInvocation';
|
|
||||||
export { $NoiseOutput } from './schemas/$NoiseOutput';
|
|
||||||
export { $PaginatedResults_GraphExecutionState_ } from './schemas/$PaginatedResults_GraphExecutionState_';
|
|
||||||
export { $PaginatedResults_ImageResponse_ } from './schemas/$PaginatedResults_ImageResponse_';
|
|
||||||
export { $ParamIntInvocation } from './schemas/$ParamIntInvocation';
|
|
||||||
export { $PasteImageInvocation } from './schemas/$PasteImageInvocation';
|
|
||||||
export { $PromptOutput } from './schemas/$PromptOutput';
|
|
||||||
export { $RandomIntInvocation } from './schemas/$RandomIntInvocation';
|
|
||||||
export { $RandomRangeInvocation } from './schemas/$RandomRangeInvocation';
|
|
||||||
export { $RangeInvocation } from './schemas/$RangeInvocation';
|
|
||||||
export { $ResizeLatentsInvocation } from './schemas/$ResizeLatentsInvocation';
|
|
||||||
export { $RestoreFaceInvocation } from './schemas/$RestoreFaceInvocation';
|
|
||||||
export { $ScaleLatentsInvocation } from './schemas/$ScaleLatentsInvocation';
|
|
||||||
export { $ShowImageInvocation } from './schemas/$ShowImageInvocation';
|
|
||||||
export { $SubtractInvocation } from './schemas/$SubtractInvocation';
|
|
||||||
export { $TextToImageInvocation } from './schemas/$TextToImageInvocation';
|
|
||||||
export { $TextToLatentsInvocation } from './schemas/$TextToLatentsInvocation';
|
|
||||||
export { $UpscaleInvocation } from './schemas/$UpscaleInvocation';
|
|
||||||
export { $VaeRepo } from './schemas/$VaeRepo';
|
|
||||||
export { $ValidationError } from './schemas/$ValidationError';
|
|
||||||
|
|
||||||
export { ImagesService } from './services/ImagesService';
|
export { ImagesService } from './services/ImagesService';
|
||||||
export { ModelsService } from './services/ModelsService';
|
export { ModelsService } from './services/ModelsService';
|
||||||
export { SessionsService } from './services/SessionsService';
|
export { SessionsService } from './services/SessionsService';
|
||||||
|
@ -31,6 +31,7 @@ import type { PasteImageInvocation } from './PasteImageInvocation';
|
|||||||
import type { RandomIntInvocation } from './RandomIntInvocation';
|
import type { RandomIntInvocation } from './RandomIntInvocation';
|
||||||
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
import type { RandomRangeInvocation } from './RandomRangeInvocation';
|
||||||
import type { RangeInvocation } from './RangeInvocation';
|
import type { RangeInvocation } from './RangeInvocation';
|
||||||
|
import type { RangeOfSizeInvocation } from './RangeOfSizeInvocation';
|
||||||
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
import type { ResizeLatentsInvocation } from './ResizeLatentsInvocation';
|
||||||
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
import type { RestoreFaceInvocation } from './RestoreFaceInvocation';
|
||||||
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
import type { ScaleLatentsInvocation } from './ScaleLatentsInvocation';
|
||||||
@ -48,7 +49,7 @@ export type Graph = {
|
|||||||
/**
|
/**
|
||||||
* The nodes in this graph
|
* The nodes in this graph
|
||||||
*/
|
*/
|
||||||
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | CompelInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | RandomIntInvocation | ParamIntInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | CvInpaintInvocation | RangeInvocation | RangeOfSizeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillColorInvocation | InfillTileInvocation | InfillPatchMatchInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
|
||||||
/**
|
/**
|
||||||
* The connections between nodes and their fields in this graph
|
* The connections between nodes and their fields in this graph
|
||||||
*/
|
*/
|
||||||
|
@ -42,7 +42,7 @@ export type GraphExecutionState = {
|
|||||||
/**
|
/**
|
||||||
* The results of node executions
|
* The results of node executions
|
||||||
*/
|
*/
|
||||||
results: Record<string, (ImageOutput | MaskOutput | CompelOutput | LatentsOutput | NoiseOutput | IntOutput | PromptOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
results: Record<string, (ImageOutput | MaskOutput | PromptOutput | CompelOutput | IntOutput | LatentsOutput | NoiseOutput | IntCollectionOutput | GraphInvocationOutput | IterateInvocationOutput | CollectInvocationOutput)>;
|
||||||
/**
|
/**
|
||||||
* Errors raised when executing nodes
|
* Errors raised when executing nodes
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The category of an image. Use ImageCategory.OTHER for non-default categories.
|
||||||
|
*/
|
||||||
|
export type ImageCategory = 'general' | 'control' | 'other';
|
66
invokeai/frontend/web/src/services/api/models/ImageDTO.ts
Normal file
66
invokeai/frontend/web/src/services/api/models/ImageDTO.ts
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageCategory } from './ImageCategory';
|
||||||
|
import type { ImageMetadata } from './ImageMetadata';
|
||||||
|
import type { ImageType } from './ImageType';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Deserialized image record, enriched for the frontend with URLs.
|
||||||
|
*/
|
||||||
|
export type ImageDTO = {
|
||||||
|
/**
|
||||||
|
* The unique name of the image.
|
||||||
|
*/
|
||||||
|
image_name: string;
|
||||||
|
/**
|
||||||
|
* The type of the image.
|
||||||
|
*/
|
||||||
|
image_type: ImageType;
|
||||||
|
/**
|
||||||
|
* The URL of the image.
|
||||||
|
*/
|
||||||
|
image_url: string;
|
||||||
|
/**
|
||||||
|
* The URL of the image's thumbnail.
|
||||||
|
*/
|
||||||
|
thumbnail_url: string;
|
||||||
|
/**
|
||||||
|
* The category of the image.
|
||||||
|
*/
|
||||||
|
image_category: ImageCategory;
|
||||||
|
/**
|
||||||
|
* The width of the image in px.
|
||||||
|
*/
|
||||||
|
width: number;
|
||||||
|
/**
|
||||||
|
* The height of the image in px.
|
||||||
|
*/
|
||||||
|
height: number;
|
||||||
|
/**
|
||||||
|
* The created timestamp of the image.
|
||||||
|
*/
|
||||||
|
created_at: string;
|
||||||
|
/**
|
||||||
|
* The updated timestamp of the image.
|
||||||
|
*/
|
||||||
|
updated_at: string;
|
||||||
|
/**
|
||||||
|
* The deleted timestamp of the image.
|
||||||
|
*/
|
||||||
|
deleted_at?: string;
|
||||||
|
/**
|
||||||
|
* The session ID that generated this image, if it is a generated image.
|
||||||
|
*/
|
||||||
|
session_id?: string;
|
||||||
|
/**
|
||||||
|
* The node ID that generated this image, if it is a generated image.
|
||||||
|
*/
|
||||||
|
node_id?: string;
|
||||||
|
/**
|
||||||
|
* A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.
|
||||||
|
*/
|
||||||
|
metadata?: ImageMetadata;
|
||||||
|
};
|
||||||
|
|
@ -0,0 +1,81 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Core generation metadata for an image/tensor generated in InvokeAI.
|
||||||
|
*
|
||||||
|
* Also includes any metadata from the image's PNG tEXt chunks.
|
||||||
|
*
|
||||||
|
* Generated by traversing the execution graph, collecting the parameters of the nearest ancestors
|
||||||
|
* of a given node.
|
||||||
|
*
|
||||||
|
* Full metadata may be accessed by querying for the session in the `graph_executions` table.
|
||||||
|
*/
|
||||||
|
export type ImageMetadata = {
|
||||||
|
/**
|
||||||
|
* The type of the ancestor node of the image output node.
|
||||||
|
*/
|
||||||
|
type?: string;
|
||||||
|
/**
|
||||||
|
* The positive conditioning.
|
||||||
|
*/
|
||||||
|
positive_conditioning?: string;
|
||||||
|
/**
|
||||||
|
* The negative conditioning.
|
||||||
|
*/
|
||||||
|
negative_conditioning?: string;
|
||||||
|
/**
|
||||||
|
* Width of the image/latents in pixels.
|
||||||
|
*/
|
||||||
|
width?: number;
|
||||||
|
/**
|
||||||
|
* Height of the image/latents in pixels.
|
||||||
|
*/
|
||||||
|
height?: number;
|
||||||
|
/**
|
||||||
|
* The seed used for noise generation.
|
||||||
|
*/
|
||||||
|
seed?: number;
|
||||||
|
/**
|
||||||
|
* The classifier-free guidance scale.
|
||||||
|
*/
|
||||||
|
cfg_scale?: number;
|
||||||
|
/**
|
||||||
|
* The number of steps used for inference.
|
||||||
|
*/
|
||||||
|
steps?: number;
|
||||||
|
/**
|
||||||
|
* The scheduler used for inference.
|
||||||
|
*/
|
||||||
|
scheduler?: string;
|
||||||
|
/**
|
||||||
|
* The model used for inference.
|
||||||
|
*/
|
||||||
|
model?: string;
|
||||||
|
/**
|
||||||
|
* The strength used for image-to-image/latents-to-latents.
|
||||||
|
*/
|
||||||
|
strength?: number;
|
||||||
|
/**
|
||||||
|
* The ID of the initial latents.
|
||||||
|
*/
|
||||||
|
latents?: string;
|
||||||
|
/**
|
||||||
|
* The VAE used for decoding.
|
||||||
|
*/
|
||||||
|
vae?: string;
|
||||||
|
/**
|
||||||
|
* The UNet used dor inference.
|
||||||
|
*/
|
||||||
|
unet?: string;
|
||||||
|
/**
|
||||||
|
* The CLIP Encoder used for conditioning.
|
||||||
|
*/
|
||||||
|
clip?: string;
|
||||||
|
/**
|
||||||
|
* Uploaded image metadata, extracted from the PNG tEXt chunk.
|
||||||
|
*/
|
||||||
|
extra?: string;
|
||||||
|
};
|
||||||
|
|
@ -8,7 +8,7 @@ import type { ImageField } from './ImageField';
|
|||||||
* Base class for invocations that output an image
|
* Base class for invocations that output an image
|
||||||
*/
|
*/
|
||||||
export type ImageOutput = {
|
export type ImageOutput = {
|
||||||
type: 'image';
|
type: 'image_output';
|
||||||
/**
|
/**
|
||||||
* The output image
|
* The output image
|
||||||
*/
|
*/
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ImageResponseMetadata } from './ImageResponseMetadata';
|
|
||||||
import type { ImageType } from './ImageType';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The response type for images
|
|
||||||
*/
|
|
||||||
export type ImageResponse = {
|
|
||||||
/**
|
|
||||||
* The type of the image
|
|
||||||
*/
|
|
||||||
image_type: ImageType;
|
|
||||||
/**
|
|
||||||
* The name of the image
|
|
||||||
*/
|
|
||||||
image_name: string;
|
|
||||||
/**
|
|
||||||
* The url of the image
|
|
||||||
*/
|
|
||||||
image_url: string;
|
|
||||||
/**
|
|
||||||
* The url of the image's thumbnail
|
|
||||||
*/
|
|
||||||
thumbnail_url: string;
|
|
||||||
/**
|
|
||||||
* The image's metadata
|
|
||||||
*/
|
|
||||||
metadata: ImageResponseMetadata;
|
|
||||||
};
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { InvokeAIMetadata } from './InvokeAIMetadata';
|
|
||||||
|
|
||||||
/**
|
|
||||||
* An image's metadata. Used only in HTTP responses.
|
|
||||||
*/
|
|
||||||
export type ImageResponseMetadata = {
|
|
||||||
/**
|
|
||||||
* The creation timestamp of the image
|
|
||||||
*/
|
|
||||||
created: number;
|
|
||||||
/**
|
|
||||||
* The width of the image in pixels
|
|
||||||
*/
|
|
||||||
width: number;
|
|
||||||
/**
|
|
||||||
* The height of the image in pixels
|
|
||||||
*/
|
|
||||||
height: number;
|
|
||||||
/**
|
|
||||||
* The image's InvokeAI-specific metadata
|
|
||||||
*/
|
|
||||||
invokeai?: InvokeAIMetadata;
|
|
||||||
};
|
|
||||||
|
|
@ -3,6 +3,6 @@
|
|||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An enumeration.
|
* The type of an image.
|
||||||
*/
|
*/
|
||||||
export type ImageType = 'results' | 'intermediates' | 'uploads';
|
export type ImageType = 'results' | 'uploads' | 'intermediates';
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
/* istanbul ignore file */
|
||||||
|
/* tslint:disable */
|
||||||
|
/* eslint-disable */
|
||||||
|
|
||||||
|
import type { ImageType } from './ImageType';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The URLs for an image and its thumbnail.
|
||||||
|
*/
|
||||||
|
export type ImageUrlsDTO = {
|
||||||
|
/**
|
||||||
|
* The unique name of the image.
|
||||||
|
*/
|
||||||
|
image_name: string;
|
||||||
|
/**
|
||||||
|
* The type of the image.
|
||||||
|
*/
|
||||||
|
image_type: ImageType;
|
||||||
|
/**
|
||||||
|
* The URL of the image.
|
||||||
|
*/
|
||||||
|
image_url: string;
|
||||||
|
/**
|
||||||
|
* The URL of the image's thumbnail.
|
||||||
|
*/
|
||||||
|
thumbnail_url: string;
|
||||||
|
};
|
||||||
|
|
@ -1,13 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { MetadataColorField } from './MetadataColorField';
|
|
||||||
import type { MetadataImageField } from './MetadataImageField';
|
|
||||||
import type { MetadataLatentsField } from './MetadataLatentsField';
|
|
||||||
|
|
||||||
export type InvokeAIMetadata = {
|
|
||||||
session_id?: string;
|
|
||||||
node?: Record<string, (string | number | boolean | MetadataImageField | MetadataLatentsField | MetadataColorField)>;
|
|
||||||
};
|
|
||||||
|
|
@ -13,5 +13,13 @@ export type MaskOutput = {
|
|||||||
* The output mask
|
* The output mask
|
||||||
*/
|
*/
|
||||||
mask: ImageField;
|
mask: ImageField;
|
||||||
|
/**
|
||||||
|
* The width of the mask in pixels
|
||||||
|
*/
|
||||||
|
width?: number;
|
||||||
|
/**
|
||||||
|
* The height of the mask in pixels
|
||||||
|
*/
|
||||||
|
height?: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -1,11 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
export type MetadataColorField = {
|
|
||||||
'r': number;
|
|
||||||
'g': number;
|
|
||||||
'b': number;
|
|
||||||
'a': number;
|
|
||||||
};
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
import type { ImageType } from './ImageType';
|
|
||||||
|
|
||||||
export type MetadataImageField = {
|
|
||||||
image_type: ImageType;
|
|
||||||
image_name: string;
|
|
||||||
};
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
|||||||
/* istanbul ignore file */
|
|
||||||
/* tslint:disable */
|
|
||||||
/* eslint-disable */
|
|
||||||
|
|
||||||
export type MetadataLatentsField = {
|
|
||||||
latents_name: string;
|
|
||||||
};
|
|
||||||
|
|
@ -2,16 +2,16 @@
|
|||||||
/* tslint:disable */
|
/* tslint:disable */
|
||||||
/* eslint-disable */
|
/* eslint-disable */
|
||||||
|
|
||||||
import type { ImageResponse } from './ImageResponse';
|
import type { ImageDTO } from './ImageDTO';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Paginated results
|
* Paginated results
|
||||||
*/
|
*/
|
||||||
export type PaginatedResults_ImageResponse_ = {
|
export type PaginatedResults_ImageDTO_ = {
|
||||||
/**
|
/**
|
||||||
* Items
|
* Items
|
||||||
*/
|
*/
|
||||||
items: Array<ImageResponse>;
|
items: Array<ImageDTO>;
|
||||||
/**
|
/**
|
||||||
* Current Page
|
* Current Page
|
||||||
*/
|
*/
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user