Compare commits

...

29 Commits

Author SHA1 Message Date
07c9b598bd hack(nodes): hack to get image urls in the invocation complete event 2023-05-21 22:46:56 +10:00
4d37ce31fc feat(nodes): streamline urlservice 2023-05-21 22:44:16 +10:00
20e853084f fix(nodes): remove bad import 2023-05-21 22:43:04 +10:00
76fe1d0103 feat(nodes): it works 2023-05-21 22:15:44 +10:00
c4fad12ac1 feat(nodes): fix types for InvocationServices 2023-05-21 20:27:34 +10:00
db44d1431c feat(nodes): add logger to images service 2023-05-21 20:24:59 +10:00
8d79610be2 feat(logger): fix logger type issues 2023-05-21 20:24:37 +10:00
466980812b feat(nodes): wip image storage implementation 2023-05-21 20:05:33 +10:00
ca8f8b9162 fix(nodes): use save instead of set
`set` is a python builtin
2023-05-21 17:27:25 +10:00
8ae769fdad feat(nodes): image records router 2023-05-21 15:47:29 +10:00
039f2b00df feat(nodes): add high-level images service 2023-05-21 15:17:06 +10:00
0299f0c4c6 feat(nodes): update urlservice 2023-05-21 10:44:26 +10:00
df4abdcd81 feat(nodes): update image related names 2023-05-21 09:59:55 +10:00
ed8dfdf996 feat(nodes): wip images db & router 2023-05-20 22:20:49 +10:00
4004916af9 feat(nodes): images_db_service and resources router 2023-05-20 20:10:55 +10:00
24638a71da feat(nodes): wip latents db stuff 2023-05-20 18:01:31 +10:00
c7392e7948 feat(nodes): add design doc 2023-05-20 02:39:46 +10:00
f92afaac7c feat(nodes): add core metadata builder 2023-05-18 23:34:45 +10:00
6a30b8ec99 fix(ui): send to canvas in currentimagebuttons not working 2023-05-18 11:15:59 +10:00
40a30f8fee feat(ui): crude results router 2023-05-17 23:09:54 +10:00
fe07a0846e fix(nodes): Result class should use outputs classes, not fields 2023-05-17 21:56:04 +10:00
78acd185b5 feat(nodes): add results router
It doesn't work due to circular imports still
2023-05-17 21:51:02 +10:00
57db0816d6 fix(nodes): do not shadow list builtin 2023-05-17 20:55:08 +10:00
71347df765 feat(nodes): add result_type to results table, fix types 2023-05-17 20:51:37 +10:00
751b4f249d fix(ui): fix type guards 2023-05-17 19:42:42 +10:00
5136b83049 chore(ui): regen api 2023-05-17 19:42:10 +10:00
6e4e0fe29e fix(nodes): fix results service bugs 2023-05-17 19:35:34 +10:00
f0a9a4fb88 feat(nodes): add ResultsServiceABC & SqliteResultsService
**Doesn't actually work bc of circular imports. Can't even test it.**

- add a base class for ResultsService and SQLite implementation
- use `graph_execution_manager` `on_changed` callback to keep `results` table in sync
2023-05-17 19:16:04 +10:00
34b50e11b6 feat(nodes): change ImageOutput type to image_output 2023-05-17 19:13:53 +10:00
38 changed files with 3217 additions and 471 deletions

View File

@ -1,9 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from logging import Logger
import os
import invokeai.backend.util.logging as logger
from typing import types
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from ..services.default_graphs import create_system_graphs
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
@ -11,7 +13,7 @@ from ...backend import Globals
from ..services.model_manager_initializer import get_model_manager
from ..services.restoration_services import RestorationServices
from ..services.graph import GraphExecutionState, LibraryGraph
from ..services.image_storage import DiskImageStorage
from ..services.image_file_storage import DiskImageFileStorage
from ..services.invocation_queue import MemoryInvocationQueue
from ..services.invocation_services import InvocationServices
from ..services.invoker import Invoker
@ -37,13 +39,16 @@ def check_internet() -> bool:
return False
logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Invoker = None
@staticmethod
def initialize(config, event_handler_id: int, logger: types.ModuleType=logger):
def initialize(config, event_handler_id: int, logger: Logger = logger):
Globals.try_patchmatch = config.patchmatch
Globals.always_use_cpu = config.always_use_cpu
Globals.internet_available = config.internet_available and check_internet()
@ -60,31 +65,47 @@ class ApiDependencies:
os.path.join(os.path.dirname(__file__), "../../../../outputs")
)
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents'))
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
metadata = PngMetadataService()
images = DiskImageStorage(f'{output_folder}/images', metadata_service=metadata)
urls = LocalUrlService()
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
# TODO: build a file/path manager?
db_location = os.path.join(output_folder, "invokeai.db")
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
image_record_storage = SqliteImageRecordStorage(db_location)
images_new = ImageService(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=urls,
logger=logger,
)
services = InvocationServices(
model_manager=get_model_manager(config,logger),
model_manager=get_model_manager(config, logger),
events=events,
logger=logger,
latents=latents,
images=images,
metadata=metadata,
images=image_file_storage,
images_new=images_new,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
restoration=RestorationServices(config,logger),
restoration=RestorationServices(config, logger),
)
create_system_graphs(services.graph_library)

View File

@ -0,0 +1,47 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from fastapi import HTTPException, Path
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from invokeai.app.models.image import ImageType
from ..dependencies import ApiDependencies
image_files_router = APIRouter(prefix="/v1/files/images", tags=["images", "files"])
@image_files_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of the image to get"),
image_name: str = Path(description="The id of the image to get"),
) -> FileResponse:
"""Gets an image"""
try:
path = ApiDependencies.invoker.services.images_new.get_path(
image_type=image_type, image_name=image_name
)
return FileResponse(path)
except Exception as e:
raise HTTPException(status_code=404)
@image_files_router.get(
"/{image_type}/{image_name}/thumbnail", operation_id="get_thumbnail"
)
async def get_thumbnail(
image_type: ImageType = Path(
description="The type of the image whose thumbnail to get"
),
image_name: str = Path(description="The id of the image whose thumbnail to get"),
) -> FileResponse:
"""Gets a thumbnail"""
try:
path = ApiDependencies.invoker.services.images_new.get_path(
image_type=image_type, image_name=image_name, thumbnail=True
)
return FileResponse(path)
except Exception as e:
raise HTTPException(status_code=404)

View File

@ -0,0 +1,71 @@
from fastapi import HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.models.image_record import ImageDTO
from ..dependencies import ApiDependencies
image_records_router = APIRouter(
prefix="/v1/images/records", tags=["images", "records"]
)
@image_records_router.get("/{image_type}/{image_name}", operation_id="get_image_record")
async def get_image_record(
image_type: ImageType = Path(description="The type of the image record to get"),
image_name: str = Path(description="The id of the image record to get"),
) -> ImageDTO:
"""Gets an image record by id"""
try:
return ApiDependencies.invoker.services.images_new.get_dto(
image_type=image_type, image_name=image_name
)
except Exception as e:
raise HTTPException(status_code=404)
@image_records_router.get(
"/",
operation_id="list_image_records",
)
async def list_image_records(
image_type: ImageType = Query(description="The type of image records to get"),
image_category: ImageCategory = Query(
description="The kind of image records to get"
),
page: int = Query(default=0, description="The page of image records to get"),
per_page: int = Query(
default=10, description="The number of image records per page"
),
) -> PaginatedResults[ImageDTO]:
"""Gets a list of image records by type and category"""
image_dtos = ApiDependencies.invoker.services.images_new.get_many(
image_type=image_type,
image_category=image_category,
page=page,
per_page=per_page,
)
return image_dtos
@image_records_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image_record(
image_type: ImageType = Query(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image record"""
try:
ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name
)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass

View File

@ -1,90 +1,39 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import io
from datetime import datetime, timezone
import json
import os
from typing import Any
import uuid
from fastapi import Body, HTTPException, Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.models.image_record import ImageRecord
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
images_router = APIRouter(prefix="/v1/images", tags=["images"])
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
) -> FileResponse:
"""Gets an image"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=image_type, image_name=image_name
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image and its thumbnail"""
ApiDependencies.invoker.services.images.delete(
image_type=image_type, image_name=image_name
)
@images_router.get(
"/{thumbnail_type}/thumbnails/{thumbnail_name}", operation_id="get_thumbnail"
)
async def get_thumbnail(
thumbnail_type: ImageType = Path(description="The type of thumbnail to get"),
thumbnail_name: str = Path(description="The name of the thumbnail to get"),
) -> FileResponse | Response:
"""Gets a thumbnail"""
path = ApiDependencies.invoker.services.images.get_path(
image_type=thumbnail_type, image_name=thumbnail_name, is_thumbnail=True
)
if ApiDependencies.invoker.services.images.validate_path(path):
return FileResponse(path)
else:
raise HTTPException(status_code=404)
@images_router.post(
"/uploads/",
"/",
operation_id="upload_image",
responses={
201: {
"description": "The image was uploaded successfully",
"model": ImageResponse,
},
201: {"description": "The image was uploaded successfully"},
415: {"description": "Image upload failed"},
},
status_code=201,
)
async def upload_image(
file: UploadFile, image_type: ImageType, request: Request, response: Response
) -> ImageResponse:
file: UploadFile,
image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.IMAGE,
) -> ImageRecord:
"""Uploads an image"""
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@ -96,53 +45,33 @@ async def upload_image(
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
filename = f"{uuid.uuid4()}_{str(int(datetime.now(timezone.utc).timestamp()))}.png"
try:
image_record = ApiDependencies.invoker.services.images_new.create(
image=img,
image_type=image_type,
image_category=image_category,
)
saved_image = ApiDependencies.invoker.services.images.save(
image_type, filename, img
)
response.status_code = 201
response.headers["Location"] = image_record.image_url
invokeai_metadata = ApiDependencies.invoker.services.metadata.get_metadata(img)
image_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_uri(
image_type, saved_image.image_name, True
)
res = ImageResponse(
image_type=image_type,
image_name=saved_image.image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
metadata=ImageResponseMetadata(
created=saved_image.created,
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
response.status_code = 201
response.headers["Location"] = image_url
return res
return image_record
except Exception as e:
raise HTTPException(status_code=500)
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(
default=ImageType.RESULT, description="The type of images to get"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(image_type, page, per_page)
return result
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image_record(
image_type: ImageType = Query(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image record"""
try:
ApiDependencies.invoker.services.images_new.delete(
image_type=image_type, image_name=image_name
)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass

View File

@ -0,0 +1,42 @@
from fastapi import HTTPException, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.results import ResultType, ResultWithSession
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
results_router = APIRouter(prefix="/v1/results", tags=["results"])
@results_router.get("/{result_type}/{result_name}", operation_id="get_result")
async def get_result(
result_type: ResultType = Path(description="The type of result to get"),
result_name: str = Path(description="The name of the result to get"),
) -> ResultWithSession:
"""Gets a result"""
result = ApiDependencies.invoker.services.results.get(
result_id=result_name, result_type=result_type
)
if result is not None:
return result
else:
raise HTTPException(status_code=404)
@results_router.get(
"/",
operation_id="list_results",
responses={200: {"model": PaginatedResults[ResultWithSession]}},
)
async def list_results(
result_type: ResultType = Query(description="The type of results to get"),
page: int = Query(default=0, description="The page of results to get"),
per_page: int = Query(default=10, description="The number of results per page"),
) -> PaginatedResults[ResultWithSession]:
"""Gets a list of results"""
results = ApiDependencies.invoker.services.results.get_many(
result_type=result_type, page=page, per_page=per_page
)
return results

View File

@ -3,6 +3,7 @@ import asyncio
from inspect import signature
import uvicorn
from invokeai.app.models import resources
import invokeai.backend.util.logging as logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -15,10 +16,11 @@ from pydantic.schema import schema
from ..backend import Args
from .api.dependencies import ApiDependencies
from .api.routers import images, sessions, models
from .api.routers import image_files, image_records, sessions, models
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
@ -74,10 +76,12 @@ async def shutdown_event():
app.include_router(sessions.session_router, prefix="/api")
app.include_router(images.images_router, prefix="/api")
app.include_router(image_files.image_files_router, prefix="/api")
app.include_router(models.models_router, prefix="/api")
app.include_router(image_records.image_records_router, prefix="/api")
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@ -126,6 +130,7 @@ app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory="static/dream_web"), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
return get_swagger_ui_html(
@ -143,8 +148,12 @@ def overridden_redoc():
redoc_favicon_url="/static/favicon.ico",
)
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui")
app.mount(
"/", StaticFiles(directory="invokeai/frontend/web/dist", html=True), name="ui"
)
def invoke_api():
# Start our own event loop for eventing usage

View File

@ -28,7 +28,7 @@ from .services.model_manager_initializer import get_model_manager
from .services.restoration_services import RestorationServices
from .services.graph import Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, are_connection_types_compatible
from .services.default_graphs import default_text_to_image_graph_id
from .services.image_storage import DiskImageStorage
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
@ -214,7 +214,7 @@ def invoke_cli():
model_manager=model_manager,
events=events,
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=DiskImageStorage(f'{output_folder}/images', metadata_service=metadata),
images=DiskImageFileStorage(f'{output_folder}/images', metadata_service=metadata),
metadata=metadata,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](

View File

@ -1,12 +1,15 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict
from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING
from pydantic import BaseModel, Field
from ..services.invocation_services import InvocationServices
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
class InvocationContext:

View File

@ -120,7 +120,7 @@ class CompelInvocation(BaseInvocation):
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
context.services.latents.save(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType
from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@ -91,24 +92,42 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
# each time it is called. We only need the first one.
generate_output = next(outputs)
image_dto = context.services.images_new.create(
image=generate_output.image,
image_type=ImageType.RESULT,
image_category=ImageCategory.IMAGE,
session_id=context.graph_execution_state_id,
node_id=self.id,
)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
# TODO: can this return multiple results? Should it?
image_type = ImageType.RESULT
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
# image_type = ImageType.RESULT
# image_name = context.services.images.create_name(
# context.graph_execution_state_id, self.id
# )
metadata = context.services.metadata.build_metadata(
session_id=context.graph_execution_state_id, node=self
)
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# context.services.images.save(
# image_type, image_name, generate_output.image, metadata
# )
# context.services.images_db.set(
# id=image_name,
# image_type=ImageType.RESULT,
# image_category=ImageCategory.IMAGE,
# session_id=context.graph_execution_state_id,
# node_id=self.id,
# metadata=GeneratedImageOrLatentsMetadata(),
# )
context.services.images.save(
image_type, image_name, generate_output.image, metadata
)
return build_image_output(
image_type=image_type,
image_name=image_name,
image_type=image_dto.image_type,
image_name=image_dto.image_name,
image=generate_output.image,
)

View File

@ -31,7 +31,7 @@ class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image"] = "image"
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")

View File

@ -20,7 +20,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, Sta
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_storage import ImageType
from ..services.image_file_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
@ -144,7 +144,7 @@ class NoiseInvocation(BaseInvocation):
noise = get_noise(self.width, self.height, device, self.seed)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, noise)
context.services.latents.save(name, noise)
return build_noise_output(latents_name=name, latents=noise)
@ -260,7 +260,7 @@ class TextToLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -319,7 +319,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.set(name, result_latents)
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@ -404,7 +404,7 @@ class ResizeLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -434,7 +434,7 @@ class ScaleLatentsInvocation(BaseInvocation):
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)
@ -478,5 +478,5 @@ class ImageToLatentsInvocation(BaseInvocation):
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@ -2,11 +2,23 @@ from enum import Enum
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.enum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum):
"""The type of an image."""
class ImageType(str, Enum):
RESULT = "results"
INTERMEDIATE = "intermediates"
UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories."""
IMAGE = "image"
CONTROL_IMAGE = "control_image"
OTHER = "other"
def is_image_type(obj):

View File

@ -0,0 +1,59 @@
from typing import Optional
from pydantic import BaseModel, Field, StrictFloat, StrictInt, StrictStr
class ImageMetadata(BaseModel):
"""
Core generation metadata for an image/tensor generated in InvokeAI.
Also includes any metadata from the image's PNG tEXt chunks.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
positive_conditioning: Optional[StrictStr] = Field(
default=None, description="The positive conditioning."
)
negative_conditioning: Optional[StrictStr] = Field(
default=None, description="The negative conditioning."
)
width: Optional[StrictInt] = Field(
default=None, description="Width of the image/tensor in pixels."
)
height: Optional[StrictInt] = Field(
default=None, description="Height of the image/tensor in pixels."
)
seed: Optional[StrictInt] = Field(
default=None, description="The seed used for noise generation."
)
cfg_scale: Optional[StrictFloat] = Field(
default=None, description="The classifier-free guidance scale."
)
steps: Optional[StrictInt] = Field(
default=None, description="The number of steps used for inference."
)
scheduler: Optional[StrictStr] = Field(
default=None, description="The scheduler used for inference."
)
model: Optional[StrictStr] = Field(
default=None, description="The model used for inference."
)
strength: Optional[StrictFloat] = Field(
default=None,
description="The strength used for image-to-image/tensor-to-tensor.",
)
image: Optional[StrictStr] = Field(
default=None, description="The ID of the initial image."
)
tensor: Optional[StrictStr] = Field(
default=None, description="The ID of the initial tensor."
)
# Pending model refactor:
# vae: Optional[str] = Field(default=None,description="The VAE used for decoding.")
# unet: Optional[str] = Field(default=None,description="The UNet used dor inference.")
# clip: Optional[str] = Field(default=None,description="The CLIP Encoder used for conditioning.")
extra: Optional[StrictStr] = Field(
default=None, description="Extra metadata, extracted from the PNG tEXt chunk."
)

View File

@ -0,0 +1,28 @@
# TODO: Make a new model for this
from enum import Enum
from invokeai.app.util.enum import MetaEnum
class ResourceType(str, Enum, metaclass=MetaEnum):
"""The type of a resource."""
IMAGES = "images"
TENSORS = "tensors"
# class ResourceOrigin(str, Enum, metaclass=MetaEnum):
# """The origin of a resource (eg image or tensor)."""
# RESULTS = "results"
# UPLOADS = "uploads"
# INTERMEDIATES = "intermediates"
class TensorKind(str, Enum, metaclass=MetaEnum):
"""The kind of a tensor. Use TensorKind.OTHER for non-default kinds."""
IMAGE_LATENTS = "image_latents"
CONDITIONING = "conditioning"
OTHER = "other"

View File

@ -0,0 +1,578 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from abc import ABC, abstractmethod\n",
"from enum import Enum\n",
"import enum\n",
"import sqlite3\n",
"import threading\n",
"from typing import Optional, Type, TypeVar, Union\n",
"from PIL.Image import Image as PILImage\n",
"from pydantic import BaseModel, Field\n",
"from torch import Tensor"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class ResourceOrigin(str, Enum):\n",
" \"\"\"The origin of a resource (eg image or tensor).\"\"\"\n",
"\n",
" RESULTS = \"results\"\n",
" UPLOADS = \"uploads\"\n",
" INTERMEDIATES = \"intermediates\"\n",
"\n",
"\n",
"class ImageKind(str, Enum):\n",
" \"\"\"The kind of an image. Use ImageKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE = \"image\"\n",
" CONTROL_IMAGE = \"control_image\"\n",
" OTHER = \"other\"\n",
"\n",
"\n",
"class TensorKind(str, Enum):\n",
" \"\"\"The kind of a tensor. Use TensorKind.OTHER for non-default kinds.\"\"\"\n",
"\n",
" IMAGE_LATENTS = \"image_latents\"\n",
" CONDITIONING = \"conditioning\"\n",
" OTHER = \"other\"\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def create_sql_values_string_from_string_enum(enum: Type[Enum]):\n",
" \"\"\"\n",
" Creates a string of the form \"('value1'), ('value2'), ..., ('valueN')\" from a StrEnum.\n",
" \"\"\"\n",
"\n",
" delimiter = \", \"\n",
" values = [f\"('{e.value}')\" for e in enum]\n",
" return delimiter.join(values)\n",
"\n",
"\n",
"def create_sql_table_from_enum(\n",
" enum: Type[Enum],\n",
" table_name: str,\n",
" primary_key_name: str,\n",
" conn: sqlite3.Connection,\n",
" cursor: sqlite3.Cursor,\n",
" lock: threading.Lock,\n",
"):\n",
" \"\"\"\n",
" Creates and populates a table to be used as a functional enum.\n",
" \"\"\"\n",
"\n",
" try:\n",
" lock.acquire()\n",
"\n",
" values_string = create_sql_values_string_from_string_enum(enum)\n",
"\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS {table_name} (\n",
" {primary_key_name} TEXT PRIMARY KEY\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" f\"\"\"--sql\n",
" INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`resource_origins` functions as an enum for the ResourceOrigin model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_resource_origins_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
"# create_sql_table_from_enum(\n",
"# enum=ResourceOrigin,\n",
"# table_name=\"resource_origins\",\n",
"# primary_key_name=\"origin_name\",\n",
"# conn=conn,\n",
"# cursor=cursor,\n",
"# lock=lock,\n",
"# )\n",
"\n",
"\n",
"\"\"\"\n",
"`image_kinds` functions as an enum for the ImageType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_image_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=ImageKind,\n",
" # table_name=\"image_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`tensor_kinds` functions as an enum for the TensorType model.\n",
"\"\"\"\n",
"\n",
"\n",
"# def create_tensor_kinds_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" # create_sql_table_from_enum(\n",
" # enum=TensorKind,\n",
" # table_name=\"tensor_kinds\",\n",
" # primary_key_name=\"kind_name\",\n",
" # conn=conn,\n",
" # cursor=cursor,\n",
" # lock=lock,\n",
" # )\n",
"\n",
"\n",
"\"\"\"\n",
"`images` stores all images, regardless of type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" image_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_results` stores additional data specific to `results` images.\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_results (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_results_images_id ON images_results(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_intermediates` stores additional data specific to `intermediates` images\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_intermediates (\n",
" images_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_intermediates_images_id ON images_intermediates(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"\"\"\"\n",
"`images_metadata` stores basic metadata for any image type\n",
"\"\"\"\n",
"\n",
"\n",
"def create_images_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS images_metadata (\n",
" images_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors` table: stores references to tensor\n",
"\n",
"\n",
"def create_tensors_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors (\n",
" id TEXT PRIMARY KEY,\n",
" origin TEXT,\n",
" tensor_kind TEXT,\n",
" created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,\n",
" FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),\n",
" FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name)\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_results` stores additional data specific to `result` tensor\n",
"\n",
"\n",
"def create_tensors_results_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_results (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_results_tensors_id ON tensors_results(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_intermediates` stores additional data specific to `intermediate` tensor\n",
"\n",
"\n",
"def create_tensors_intermediates_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_intermediates (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" session_id TEXT NOT NULL,\n",
" node_id TEXT NOT NULL,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_intermediates_tensors_id ON tensors_intermediates(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n",
"\n",
"\n",
"# `tensors_metadata` table: stores generated/transformed metadata for tensor\n",
"\n",
"\n",
"def create_tensors_metadata_table(conn: sqlite3.Connection, cursor: sqlite3.Cursor, lock: threading.Lock):\n",
" try:\n",
" lock.acquire()\n",
"\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE TABLE IF NOT EXISTS tensors_metadata (\n",
" tensors_id TEXT PRIMARY KEY,\n",
" metadata TEXT,\n",
" FOREIGN KEY(tensors_id) REFERENCES tensors(id) ON DELETE CASCADE\n",
" );\n",
" \"\"\"\n",
" )\n",
" cursor.execute(\n",
" \"\"\"--sql\n",
" CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensors_id ON tensors_metadata(tensors_id);\n",
" \"\"\"\n",
" )\n",
" conn.commit()\n",
" finally:\n",
" lock.release()\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"db_path = '/home/bat/Documents/Code/outputs/test.db'\n",
"if (os.path.exists(db_path)):\n",
" os.remove(db_path)\n",
"\n",
"conn = sqlite3.connect(\n",
" db_path, check_same_thread=False\n",
")\n",
"cursor = conn.cursor()\n",
"lock = threading.Lock()"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"create_sql_table_from_enum(\n",
" enum=ResourceOrigin,\n",
" table_name=\"resource_origins\",\n",
" primary_key_name=\"origin_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=ImageKind,\n",
" table_name=\"image_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")\n",
"\n",
"create_sql_table_from_enum(\n",
" enum=TensorKind,\n",
" table_name=\"tensor_kinds\",\n",
" primary_key_name=\"kind_name\",\n",
" conn=conn,\n",
" cursor=cursor,\n",
" lock=lock,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"create_images_table(conn, cursor, lock)\n",
"create_images_results_table(conn, cursor, lock)\n",
"create_images_intermediates_table(conn, cursor, lock)\n",
"create_images_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"create_tensors_table(conn, cursor, lock)\n",
"create_tensors_results_table(conn, cursor, lock)\n",
"create_tensors_intermediates_table(conn, cursor, lock)\n",
"create_tensors_metadata_table(conn, cursor, lock)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from pydantic import StrictStr\n",
"\n",
"\n",
"class GeneratedImageOrLatentsMetadata(BaseModel):\n",
" \"\"\"Core generation metadata for an image/tensor generated in InvokeAI.\n",
"\n",
" Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.\n",
"\n",
" Full metadata may be accessed by querying for the session in the `graph_executions` table.\n",
" \"\"\"\n",
"\n",
" positive_conditioning: Optional[StrictStr] = Field(\n",
" default=None, description=\"The positive conditioning.\"\n",
" )\n",
" negative_conditioning: Optional[str] = Field(\n",
" default=None, description=\"The negative conditioning.\"\n",
" )\n",
" width: Optional[int] = Field(\n",
" default=None, description=\"Width of the image/tensor in pixels.\"\n",
" )\n",
" height: Optional[int] = Field(\n",
" default=None, description=\"Height of the image/tensor in pixels.\"\n",
" )\n",
" seed: Optional[int] = Field(\n",
" default=None, description=\"The seed used for noise generation.\"\n",
" )\n",
" cfg_scale: Optional[float] = Field(\n",
" default=None, description=\"The classifier-free guidance scale.\"\n",
" )\n",
" steps: Optional[int] = Field(\n",
" default=None, description=\"The number of steps used for inference.\"\n",
" )\n",
" scheduler: Optional[str] = Field(\n",
" default=None, description=\"The scheduler used for inference.\"\n",
" )\n",
" model: Optional[str] = Field(\n",
" default=None, description=\"The model used for inference.\"\n",
" )\n",
" strength: Optional[float] = Field(\n",
" default=None,\n",
" description=\"The strength used for image-to-image/tensor-to-tensor.\",\n",
" )\n",
" image: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial image.\"\n",
" )\n",
" tensor: Optional[str] = Field(\n",
" default=None, description=\"The ID of the initial tensor.\"\n",
" )\n",
" # Pending model refactor:\n",
" # vae: Optional[str] = Field(default=None,description=\"The VAE used for decoding.\")\n",
" # unet: Optional[str] = Field(default=None,description=\"The UNet used dor inference.\")\n",
" # clip: Optional[str] = Field(default=None,description=\"The CLIP Encoder used for conditioning.\")\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123', negative_conditioning=None, width=None, height=None, seed=None, cfg_scale=None, steps=None, scheduler=None, model=None, strength=None, image=None, tensor=None)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"GeneratedImageOrLatentsMetadata(positive_conditioning='123')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from typing import Any, Optional
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.util.misc import get_timestamp
@ -50,6 +50,8 @@ class EventServiceBase:
result: dict,
node: dict,
source_node_id: str,
image_url: Optional[str] = None,
thumbnail_url: Optional[str] = None,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
@ -59,6 +61,8 @@ class EventServiceBase:
node=node,
source_node_id=source_node_id,
result=result,
image_url=image_url,
thumbnail_url=thumbnail_url
),
)

View File

@ -0,0 +1,180 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
import os
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, Optional
from PIL.Image import Image as PILImageType
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from send2trash import send2trash
from invokeai.app.models.image import ImageType
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
class ImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
# # TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or thumbnail."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
pnginfo: Optional[PngInfo] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
class DiskImageFileStorage(ImageFileStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, PILImageType]
__max_cache_size: int
def __init__(self, output_folder: str):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def get(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = Image.open(image_path)
self.__set_cache(image_path, image)
return image
except FileNotFoundError as e:
raise ImageFileStorageBase.ImageFileNotFoundException from e
def save(
self,
image: PILImageType,
image_type: ImageType,
image_name: str,
pnginfo: Optional[PngInfo] = None,
thumbnail_size: int = 256,
) -> None:
try:
image_path = self.get_path(image_type, image_name)
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
except Exception as e:
raise ImageFileStorageBase.ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
except Exception as e:
raise ImageFileStorageBase.ImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if thumbnail:
thumbnail_name = get_thumbnail_name(basename)
path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def __get_cache(self, image_name: str) -> PILImageType | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,318 @@
from abc import ABC, abstractmethod
import datetime
from typing import Optional
import sqlite3
import threading
from typing import Optional, Union
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.util.create_enum_table import create_enum_table
from invokeai.app.services.models.image_record import (
ImageRecord,
deserialize_image_record,
)
from invokeai.app.services.item_storage import PaginatedResults
class ImageRecordStorageBase(ABC):
"""Low-level service responsible for interfacing with the image record store."""
class ImageRecordNotFoundException(Exception):
"""Raised when an image record is not found."""
def __init__(self, message="Image record not found"):
super().__init__(message)
class ImageRecordSaveException(Exception):
"""Raised when an image record cannot be saved."""
def __init__(self, message="Image record not saved"):
super().__init__(message)
class ImageRecordDeleteException(Exception):
"""Raised when an image record cannot be deleted."""
def __init__(self, message="Image record not deleted"):
super().__init__(message)
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image record."""
pass
@abstractmethod
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
created_at: str = datetime.datetime.utcnow().isoformat(),
) -> None:
"""Saves an image record."""
pass
class SqliteImageRecordStorage(ImageRecordStorageBase):
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = threading.Lock()
try:
self._lock.acquire()
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
def _create_tables(self) -> None:
"""Creates the tables for the `images` database."""
# Create the `images` table.
self._cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY,
image_type TEXT, -- non-nullable via foreign key constraint
image_category TEXT, -- non-nullable via foreign key constraint
session_id TEXT, -- nullable
node_id TEXT, -- nullable
metadata TEXT, -- nullable
created_at TEXT NOT NULL,
FOREIGN KEY(image_type) REFERENCES image_types(type_name),
FOREIGN KEY(image_category) REFERENCES image_categories(category_name)
);
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_category ON images(image_category);
"""
)
# Create the tables for image-related enums
create_enum_table(
enum=ImageType,
table_name="image_types",
primary_key_name="type_name",
cursor=self._cursor,
)
create_enum_table(
enum=ImageCategory,
table_name="image_categories",
primary_key_name="category_name",
cursor=self._cursor,
)
# Create the `tags` table. TODO: do this elsewhere, shouldn't be in images db service
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tags (
id INTEGER PRIMARY KEY AUTOINCREMENT,
tag_name TEXT UNIQUE NOT NULL
);
"""
)
# Create the `images_tags` junction table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_tags (
image_id TEXT,
tag_id INTEGER,
PRIMARY KEY (image_id, tag_id),
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE,
FOREIGN KEY(tag_id) REFERENCES tags(id) ON DELETE CASCADE
);
"""
)
# Create the `images_favorites` table.
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_favorites (
image_id TEXT PRIMARY KEY,
favorited_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(image_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE id = ?;
""",
(image_name,),
)
result = self._cursor.fetchone()
except sqlite3.Error as e:
self._conn.rollback()
raise self.ImageRecordNotFoundException from e
finally:
self._lock.release()
if not result:
raise self.ImageRecordNotFoundException
return deserialize_image_record(result)
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageRecord]:
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT * FROM images
WHERE image_type = ? AND image_category = ?
LIMIT ? OFFSET ?;
""",
(image_type.value, image_category.value, per_page, page * per_page),
)
result = self._cursor.fetchall()
images = list(map(lambda r: deserialize_image_record(r), result))
self._cursor.execute(
"""--sql
SELECT count(*) FROM images
WHERE image_type = ? AND image_category = ?
""",
(image_type.value, image_category.value),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
self._conn.rollback()
raise e
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
)
def delete(self, image_type: ImageType, image_name: str) -> None:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
DELETE FROM images
WHERE id = ?;
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordDeleteException from e
finally:
self._lock.release()
def save(
self,
image_name: str,
image_type: ImageType,
image_category: ImageCategory,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
created_at: str,
) -> None:
try:
metadata_json = (
None if metadata is None else metadata.json(exclude_none=True)
)
self._lock.acquire()
self._cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
id,
image_type,
image_category,
node_id,
session_id,
metadata,
created_at
)
VALUES (?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_type.value,
image_category.value,
node_id,
session_id,
metadata_json,
created_at,
),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordStorageBase.ImageRecordNotFoundException from e
finally:
self._lock.release()

View File

@ -1,274 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import os
from glob import glob
from abc import ABC, abstractmethod
from pathlib import Path
from queue import Queue
from typing import Dict, List
from PIL.Image import Image
import PIL.Image as PILImage
from send2trash import send2trash
from invokeai.app.api.models.images import (
ImageResponse,
ImageResponseMetadata,
SavedImage,
)
from invokeai.app.models.image import ImageType
from invokeai.app.services.metadata import (
InvokeAIMetadata,
MetadataServiceBase,
build_invokeai_metadata_pnginfo,
)
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def get(self, image_type: ImageType, image_name: str) -> Image:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
"""Gets a paginated list of images."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the internal path to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
"""Gets the external URI to an image or its thumbnail."""
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def validate_path(self, path: str) -> bool:
"""Validates an image path."""
pass
@abstractmethod
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists)."""
pass
def create_name(self, context_id: str, node_id: str) -> str:
"""Creates a unique contextual image filename."""
return f"{context_id}_{node_id}_{str(get_timestamp())}.png"
class DiskImageStorage(ImageStorageBase):
"""Stores images on disk"""
__output_folder: str
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[str, Image]
__max_cache_size: int
__metadata_service: MetadataServiceBase
def __init__(self, output_folder: str, metadata_service: MetadataServiceBase):
self.__output_folder = output_folder
self.__cache = dict()
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__metadata_service = metadata_service
Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType:
Path(os.path.join(output_folder, image_type)).mkdir(
parents=True, exist_ok=True
)
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir(
parents=True, exist_ok=True
)
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getctime, reverse=True
)
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
invokeai_metadata = self.__metadata_service.get_metadata(img)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=filename,
# TODO: DiskImageStorage should not be building URLs...?
image_url=self.get_uri(image_type, filename),
thumbnail_url=self.get_uri(image_type, filename, True),
# TODO: Creation of this object should happen elsewhere (?), just making it fit here so it works
metadata=ImageResponseMetadata(
created=int(os.path.getctime(path)),
width=img.width,
height=img.height,
invokeai=invokeai_metadata,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
per_page=per_page,
total=count,
)
def get(self, image_type: ImageType, image_name: str) -> Image:
image_path = self.get_path(image_type, image_name)
cache_item = self.__get_cache(image_path)
if cache_item:
return cache_item
image = PILImage.open(image_path)
self.__set_cache(image_path, image)
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", basename
)
else:
path = os.path.join(self.__output_folder, image_type, basename)
abspath = os.path.abspath(path)
return abspath
def get_uri(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
# strip out any relative path shenanigans
basename = os.path.basename(image_name)
if is_thumbnail:
thumbnail_basename = get_thumbnail_name(basename)
uri = f"api/v1/images/{image_type.value}/thumbnails/{thumbnail_basename}"
else:
uri = f"api/v1/images/{image_type.value}/{basename}"
return uri
def validate_path(self, path: str) -> bool:
try:
os.stat(path)
return True
except Exception:
return False
def save(
self,
image_type: ImageType,
image_name: str,
image: Image,
metadata: InvokeAIMetadata | None = None,
) -> SavedImage:
image_path = self.get_path(image_type, image_name)
# TODO: Reading the image and then saving it strips the metadata...
if metadata:
pnginfo = build_invokeai_metadata_pnginfo(metadata=metadata)
image.save(image_path, "PNG", pnginfo=pnginfo)
else:
image.save(image_path) # this saved image has an empty info
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, is_thumbnail=True)
thumbnail_image = make_thumbnail(image)
thumbnail_image.save(thumbnail_path)
self.__set_cache(image_path, image)
self.__set_cache(thumbnail_path, thumbnail_image)
return SavedImage(
image_name=image_name,
thumbnail_name=thumbnail_name,
created=int(os.path.getctime(image_path)),
)
def delete(self, image_type: ImageType, image_name: str) -> None:
basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename)
if os.path.exists(image_path):
send2trash(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True)
if os.path.exists(thumbnail_path):
send2trash(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image | None:
return None if image_name not in self.__cache else self.__cache[image_name]
def __set_cache(self, image_name: str, image: Image):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:
del self.__cache[cache_id]

View File

@ -0,0 +1,355 @@
from abc import ABC, abstractmethod
import json
from logging import Logger
from typing import Optional, Union
import uuid
from PIL.Image import Image as PILImageType
from PIL import PngImagePlugin
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import (
ImageRecordStorageBase,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import ImageFileStorageBase
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
class ImageServiceABC(ABC):
"""
High-level service for image management.
Provides methods for creating, retrieving, and deleting images.
"""
@abstractmethod
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
pass
@abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
"""Gets an image record."""
pass
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
"""Gets an image's path"""
pass
@abstractmethod
def get_url(self, image_type: ImageType, image_name: str, thumbnail: bool = False) -> str:
"""Gets an image's or thumbnail's URL"""
pass
@abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
"""Gets an image DTO."""
pass
@abstractmethod
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
def delete(self, image_type: ImageType, image_name: str):
"""Deletes an image."""
pass
@abstractmethod
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Adds a tag to an image."""
pass
@abstractmethod
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
"""Removes a tag from an image."""
pass
@abstractmethod
def favorite(self, image_type: ImageType, image_id: str) -> None:
"""Favorites an image."""
pass
@abstractmethod
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
"""Unfavorites an image."""
pass
class ImageServiceDependencies:
"""Service dependencies for the ImageService."""
records: ImageRecordStorageBase
files: ImageFileStorageBase
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
):
self.records = image_record_storage
self.files = image_file_storage
self.metadata = metadata
self.urls = url
self.logger = logger
class ImageService(ImageServiceABC):
_services: ImageServiceDependencies
def __init__(
self,
image_record_storage: ImageRecordStorageBase,
image_file_storage: ImageFileStorageBase,
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
):
self._services = ImageServiceDependencies(
image_record_storage=image_record_storage,
image_file_storage=image_file_storage,
metadata=metadata,
url=url,
logger=logger,
)
def create(
self,
image: PILImageType,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
) -> ImageDTO:
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
timestamp = get_iso_timestamp()
if metadata is not None:
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("invokeai", json.dumps(metadata))
else:
pnginfo = None
try:
# TODO: Consider using a transaction here to ensure consistency between storage and database
self._services.files.save(
image_type=image_type,
image_name=image_name,
image=image,
pnginfo=pnginfo,
)
self._services.records.save(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
)
image_url = self._services.urls.get_image_url(image_type, image_name)
thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True
)
return ImageDTO(
image_name=image_name,
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
metadata=metadata,
created_at=timestamp,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except ImageRecordStorageBase.ImageRecordSaveException:
self._services.logger.error("Failed to save image record")
raise
except ImageFileStorageBase.ImageFileSaveException:
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_type, image_name)
except ImageFileStorageBase.ImageFileNotFoundException:
self._services.logger.error("Failed to get image file")
raise
except Exception as e:
self._services.logger.error("Problem getting image file")
raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord:
try:
return self._services.records.get(image_type, image_name)
except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image record")
raise e
def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.files.get_path(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail)
except Exception as e:
self._services.logger.error("Problem getting image path")
raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO:
try:
image_record = self._services.records.get(image_type, image_name)
image_dto = image_record_to_dto(
image_record,
self._services.urls.get_image_url(image_type, image_name),
self._services.urls.get_image_url(image_type, image_name, True),
)
return image_dto
except ImageRecordStorageBase.ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise
except Exception as e:
self._services.logger.error("Problem getting image DTO")
raise e
def get_many(
self,
image_type: ImageType,
image_category: ImageCategory,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
image_type,
image_category,
page,
per_page,
)
image_dtos = list(
map(
lambda r: image_record_to_dto(
r,
self._services.urls.get_image_url(image_type, r.image_name),
self._services.urls.get_image_url(
image_type, r.image_name, True
),
),
results.items,
)
)
return PaginatedResults[ImageDTO](
items=image_dtos,
page=results.page,
pages=results.pages,
per_page=results.per_page,
total=results.total,
)
except Exception as e:
self._services.logger.error("Problem getting paginated image DTOs")
raise e
def delete(self, image_type: ImageType, image_name: str):
# TODO: Consider using a transaction here to ensure consistency between storage and database
try:
self._services.files.delete(image_type, image_name)
self._services.records.delete(image_type, image_name)
except ImageRecordStorageBase.ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record")
raise
except ImageFileStorageBase.ImageFileDeleteException:
self._services.logger.error(f"Failed to delete image file")
raise
except Exception as e:
self._services.logger.error("Problem deleting image record and file")
raise e
def add_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'add_tag' method is not implemented yet.")
def remove_tag(self, image_type: ImageType, image_id: str, tag: str) -> None:
raise NotImplementedError("The 'remove_tag' method is not implemented yet.")
def favorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'favorite' method is not implemented yet.")
def unfavorite(self, image_type: ImageType, image_id: str) -> None:
raise NotImplementedError("The 'unfavorite' method is not implemented yet.")
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"

View File

@ -1,26 +1,32 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from typing import TYPE_CHECKING
from logging import Logger
from typing import types
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.images import ImageService
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .latent_storage import LatentsStorageBase
from .image_storage import ImageStorageBase
from .image_file_storage import ImageFileStorageBase
from .restoration_services import RestorationServices
from .invocation_queue import InvocationQueueABC
from .item_storage import ItemStorageABC
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.invoker import InvocationProcessorABC
class InvocationServices:
"""Services that can be used by invocations"""
events: EventServiceBase
latents: LatentsStorageBase
images: ImageStorageBase
metadata: MetadataServiceBase
images: ImageFileStorageBase
queue: InvocationQueueABC
model_manager: ModelManager
restoration: RestorationServices
images_new: ImageService
# NOTE: we must forward-declare any types that include invocations, since invocations can use services
graph_library: ItemStorageABC["LibraryGraph"]
@ -28,26 +34,26 @@ class InvocationServices:
processor: "InvocationProcessorABC"
def __init__(
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: types.ModuleType,
latents: LatentsStorageBase,
images: ImageStorageBase,
metadata: MetadataServiceBase,
queue: InvocationQueueABC,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
self,
model_manager: ModelManager,
events: EventServiceBase,
logger: Logger,
latents: LatentsStorageBase,
images: ImageFileStorageBase,
queue: InvocationQueueABC,
images_new: ImageService,
graph_library: ItemStorageABC["LibraryGraph"],
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
restoration: RestorationServices,
):
self.model_manager = model_manager
self.events = events
self.logger = logger
self.latents = latents
self.images = images
self.metadata = metadata
self.queue = queue
self.images_new = images_new
self.graph_library = graph_library
self.graph_execution_manager = graph_execution_manager
self.processor = processor

View File

@ -16,7 +16,7 @@ class LatentsStorageBase(ABC):
pass
@abstractmethod
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
pass
@abstractmethod
@ -47,8 +47,8 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
self.__set_cache(name, latent)
return latent
def set(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.set(name, data)
def save(self, name: str, data: torch.Tensor) -> None:
self.__underlying_storage.save(name, data)
self.__set_cache(name, data)
def delete(self, name: str) -> None:
@ -80,7 +80,7 @@ class DiskLatentsStorage(LatentsStorageBase):
latent_path = self.get_path(name)
return torch.load(latent_path)
def set(self, name: str, data: torch.Tensor) -> None:
def save(self, name: str, data: torch.Tensor) -> None:
latent_path = self.get_path(name)
torch.save(data, latent_path)

View File

@ -22,16 +22,24 @@ class MetadataLatentsField(TypedDict):
class MetadataColorField(TypedDict):
"""Pydantic-less ColorField, used for metadata parsing"""
r: int
g: int
b: int
a: int
# TODO: This is a placeholder for `InvocationsUnion` pending resolution of circular imports
NodeMetadata = Dict[
str, None | str | int | float | bool | MetadataImageField | MetadataLatentsField | MetadataColorField
str,
None
| str
| int
| float
| bool
| MetadataImageField
| MetadataLatentsField
| MetadataColorField,
]
@ -67,6 +75,11 @@ class MetadataServiceBase(ABC):
"""Builds an InvokeAIMetadata object"""
pass
# @abstractmethod
# def create_metadata(self, session_id: str, node_id: str) -> dict:
# """Creates metadata for a result"""
# pass
class PngMetadataService(MetadataServiceBase):
"""Handles loading and building metadata for images."""

View File

@ -0,0 +1,71 @@
import datetime
import sqlite3
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
class ImageRecord(BaseModel):
"""Deserialized image record."""
image_name: str = Field(description="The name of the image.")
image_type: ImageType = Field(description="The type of the image.")
image_category: ImageCategory = Field(description="The category of the image.")
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
session_id: Optional[str] = Field(default=None, description="The session ID.")
node_id: Optional[str] = Field(default=None, description="The node ID.")
metadata: Optional[ImageMetadata] = Field(
default=None, description="The image's metadata."
)
class ImageDTO(ImageRecord):
"""Deserialized image record with URLs."""
image_url: str = Field(description="The URL of the image.")
thumbnail_url: str = Field(description="The thumbnail URL of the image.")
def image_record_to_dto(
image_record: ImageRecord, image_url: str, thumbnail_url: str
) -> ImageDTO:
"""Converts an image record to an image DTO."""
return ImageDTO(
image_name=image_record.image_name,
image_type=image_record.image_type,
image_category=image_record.image_category,
created_at=image_record.created_at,
session_id=image_record.session_id,
node_id=image_record.node_id,
metadata=image_record.metadata,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
def deserialize_image_record(image_row: sqlite3.Row) -> ImageRecord:
"""Deserializes an image record."""
image_dict = dict(image_row)
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value))
raw_metadata = image_dict.get("metadata", "{}")
metadata = ImageMetadata.parse_raw(raw_metadata)
return ImageRecord(
image_name=image_dict.get("id", "unknown"),
session_id=image_dict.get("session_id", None),
node_id=image_dict.get("node_id", None),
metadata=metadata,
image_type=image_type,
image_category=ImageCategory(
image_dict.get("image_category", ImageCategory.IMAGE.value)
),
created_at=image_dict.get("created_at", get_iso_timestamp()),
)

View File

@ -1,7 +1,10 @@
import time
import traceback
from threading import Event, Thread, BoundedSemaphore
from typing import Any, TypeGuard
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.models.image import ImageType
from ..invocations.baseinvocation import InvocationContext
from .invocation_queue import InvocationQueueItem
from .invoker import InvocationProcessorABC, Invoker
@ -88,12 +91,30 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state
)
def is_image_output(obj: Any) -> TypeGuard[ImageOutput]:
return obj.__class__ == ImageOutput
outputs_dict = outputs.dict()
if is_image_output(outputs):
image_url = self.__invoker.services.images_new.get_url(
ImageType.RESULT, outputs.image.image_name
)
thumbnail_url = self.__invoker.services.images_new.get_url(
ImageType.RESULT, outputs.image.image_name, True
)
else:
image_url = None
thumbnail_url = None
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
result=outputs.dict(),
result=outputs_dict,
image_url=image_url,
thumbnail_url=thumbnail_url,
)
except KeyboardInterrupt:

View File

@ -0,0 +1,657 @@
from abc import ABC, abstractmethod
from enum import Enum
import enum
import sqlite3
import threading
from typing import Optional, Type, TypeVar, Union
from PIL.Image import Image as PILImage
from pydantic import BaseModel, Field
from torch import Tensor
from invokeai.app.services.item_storage import PaginatedResults
"""
Substantial proposed changes to the management of images and tensor.
tl;dr:
With the upcoming move to latents-only nodes, we need to handle metadata differently. After struggling with this unsuccessfully - trying to smoosh it in to the existing setup - I believe we need to expand the scope of the refactor to include the management of images and latents - and make `latents` a special case of `tensor`.
full story:
The consensus for tensor-only nodes' metadata was to traverse the execution graph and grab the core parameters to write to the image. This was straightforward, and I've written functions to find the nearest t2l/l2l, noise, and compel nodes and build the metadata from those.
But struggling to integrate this and the associated edge cases this brought up a number of issues deeper in the system (some of which I had previously implemented). The ImageStorageService is doing way too much, and we have a need to be able to retrieve sessions the session given image/latents id, which is not currently feasible due to SQLite's JSON parsing performance.
I made a new ResultsService and `results` table in the db to facilitate this. This first attempt failed because it doesn't handle uploads and leaves the codebase messy.
So I've spent the day trying to figure out to handle this in a sane way and think I've got something decent. I've described some changes to service bases and the database below.
The gist of it is to store the core parameters for an image in its metadata when the image is saved, but never to read from it. Instead, the same metadata is stored in the database, which will be set up for efficient access. So when a page of images is requested, the metadata comes from the db instead of a filesystem operation.
The URL generation responsibilities have been split off the image storage service in to a URL service. New database services/tables for images and tensor are added. These services will provide paginated images/tensors for the API to serve. This also paves the way for handling tensors as first-class outputs.
"""
# TODO: Make a new model for this
class ResourceOrigin(str, Enum):
"""The origin of a resource (eg image or tensor)."""
RESULTS = "results"
UPLOADS = "uploads"
INTERMEDIATES = "intermediates"
class ImageKind(str, Enum):
"""The kind of an image."""
IMAGE = "image"
CONTROL_IMAGE = "control_image"
class TensorKind(str, Enum):
"""The kind of a tensor."""
IMAGE_TENSOR = "tensor"
CONDITIONING = "conditioning"
"""
Core Generation Metadata Pydantic Model
I've already implemented the code to traverse a session to build this object.
"""
class CoreGenerationMetadata(BaseModel):
"""Core generation metadata for an image/tensor generated in InvokeAI.
Generated by traversing the execution graph, collecting the parameters of the nearest ancestors of a given node.
Full metadata may be accessed by querying for the session in the `graph_executions` table.
"""
positive_conditioning: Optional[str] = Field(
description="The positive conditioning."
)
negative_conditioning: Optional[str] = Field(
description="The negative conditioning."
)
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
seed: Optional[int] = Field(description="The seed used for noise generation.")
cfg_scale: Optional[float] = Field(
description="The classifier-free guidance scale."
)
steps: Optional[int] = Field(description="The number of steps used for inference.")
scheduler: Optional[str] = Field(description="The scheduler used for inference.")
model: Optional[str] = Field(description="The model used for inference.")
strength: Optional[float] = Field(
description="The strength used for image-to-image/tensor-to-tensor."
)
image: Optional[str] = Field(description="The ID of the initial image.")
tensor: Optional[str] = Field(description="The ID of the initial tensor.")
# Pending model refactor:
# vae: Optional[str] = Field(description="The VAE used for decoding.")
# unet: Optional[str] = Field(description="The UNet used dor inference.")
# clip: Optional[str] = Field(description="The CLIP Encoder used for conditioning.")
"""
Minimal Uploads Metadata Model
"""
class UploadsMetadata(BaseModel):
"""Limited metadata for an uploaded image/tensor."""
width: Optional[int] = Field(description="Width of the image/tensor in pixels.")
height: Optional[int] = Field(description="Height of the image/tensor in pixels.")
# The extra field will be the contents of the PNG file's tEXt chunk. It may have come
# from another SD application or InvokeAI, so we need to make it very flexible. I think it's
# best to just store it as a string and let the frontend parse it.
# If the upload is a tensor type, this will be omitted.
extra: Optional[str] = Field(
description="Extra metadata, extracted from the PNG tEXt chunk."
)
"""
Slimmed-down Image Storage Service Base
- No longer lists images or generates URLs - only stores and retrieves images.
- OSS implementation for disk storage
"""
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@abstractmethod
def save(
self,
image: PILImage,
image_kind: ImageKind,
origin: ResourceOrigin,
context_id: str,
node_id: str,
metadata: CoreGenerationMetadata,
) -> str:
"""Saves an image and its thumbnail, returning its unique identifier."""
pass
@abstractmethod
def get(self, id: str, thumbnail: bool = False) -> Union[PILImage, None]:
"""Retrieves an image as a PIL Image."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes an image."""
pass
class TensorStorageBase(ABC):
"""Responsible for storing and retrieving tensors."""
@abstractmethod
def save(
self,
tensor: Tensor,
tensor_kind: TensorKind,
origin: ResourceOrigin,
context_id: str,
node_id: str,
metadata: CoreGenerationMetadata,
) -> str:
"""Saves a tensor, returning its unique identifier."""
pass
@abstractmethod
def get(self, id: str, thumbnail: bool = False) -> Union[Tensor, None]:
"""Retrieves a tensor as a torch Tensor."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes a tensor."""
pass
"""
New Url Service Base
- Abstracts the logic for generating URLs out of the storage service
- OSS implementation for locally-hosted URLs
- Also provides a method to get the internal path to a resource (for OSS, the FS path)
"""
class ResourceLocationServiceBase(ABC):
"""Responsible for locating resources (eg images or tensors)."""
@abstractmethod
def get_url(self, id: str) -> str:
"""Gets the URL for a resource."""
pass
@abstractmethod
def get_path(self, id: str) -> str:
"""Gets the path for a resource."""
pass
"""
New Images Database Service Base
This is a new service that will be responsible for the new `images` table(s):
- Storing images in the table
- Retrieving individual images and pages of images
- Deleting individual images
Operations will typically use joins with the various `images` tables.
"""
class ImagesDbServiceBase(ABC):
"""Responsible for interfacing with `images` table."""
class GeneratedImageEntity(BaseModel):
id: str = Field(description="The unique identifier for the image.")
session_id: str = Field(description="The session ID.")
node_id: str = Field(description="The node ID.")
metadata: CoreGenerationMetadata = Field(
description="The metadata for the image."
)
class UploadedImageEntity(BaseModel):
id: str = Field(description="The unique identifier for the image.")
metadata: UploadsMetadata = Field(description="The metadata for the image.")
@abstractmethod
def get(self, id: str) -> Union[GeneratedImageEntity, UploadedImageEntity, None]:
"""Gets an image from the `images` table."""
pass
@abstractmethod
def get_many(
self, image_kind: ImageKind, page: int = 0, per_page: int = 10
) -> PaginatedResults[Union[GeneratedImageEntity, UploadedImageEntity]]:
"""Gets a page of images from the `images` table."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes an image from the `images` table."""
pass
@abstractmethod
def set(
self,
id: str,
image_kind: ImageKind,
session_id: Optional[str],
node_id: Optional[str],
metadata: CoreGenerationMetadata | UploadsMetadata,
) -> None:
"""Sets an image in the `images` table."""
pass
"""
New Tensor Database Service Base
This is a new service that will be responsible for the new `tensor` table:
- Storing tensor in the table
- Retrieving individual tensor and pages of tensor
- Deleting individual tensor
Operations will always use joins with the `tensor_metadata` table.
"""
class TensorDbServiceBase(ABC):
"""Responsible for interfacing with `tensor` table."""
class GeneratedTensorEntity(BaseModel):
id: str = Field(description="The unique identifier for the tensor.")
session_id: str = Field(description="The session ID.")
node_id: str = Field(description="The node ID.")
metadata: CoreGenerationMetadata = Field(
description="The metadata for the tensor."
)
class UploadedTensorEntity(BaseModel):
id: str = Field(description="The unique identifier for the tensor.")
metadata: UploadsMetadata = Field(description="The metadata for the tensor.")
@abstractmethod
def get(self, id: str) -> Union[GeneratedTensorEntity, UploadedTensorEntity, None]:
"""Gets a tensor from the `tensor` table."""
pass
@abstractmethod
def get_many(
self, tensor_kind: TensorKind, page: int = 0, per_page: int = 10
) -> PaginatedResults[Union[GeneratedTensorEntity, UploadedTensorEntity]]:
"""Gets a page of tensor from the `tensor` table."""
pass
@abstractmethod
def delete(self, id: str) -> None:
"""Deletes a tensor from the `tensor` table."""
pass
@abstractmethod
def set(
self,
id: str,
tensor_kind: TensorKind,
session_id: Optional[str],
node_id: Optional[str],
metadata: CoreGenerationMetadata | UploadsMetadata,
) -> None:
"""Sets a tensor in the `tensor` table."""
pass
"""
Database Changes
The existing tables will remain as-is, new tables will be added.
Tensor now also have the same types as images - `results`, `intermediates`, `uploads`. Storage, retrieval, and operations may diverge from images in the future, so they are managed separately.
A few `images` tables are created to store all images:
- `results` and `intermediates` images have additional data: `session_id` and `node_id`, and may be further differentiated in the future. For this reason, they each get their own table.
- `uploads` do not get their own table, as they are never going to have more than an `id`, `image_kind` and `timestamp`.
- `images_metadata` holds the same image metadata that is written to the image. This table, along with the URL service, allow us to more efficiently serve images without having to read the image from storage.
The same tables are made for `tensor` and for the moment, implementation is expected to be identical.
Schemas for each table below.
Insertions and updates of ancillary tables (e.g. `results_images`, `images_metadata`, etc) will need to be done manually in the services, but should be straightforward. Deletion via cascading will be handled by the database.
"""
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
"""
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
"""
delimiter = ", "
values = [f"('{e.value}')" for e in enum]
return delimiter.join(values)
def create_sql_table_from_enum(
enum: Type[Enum],
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
lock: threading.Lock,
):
"""
Creates and populates a table to be used as a functional enum.
"""
try:
lock.acquire()
values_string = create_sql_values_string_from_string_enum(enum)
cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {table_name} (
{primary_key_name} TEXT PRIMARY KEY
);
"""
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)
finally:
lock.release()
"""
`resource_origins` functions as an enum for the ResourceOrigin model.
"""
def create_resource_origins_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=ResourceOrigin,
table_name="resource_origins",
primary_key_name="origin_name",
cursor=cursor,
lock=lock,
)
"""
`image_kinds` functions as an enum for the ImageType model.
"""
def create_image_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=ImageKind,
table_name="image_kinds",
primary_key_name="kind_name",
cursor=cursor,
lock=lock,
)
"""
`tensor_kinds` functions as an enum for the TensorType model.
"""
def create_tensor_kinds_table(cursor: sqlite3.Cursor, lock: threading.Lock):
create_sql_table_from_enum(
enum=TensorKind,
table_name="tensor_kinds",
primary_key_name="kind_name",
cursor=cursor,
lock=lock,
)
"""
`images` stores all images, regardless of type
"""
def create_images_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images (
id TEXT PRIMARY KEY,
origin TEXT,
image_kind TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
FOREIGN KEY(image_kind) REFERENCES image_kinds(kind_name)
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_id ON images(id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_origin ON images(origin);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_images_image_kind ON images(image_kind);
"""
)
finally:
lock.release()
"""
`image_results` stores additional data specific to `results` images.
"""
def create_image_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS image_results (
images_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_results_images_id ON image_results(id);
"""
)
finally:
lock.release()
"""
`image_intermediates` stores additional data specific to `intermediates` images
"""
def create_image_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS image_intermediates (
images_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_image_intermediates_images_id ON image_intermediates(id);
"""
)
finally:
lock.release()
"""
`images_metadata` stores basic metadata for any image type
"""
def create_images_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS images_metadata (
images_id TEXT PRIMARY KEY,
metadata TEXT,
FOREIGN KEY(images_id) REFERENCES images(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_images_metadata_images_id ON images_metadata(images_id);
"""
)
finally:
lock.release()
# `tensor` table: stores references to tensor
def create_tensors_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensors (
id TEXT PRIMARY KEY,
origin TEXT,
tensor_kind TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(origin) REFERENCES resource_origins(origin_name),
FOREIGN KEY(tensor_kind) REFERENCES tensor_kinds(kind_name),
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_id ON tensors(id);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_tensors_origin ON tensors(origin);
"""
)
cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_tensors_tensor_kind ON tensors(tensor_kind);
"""
)
finally:
lock.release()
# `results_tensor` stores additional data specific to `result` tensor
def create_tensor_results_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensor_results (
tensor_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_results_tensor_id ON tensor_results(tensor_id);
"""
)
finally:
lock.release()
# `tensor_intermediates` stores additional data specific to `intermediate` tensor
def create_tensor_intermediates_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensor_intermediates (
tensor_id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
node_id TEXT NOT NULL,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensor_intermediates_tensor_id ON tensor_intermediates(tensor_id);
"""
)
finally:
lock.release()
# `tensors_metadata` table: stores generated/transformed metadata for tensor
def create_tensors_metadata_table(cursor: sqlite3.Cursor, lock: threading.Lock):
try:
lock.acquire()
cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS tensors_metadata (
tensor_id TEXT PRIMARY KEY,
metadata TEXT,
FOREIGN KEY(tensor_id) REFERENCES tensors(id) ON DELETE CASCADE
);
"""
)
cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_tensors_metadata_tensor_id ON tensors_metadata(tensor_id);
"""
)
finally:
lock.release()

View File

@ -0,0 +1,466 @@
from enum import Enum
from abc import ABC, abstractmethod
import json
import sqlite3
from threading import Lock
from typing import Any, Union
import networkx as nx
from pydantic import BaseModel, Field, parse_obj_as, parse_raw_as
from invokeai.app.invocations.image import ImageOutput
from invokeai.app.services.graph import Edge, GraphExecutionState
from invokeai.app.invocations.latent import LatentsOutput
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.misc import get_timestamp
class ResultType(str, Enum):
image_output = "image_output"
latents_output = "latents_output"
class Result(BaseModel):
"""A session result"""
id: str = Field(description="Result ID")
session_id: str = Field(description="Session ID")
node_id: str = Field(description="Node ID")
data: Union[LatentsOutput, ImageOutput] = Field(description="The result data")
class ResultWithSession(BaseModel):
"""A result with its session"""
result: Result = Field(description="The result")
session: GraphExecutionState = Field(description="The session")
# Create a directed graph
from typing import Any, TypedDict, Union
from networkx import DiGraph
import networkx as nx
import json
# We need to use a loose class for nodes to allow for graceful parsing - we cannot use the stricter
# model used by the system, because we may be a graph in an old format. We can, however, use the
# Edge model, because the edge format does not change.
class LooseGraph(BaseModel):
id: str
nodes: dict[str, dict[str, Any]]
edges: list[Edge]
# An intermediate type used during parsing
class NearestAncestor(TypedDict):
node_id: str
metadata: dict[str, Any]
# The ancestor types that contain the core metadata
ANCESTOR_TYPES = ['t2l', 'l2l']
# The core metadata parameters in the ancestor types
ANCESTOR_PARAMS = ['steps', 'model', 'cfg_scale', 'scheduler', 'strength']
# The core metadata parameters in the noise node
NOISE_FIELDS = ['seed', 'width', 'height']
# Find nearest t2l or l2l ancestor from a given l2i node
def find_nearest_ancestor(G: DiGraph, node_id: str) -> Union[NearestAncestor, None]:
"""Returns metadata for the nearest ancestor of a given node.
Parameters:
G (DiGraph): A directed graph.
node_id (str): The ID of the starting node.
Returns:
NearestAncestor | None: An object with the ID and metadata of the nearest ancestor.
"""
# Retrieve the node from the graph
node = G.nodes[node_id]
# If the node type is one of the core metadata node types, gather necessary metadata and return
if node.get('type') in ANCESTOR_TYPES:
parsed_metadata = {param: val for param, val in node.items() if param in ANCESTOR_PARAMS}
return NearestAncestor(node_id=node_id, metadata=parsed_metadata)
# Else, look for the ancestor in the predecessor nodes
for predecessor in G.predecessors(node_id):
result = find_nearest_ancestor(G, predecessor)
if result:
return result
# If there are no valid ancestors, return None
return None
def get_additional_metadata(graph: LooseGraph, node_id: str) -> Union[dict[str, Any], None]:
"""Collects additional metadata from nodes connected to a given node.
Parameters:
graph (LooseGraph): The graph.
node_id (str): The ID of the node.
Returns:
dict | None: A dictionary containing additional metadata.
"""
metadata = {}
# Iterate over all edges in the graph
for edge in graph.edges:
dest_node_id = edge.destination.node_id
dest_field = edge.destination.field
source_node = graph.nodes[edge.source.node_id]
# If the destination node ID matches the given node ID, gather necessary metadata
if dest_node_id == node_id:
# If the destination field is 'positive_conditioning', add the 'prompt' from the source node
if dest_field == 'positive_conditioning':
metadata['positive_conditioning'] = source_node.get('prompt')
# If the destination field is 'negative_conditioning', add the 'prompt' from the source node
if dest_field == 'negative_conditioning':
metadata['negative_conditioning'] = source_node.get('prompt')
# If the destination field is 'noise', add the core noise fields from the source node
if dest_field == 'noise':
for field in NOISE_FIELDS:
metadata[field] = source_node.get(field)
return metadata
def build_core_metadata(graph_raw: str, node_id: str) -> Union[dict, None]:
"""Builds the core metadata for a given node.
Parameters:
graph_raw (str): The graph structure as a raw string.
node_id (str): The ID of the node.
Returns:
dict | None: A dictionary containing core metadata.
"""
# Create a directed graph to facilitate traversal
G = nx.DiGraph()
# Convert the raw graph string into a JSON object
graph = parse_obj_as(LooseGraph, graph_raw)
# Add nodes and edges to the graph
for node_id, node_data in graph.nodes.items():
G.add_node(node_id, **node_data)
for edge in graph.edges:
G.add_edge(edge.source.node_id, edge.destination.node_id)
# Find the nearest ancestor of the given node
ancestor = find_nearest_ancestor(G, node_id)
# If no ancestor was found, return None
if ancestor is None:
return None
metadata = ancestor['metadata']
ancestor_id = ancestor['node_id']
# Get additional metadata related to the ancestor
addl_metadata = get_additional_metadata(graph, ancestor_id)
# If additional metadata was found, add it to the main metadata
if addl_metadata is not None:
metadata.update(addl_metadata)
return metadata
class ResultsServiceABC(ABC):
"""The Results service is responsible for retrieving results."""
@abstractmethod
def get(
self, result_id: str, result_type: ResultType
) -> Union[ResultWithSession, None]:
pass
@abstractmethod
def get_many(
self, result_type: ResultType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[ResultWithSession]:
pass
@abstractmethod
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
pass
class SqliteResultsService(ResultsServiceABC):
"""SQLite implementation of the Results service."""
_filename: str
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: Lock
def __init__(self, filename: str):
super().__init__()
self._filename = filename
self._lock = Lock()
self._conn = sqlite3.connect(
self._filename, check_same_thread=False
) # TODO: figure out a better threading solution
self._cursor = self._conn.cursor()
self._create_table()
def _create_table(self):
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS results (
id TEXT PRIMARY KEY, -- the result's name
result_type TEXT, -- `image_output` | `latents_output`
node_id TEXT, -- the node that produced this result
session_id TEXT, -- the session that produced this result
created_at INTEGER, -- the time at which this result was created
data TEXT -- the result itself
);
"""
)
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_result_id ON results(id);
"""
)
finally:
self._lock.release()
def _parse_joined_result(self, result_row: Any, column_names: list[str]):
result_raw = {}
session_raw = {}
for idx, name in enumerate(column_names):
if name == "session":
session_raw = json.loads(result_row[idx])
elif name == "data":
result_raw[name] = json.loads(result_row[idx])
else:
result_raw[name] = result_row[idx]
graph_raw = session_raw['execution_graph']
result = parse_obj_as(Result, result_raw)
session = parse_obj_as(GraphExecutionState, session_raw)
m = build_core_metadata(graph_raw, result.node_id)
print(m)
# g = session.execution_graph.nx_graph()
# ancestors = nx.dag.ancestors(g, result.node_id)
# nodes = [session.execution_graph.get_node(result.node_id)]
# for ancestor in ancestors:
# nodes.append(session.execution_graph.get_node(ancestor))
# filtered_nodes = filter(lambda n: n.type in NODE_TYPE_ALLOWLIST, nodes)
# print(list(map(lambda n: n.dict(), filtered_nodes)))
# metadata = {}
# for node in nodes:
# if (node.type in ['txt2img', 'img2img',])
# for field, value in node.dict().items():
# if field not in ['type', 'id']:
# if field not in metadata:
# metadata[field] = value
# print(ancestors)
# print(nodes)
# print(metadata)
# for node in nodes:
# print(node.dict())
# print(nodes)
return ResultWithSession(
result=result,
session=session,
)
def get(
self, result_id: str, result_type: ResultType
) -> Union[ResultWithSession, None]:
"""Retrieves a result by ID and type."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT
results.id AS id,
results.result_type AS result_type,
results.node_id AS node_id,
results.session_id AS session_id,
results.data AS data,
graph_executions.item AS session
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.id = ? AND results.result_type = ?
""",
(result_id, result_type),
)
result_row = self._cursor.fetchone()
if result_row is None:
return None
column_names = list(map(lambda x: x[0], self._cursor.description))
result_parsed = self._parse_joined_result(result_row, column_names)
finally:
self._lock.release()
if not result_parsed:
return None
return result_parsed
def get_many(
self,
result_type: ResultType,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ResultWithSession]:
"""Lists results of a given type."""
try:
self._lock.acquire()
self._cursor.execute(
f"""--sql
SELECT
results.id AS id,
results.result_type AS result_type,
results.node_id AS node_id,
results.session_id AS session_id,
results.data AS data,
graph_executions.item AS session
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE results.result_type = ?
LIMIT ? OFFSET ?;
""",
(result_type.value, per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
column_names = list(map(lambda c: c[0], self._cursor.description))
result_parsed = []
for result_row in result_rows:
result_parsed.append(
self._parse_joined_result(result_row, column_names)
)
self._cursor.execute("""SELECT count(*) FROM results;""")
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=result_parsed,
page=page,
pages=pageCount,
per_page=per_page,
total=count,
)
def search(
self,
query: str,
page: int = 0,
per_page: int = 10,
) -> PaginatedResults[ResultWithSession]:
"""Finds results by query."""
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT results.data, graph_executions.item
FROM results
JOIN graph_executions ON results.session_id = graph_executions.id
WHERE item LIKE ?
LIMIT ? OFFSET ?;
""",
(f"%{query}%", per_page, page * per_page),
)
result_rows = self._cursor.fetchall()
items = list(
map(
lambda r: ResultWithSession(
result=parse_raw_as(Result, r[0]),
session=parse_raw_as(GraphExecutionState, r[1]),
),
result_rows,
)
)
self._cursor.execute(
"""--sql
SELECT count(*) FROM results WHERE item LIKE ?;
""",
(f"%{query}%",),
)
count = self._cursor.fetchone()[0]
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults[ResultWithSession](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def handle_graph_execution_state_change(self, session: GraphExecutionState) -> None:
"""Updates the results table with the results from the session."""
with self._conn as conn:
for node_id, result in session.results.items():
# We'll only process 'image_output' or 'latents_output'
if result.type not in ["image_output", "latents_output"]:
continue
# The id depends on the result type
if result.type == "image_output":
id = result.image.image_name
result_type = "image_output"
else:
id = result.latents.latents_name
result_type = "latents_output"
# Insert the result into the results table, ignoring if it already exists
conn.execute(
"""--sql
INSERT OR IGNORE INTO results (id, result_type, node_id, session_id, created_at, data)
VALUES (?, ?, ?, ?, ?, ?)
""",
(
id,
result_type,
node_id,
session.id,
get_timestamp(),
result.json(),
),
)

View File

@ -0,0 +1,30 @@
import os
from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType
from invokeai.app.util.thumbnails import get_thumbnail_name
class UrlServiceBase(ABC):
"""Responsible for building URLs for resources."""
@abstractmethod
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
"""Gets the URL for an image or thumbnail."""
pass
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
self._base_url = base_url
def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False
) -> str:
image_basename = os.path.basename(image_name)
if thumbnail:
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}/thumbnail"
return f"{self._base_url}/files/images/{image_type.value}/{image_basename}"

View File

@ -0,0 +1,39 @@
from enum import Enum
import sqlite3
from typing import Type
def create_sql_values_string_from_string_enum(enum: Type[Enum]):
"""
Creates a string of the form "('value1'), ('value2'), ..., ('valueN')" from a StrEnum.
"""
delimiter = ", "
values = [f"('{e.value}')" for e in enum]
return delimiter.join(values)
def create_enum_table(
enum: Type[Enum],
table_name: str,
primary_key_name: str,
cursor: sqlite3.Cursor,
):
"""
Creates and populates a table to be used as a functional enum.
"""
values_string = create_sql_values_string_from_string_enum(enum)
cursor.execute(
f"""--sql
CREATE TABLE IF NOT EXISTS {table_name} (
{primary_key_name} TEXT PRIMARY KEY
);
"""
)
cursor.execute(
f"""--sql
INSERT OR IGNORE INTO {table_name} ({primary_key_name}) VALUES {values_string};
"""
)

12
invokeai/app/util/enum.py Normal file
View File

@ -0,0 +1,12 @@
from enum import EnumMeta
class MetaEnum(EnumMeta):
"""Metaclass to support `in` syntax value checking in String Enums"""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True

View File

@ -6,6 +6,14 @@ def get_timestamp():
return int(datetime.datetime.now(datetime.timezone.utc).timestamp())
def get_iso_timestamp() -> str:
return datetime.datetime.utcnow().isoformat()
def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max

View File

@ -76,16 +76,16 @@ class InvokeAILogFormatter(logging.Formatter):
reset = "\x1b[0m"
# Log Format
format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
log_format = "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
## More Formatting Options: %(pathname)s, %(filename)s, %(module)s, %(lineno)d
# Format Map
FORMATS = {
logging.DEBUG: cyan + format + reset,
logging.INFO: grey + format + reset,
logging.WARNING: yellow + format + reset,
logging.ERROR: red + format + reset,
logging.CRITICAL: bold_red + format + reset
logging.DEBUG: cyan + log_format + reset,
logging.INFO: grey + log_format + reset,
logging.WARNING: yellow + log_format + reset,
logging.ERROR: red + log_format + reset,
logging.CRITICAL: bold_red + log_format + reset
}
def format(self, record):
@ -98,13 +98,13 @@ class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(self, name: str = 'InvokeAI') -> logging.Logger:
if name not in self.loggers:
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
if name not in cls.loggers:
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
fmt = InvokeAILogFormatter()
ch.setFormatter(fmt)
logger.addHandler(ch)
self.loggers[name] = logger
return self.loggers[name]
cls.loggers[name] = logger
return cls.loggers[name]

View File

@ -61,6 +61,7 @@ import UpscaleSettings from 'features/parameters/components/Parameters/Upscale/U
import { allParametersSet } from 'features/parameters/store/generationSlice';
import DeleteImageButton from './ImageActionButtons/DeleteImageButton';
import { useAppToaster } from 'app/components/Toaster';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
const currentImageButtonsSelector = createSelector(
[
@ -329,7 +330,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
if (!image) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(setInitialCanvasImage(image));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {

View File

@ -8,7 +8,7 @@ import type { ImageField } from './ImageField';
* Base class for invocations that output an image
*/
export type ImageOutput = {
type: 'image';
type: 'image_output';
/**
* The output image
*/

View File

@ -11,5 +11,13 @@ export type RandomIntInvocation = {
*/
id: string;
type?: 'rand_int';
/**
* The inclusive low value
*/
low?: number;
/**
* The exclusive high value
*/
high?: number;
};

View File

@ -12,5 +12,13 @@ export const $RandomIntInvocation = {
type: {
type: 'Enum',
},
low: {
type: 'number',
description: `The inclusive low value`,
},
high: {
type: 'number',
description: `The exclusive high value`,
},
},
} as const;

View File

@ -10,11 +10,16 @@ import {
CollectInvocationOutput,
ImageType,
ImageField,
LatentsOutput,
} from 'services/api';
export const isImageOutput = (
output: GraphExecutionState['results'][string]
): output is ImageOutput => output.type === 'image';
): output is ImageOutput => output.type === 'image_output';
export const isLatentsOutput = (
output: GraphExecutionState['results'][string]
): output is LatentsOutput => output.type === 'latents_output';
export const isMaskOutput = (
output: GraphExecutionState['results'][string]