mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
23 Commits
release/ad
...
release/3.
Author | SHA1 | Date | |
---|---|---|---|
f6127a1b6b | |||
7f457ca03d | |||
2b972cda6c | |||
47b0e1d7b4 | |||
fe0a16c846 | |||
f19c6069a9 | |||
fcba4382b2 | |||
6f45931711 | |||
278392d52c | |||
b2f942d714 | |||
6bc2253894 | |||
97d6f207d8 | |||
dc9a9d7bec | |||
15a3e49a40 | |||
7ccfc499dc | |||
56d0d80a39 | |||
2d64ee7f9e | |||
10ada84404 | |||
7744e01e2c | |||
ce8e5f9adf | |||
fc1021b6be | |||
fadfe1dfe9 | |||
2716ae353b |
2
.github/workflows/pypi-release.yml
vendored
2
.github/workflows/pypi-release.yml
vendored
@ -28,7 +28,7 @@ jobs:
|
|||||||
run: twine check dist/*
|
run: twine check dist/*
|
||||||
|
|
||||||
- name: check PyPI versions
|
- name: check PyPI versions
|
||||||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/v2.3' || github.ref == 'refs/heads/v3.3.0post1'
|
if: github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/')
|
||||||
run: |
|
run: |
|
||||||
pip install --upgrade requests
|
pip install --upgrade requests
|
||||||
python -c "\
|
python -c "\
|
||||||
|
@ -1,35 +1,35 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
||||||
|
from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies
|
||||||
|
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
||||||
|
from invokeai.app.services.boards import BoardService, BoardServiceDependencies
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||||
|
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||||
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
|
from invokeai.app.services.resource_name import SimpleNameService
|
||||||
|
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
|
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
|
from invokeai.app.services.urls import LocalUrlService
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
from ..services.default_graphs import create_system_graphs
|
||||||
from ..services.board_images.board_images_default import BoardImagesService
|
from ..services.graph import GraphExecutionState, LibraryGraph
|
||||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
from ..services.image_file_storage import DiskImageFileStorage
|
||||||
from ..services.boards.boards_default import BoardService
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
from ..services.config import InvokeAIAppConfig
|
|
||||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
|
||||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
|
||||||
from ..services.images.images_default import ImageService
|
|
||||||
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
||||||
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
|
||||||
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
from ..services.model_manager_service import ModelManagerService
|
||||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
from ..services.processor import DefaultInvocationProcessor
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.sqlite import SqliteItemStorage
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.thread import lock
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
|
||||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
|
||||||
from ..services.shared.default_graphs import create_system_graphs
|
|
||||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
|
||||||
from ..services.shared.sqlite import SqliteDatabase
|
|
||||||
from ..services.urls.urls_default import LocalUrlService
|
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
@ -63,64 +63,100 @@ class ApiDependencies:
|
|||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||||
|
|
||||||
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
output_folder = config.output_path
|
output_folder = config.output_path
|
||||||
|
|
||||||
db = SqliteDatabase(config, logger)
|
# TODO: build a file/path manager?
|
||||||
|
if config.use_memory_db:
|
||||||
|
db_location = ":memory:"
|
||||||
|
else:
|
||||||
|
db_path = config.db_path
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
db_location = str(db_path)
|
||||||
|
|
||||||
configuration = config
|
logger.info(f"Using database at {db_location}")
|
||||||
logger = logger
|
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||||
|
|
||||||
|
if config.log_sql:
|
||||||
|
db_conn.set_trace_callback(print)
|
||||||
|
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||||
|
|
||||||
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
||||||
|
conn=db_conn, table_name="graph_executions", lock=lock
|
||||||
|
)
|
||||||
|
|
||||||
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
|
||||||
board_images = BoardImagesService()
|
|
||||||
board_records = SqliteBoardRecordStorage(db=db)
|
|
||||||
boards = BoardService()
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
|
||||||
graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs")
|
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
|
||||||
image_records = SqliteImageRecordStorage(db=db)
|
|
||||||
images = ImageService()
|
|
||||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
|
||||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
|
||||||
model_manager = ModelManagerService(config, logger)
|
|
||||||
names = SimpleNameService()
|
|
||||||
performance_statistics = InvocationStatsService()
|
|
||||||
processor = DefaultInvocationProcessor()
|
|
||||||
queue = MemoryInvocationQueue()
|
|
||||||
session_processor = DefaultSessionProcessor()
|
|
||||||
session_queue = SqliteSessionQueue(db=db)
|
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
|
image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
|
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
names = SimpleNameService()
|
||||||
|
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||||
|
|
||||||
|
board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock)
|
||||||
|
board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock)
|
||||||
|
|
||||||
|
boards = BoardService(
|
||||||
|
services=BoardServiceDependencies(
|
||||||
|
board_image_record_storage=board_image_record_storage,
|
||||||
|
board_record_storage=board_record_storage,
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
board_images = BoardImagesService(
|
||||||
|
services=BoardImagesServiceDependencies(
|
||||||
|
board_image_record_storage=board_image_record_storage,
|
||||||
|
board_record_storage=board_record_storage,
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
images = ImageService(
|
||||||
|
services=ImageServiceDependencies(
|
||||||
|
board_image_record_storage=board_image_record_storage,
|
||||||
|
image_record_storage=image_record_storage,
|
||||||
|
image_file_storage=image_file_storage,
|
||||||
|
url=urls,
|
||||||
|
logger=logger,
|
||||||
|
names=names,
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
board_image_records=board_image_records,
|
model_manager=ModelManagerService(config, logger),
|
||||||
board_images=board_images,
|
|
||||||
board_records=board_records,
|
|
||||||
boards=boards,
|
|
||||||
configuration=configuration,
|
|
||||||
events=events,
|
events=events,
|
||||||
graph_execution_manager=graph_execution_manager,
|
|
||||||
graph_library=graph_library,
|
|
||||||
image_files=image_files,
|
|
||||||
image_records=image_records,
|
|
||||||
images=images,
|
|
||||||
invocation_cache=invocation_cache,
|
|
||||||
latents=latents,
|
latents=latents,
|
||||||
|
images=images,
|
||||||
|
boards=boards,
|
||||||
|
board_images=board_images,
|
||||||
|
queue=MemoryInvocationQueue(),
|
||||||
|
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"),
|
||||||
|
graph_execution_manager=graph_execution_manager,
|
||||||
|
processor=DefaultInvocationProcessor(),
|
||||||
|
configuration=config,
|
||||||
|
performance_statistics=InvocationStatsService(graph_execution_manager),
|
||||||
logger=logger,
|
logger=logger,
|
||||||
model_manager=model_manager,
|
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||||
names=names,
|
session_processor=DefaultSessionProcessor(),
|
||||||
performance_statistics=performance_statistics,
|
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
|
||||||
processor=processor,
|
|
||||||
queue=queue,
|
|
||||||
session_processor=session_processor,
|
|
||||||
session_queue=session_queue,
|
|
||||||
urls=urls,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
create_system_graphs(services.graph_library)
|
create_system_graphs(services.graph_library)
|
||||||
|
|
||||||
ApiDependencies.invoker = Invoker(services)
|
ApiDependencies.invoker = Invoker(services)
|
||||||
|
|
||||||
db.clean()
|
try:
|
||||||
|
lock.acquire()
|
||||||
|
db_conn.execute("VACUUM;")
|
||||||
|
db_conn.commit()
|
||||||
|
logger.info("Cleaned database")
|
||||||
|
finally:
|
||||||
|
lock.release()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shutdown():
|
def shutdown():
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any
|
|||||||
|
|
||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from ..services.events.events_base import EventServiceBase
|
from ..services.events import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
|
@ -4,9 +4,9 @@ from fastapi import Body, HTTPException, Path, Query
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
from invokeai.app.services.board_record_storage import BoardChanges
|
||||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
@ -8,9 +8,9 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.models.image_record import ImageDTO, ImageRecordChanges, ImageUrlsDTO
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
|
@ -18,9 +18,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
|
||||||
|
|
||||||
|
from ...services.graph import Graph
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||||
|
@ -6,12 +6,11 @@ from fastapi import Body, HTTPException, Path, Query, Response
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.shared.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
|
from ...services.item_storage import PaginatedResults
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
|
@ -5,7 +5,7 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.typing import Event
|
from fastapi_events.typing import Event
|
||||||
from socketio import ASGIApp, AsyncServer
|
from socketio import ASGIApp, AsyncServer
|
||||||
|
|
||||||
from ..services.events.events_base import EventServiceBase
|
from ..services.events import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class SocketIO:
|
class SocketIO:
|
||||||
|
@ -28,7 +28,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from pydantic.fields import ModelField, Undefined
|
from pydantic.fields import ModelField, Undefined
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
@ -27,9 +27,9 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
from ...backend.model_management import BaseModelType
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -6,7 +6,7 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("face_mask_output")
|
@invocation_output("face_mask_output")
|
||||||
|
@ -9,10 +9,10 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,12 +7,12 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
@ -34,7 +34,6 @@ from invokeai.app.invocations.primitives import (
|
|||||||
build_latents_output,
|
build_latents_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
@ -55,6 +54,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -14,13 +14,13 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util import choose_torch_device
|
from ...backend.util import choose_torch_device
|
||||||
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -10,7 +10,7 @@ from PIL import Image
|
|||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
4
invokeai/app/models/exceptions.py
Normal file
4
invokeai/app/models/exceptions.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class CanceledException(Exception):
|
||||||
|
"""Execution canceled by user."""
|
||||||
|
|
||||||
|
pass
|
71
invokeai/app/models/image.py
Normal file
71
invokeai/app/models/image.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
|
- INTERNAL: The resource was created by the application.
|
||||||
|
- EXTERNAL: The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNAL = "internal"
|
||||||
|
"""The resource was created by the application."""
|
||||||
|
EXTERNAL = "external"
|
||||||
|
"""The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOriginException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid resource origin."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The category of an image.
|
||||||
|
|
||||||
|
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||||
|
- MASK: The image is a mask image.
|
||||||
|
- CONTROL: The image is a ControlNet control image.
|
||||||
|
- USER: The image is a user-provide image.
|
||||||
|
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GENERAL = "general"
|
||||||
|
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||||
|
MASK = "mask"
|
||||||
|
"""MASK: The image is a mask image."""
|
||||||
|
CONTROL = "control"
|
||||||
|
"""CONTROL: The image is a ControlNet control image."""
|
||||||
|
USER = "user"
|
||||||
|
"""USER: The image is a user-provide image."""
|
||||||
|
OTHER = "other"
|
||||||
|
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidImageCategoryException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ImageCategory.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid image category."):
|
||||||
|
super().__init__(message)
|
@ -1,12 +1,55 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
|
|
||||||
from .board_image_records_base import BoardImageRecordStorageBase
|
|
||||||
|
class BoardImageRecordStorageBase(ABC):
|
||||||
|
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_count_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Gets the number of images for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||||
@ -14,11 +57,13 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._lock = db.lock
|
self._conn = conn
|
||||||
self._conn = db.conn
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
@ -1,47 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImageRecordStorageBase(ABC):
|
|
||||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_count_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""Gets the number of images for a board."""
|
|
||||||
pass
|
|
112
invokeai/app/services/board_images.py
Normal file
112
invokeai/app/services/board_images.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
|
from invokeai.app.services.board_record_storage import BoardRecord, BoardRecordStorageBase
|
||||||
|
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceABC(ABC):
|
||||||
|
"""High-level service for board-image relationship management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceDependencies:
|
||||||
|
"""Service dependencies for the BoardImagesService."""
|
||||||
|
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
|
board_records: BoardRecordStorageBase
|
||||||
|
image_records: ImageRecordStorageBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
board_record_storage: BoardRecordStorageBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
|
self.image_records = image_record_storage
|
||||||
|
self.board_records = board_record_storage
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesService(BoardImagesServiceABC):
|
||||||
|
_services: BoardImagesServiceDependencies
|
||||||
|
|
||||||
|
def __init__(self, services: BoardImagesServiceDependencies):
|
||||||
|
self._services = services
|
||||||
|
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self._services.board_image_records.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self._services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
board_id = self._services.board_image_records.get_board_for_image(image_name)
|
||||||
|
return board_id
|
||||||
|
|
||||||
|
|
||||||
|
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||||
|
"""Converts a board record to a board DTO."""
|
||||||
|
return BoardDTO(
|
||||||
|
**board_record.dict(exclude={"cover_image_name"}),
|
||||||
|
cover_image_name=cover_image_name,
|
||||||
|
image_count=image_count,
|
||||||
|
)
|
@ -1,39 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceABC(ABC):
|
|
||||||
"""High-level service for board-image relationship management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
@ -1,38 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
|
|
||||||
from .board_images_base import BoardImagesServiceABC
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesService(BoardImagesServiceABC):
|
|
||||||
__invoker: Invoker
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self.__invoker = invoker
|
|
||||||
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
|
||||||
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
|
||||||
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
|
||||||
return board_id
|
|
@ -1,20 +1,89 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from typing import Union, cast
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from pydantic import BaseModel, Extra, Field
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
|
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
from .board_records_base import BoardRecordStorageBase
|
|
||||||
from .board_records_common import (
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
BoardChanges,
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
BoardRecord,
|
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||||
BoardRecordDeleteException,
|
|
||||||
BoardRecordNotFoundException,
|
|
||||||
BoardRecordSaveException,
|
class BoardRecordNotFoundException(Exception):
|
||||||
deserialize_board_record,
|
"""Raised when an board record is not found."""
|
||||||
)
|
|
||||||
|
def __init__(self, message="Board record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordSaveException(Exception):
|
||||||
|
"""Raised when an board record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordDeleteException(Exception):
|
||||||
|
"""Raised when an board record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the board record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
"""Deletes a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Saves a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Gets a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Updates a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardRecord]:
|
||||||
|
"""Gets many board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
"""Gets all board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||||
@ -22,11 +91,13 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
|||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._lock = db.lock
|
self._conn = conn
|
||||||
self._conn = db.conn
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
@ -1,55 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
|
|
||||||
from .board_records_common import BoardChanges, BoardRecord
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the board record store."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, board_id: str) -> None:
|
|
||||||
"""Deletes a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Saves a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Gets a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Updates a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardRecord]:
|
|
||||||
"""Gets many board records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardRecord]:
|
|
||||||
"""Gets all board records."""
|
|
||||||
pass
|
|
158
invokeai/app/services/boards.py
Normal file
158
invokeai/app/services/boards.py
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
|
from invokeai.app.services.board_images import board_record_to_dto
|
||||||
|
from invokeai.app.services.board_record_storage import BoardChanges, BoardRecordStorageBase
|
||||||
|
from invokeai.app.services.image_record_storage import ImageRecordStorageBase, OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.models.board_record import BoardDTO
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceABC(ABC):
|
||||||
|
"""High-level service for board management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Creates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Gets a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Updates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
"""Gets many boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardDTO]:
|
||||||
|
"""Gets all boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceDependencies:
|
||||||
|
"""Service dependencies for the BoardService."""
|
||||||
|
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
|
board_records: BoardRecordStorageBase
|
||||||
|
image_records: ImageRecordStorageBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
board_record_storage: BoardRecordStorageBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
|
self.image_records = image_record_storage
|
||||||
|
self.board_records = board_record_storage
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
|
||||||
|
class BoardService(BoardServiceABC):
|
||||||
|
_services: BoardServiceDependencies
|
||||||
|
|
||||||
|
def __init__(self, services: BoardServiceDependencies):
|
||||||
|
self._services = services
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.save(board_name)
|
||||||
|
return board_record_to_dto(board_record, None, 0)
|
||||||
|
|
||||||
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.get(board_id)
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
board_record = self._services.board_records.update(board_id, changes)
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
|
||||||
|
return board_record_to_dto(board_record, cover_image_name, image_count)
|
||||||
|
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
self._services.board_records.delete(board_id)
|
||||||
|
|
||||||
|
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
board_records = self._services.board_records.get_many(offset, limit)
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records.items:
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
board_records = self._services.board_records.get_all()
|
||||||
|
board_dtos = []
|
||||||
|
for r in board_records:
|
||||||
|
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||||
|
if cover_image:
|
||||||
|
cover_image_name = cover_image.image_name
|
||||||
|
else:
|
||||||
|
cover_image_name = None
|
||||||
|
|
||||||
|
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
|
||||||
|
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
||||||
|
|
||||||
|
return board_dtos
|
@ -1,59 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
|
|
||||||
from .boards_common import BoardDTO
|
|
||||||
|
|
||||||
|
|
||||||
class BoardServiceABC(ABC):
|
|
||||||
"""High-level service for board management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Creates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Gets a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Updates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Deletes a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
"""Gets many boards."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardDTO]:
|
|
||||||
"""Gets all boards."""
|
|
||||||
pass
|
|
@ -1,23 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from ..board_records.board_records_common import BoardRecord
|
|
||||||
|
|
||||||
|
|
||||||
class BoardDTO(BoardRecord):
|
|
||||||
"""Deserialized board record with cover image URL and image count."""
|
|
||||||
|
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
|
||||||
"""The URL of the thumbnail of the most recent image in the board."""
|
|
||||||
image_count: int = Field(description="The number of images in the board.")
|
|
||||||
"""The number of images in the board."""
|
|
||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
|
||||||
"""Converts a board record to a board DTO."""
|
|
||||||
return BoardDTO(
|
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
|
||||||
cover_image_name=cover_image_name,
|
|
||||||
image_count=image_count,
|
|
||||||
)
|
|
@ -1,79 +0,0 @@
|
|||||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
|
||||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
|
|
||||||
from .boards_base import BoardServiceABC
|
|
||||||
from .boards_common import board_record_to_dto
|
|
||||||
|
|
||||||
|
|
||||||
class BoardService(BoardServiceABC):
|
|
||||||
__invoker: Invoker
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self.__invoker = invoker
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
board_record = self.__invoker.services.board_records.save(board_name)
|
|
||||||
return board_record_to_dto(board_record, None, 0)
|
|
||||||
|
|
||||||
def get_dto(self, board_id: str) -> BoardDTO:
|
|
||||||
board_record = self.__invoker.services.board_records.get(board_id)
|
|
||||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardDTO:
|
|
||||||
board_record = self.__invoker.services.board_records.update(board_id, changes)
|
|
||||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
|
|
||||||
return board_record_to_dto(board_record, cover_image_name, image_count)
|
|
||||||
|
|
||||||
def delete(self, board_id: str) -> None:
|
|
||||||
self.__invoker.services.board_records.delete(board_id)
|
|
||||||
|
|
||||||
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
board_records = self.__invoker.services.board_records.get_many(offset, limit)
|
|
||||||
board_dtos = []
|
|
||||||
for r in board_records.items:
|
|
||||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
|
||||||
|
|
||||||
def get_all(self) -> list[BoardDTO]:
|
|
||||||
board_records = self.__invoker.services.board_records.get_all()
|
|
||||||
board_dtos = []
|
|
||||||
for r in board_records:
|
|
||||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
|
||||||
if cover_image:
|
|
||||||
cover_image_name = cover_image.image_name
|
|
||||||
else:
|
|
||||||
cover_image_name = None
|
|
||||||
|
|
||||||
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
|
|
||||||
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
|
|
||||||
|
|
||||||
return board_dtos
|
|
@ -2,5 +2,5 @@
|
|||||||
Init file for InvokeAI configure package
|
Init file for InvokeAI configure package
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .config_base import PagingArgumentParser # noqa F401
|
from .base import PagingArgumentParser # noqa F401
|
||||||
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
from .invokeai_config import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||||
|
@ -12,6 +12,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import pydoc
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -20,7 +21,16 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get
|
|||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
|
||||||
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
|
"""
|
||||||
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
|
It also supports reading defaults from an init file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def print_help(self, file=None):
|
||||||
|
text = self.format_help()
|
||||||
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
@ -213,3 +223,18 @@ class InvokeAISettings(BaseSettings):
|
|||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||||
|
"""
|
||||||
|
Workaround for argparse type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
return str(value)
|
@ -1,41 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
|
||||||
|
|
||||||
"""
|
|
||||||
Base class for the InvokeAI configuration system.
|
|
||||||
It defines a type of pydantic BaseSettings object that
|
|
||||||
is able to read and write from an omegaconf-based config file,
|
|
||||||
with overriding of settings from environment variables and/or
|
|
||||||
the command line.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import pydoc
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
|
||||||
"""
|
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
|
||||||
It also supports reading defaults from an init file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def print_help(self, file=None):
|
|
||||||
text = self.format_help()
|
|
||||||
pydoc.pager(text)
|
|
||||||
|
|
||||||
|
|
||||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
|
||||||
"""
|
|
||||||
Workaround for argparse type checking.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
return str(value)
|
|
@ -177,7 +177,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, parse_obj_as
|
from pydantic import Field, parse_obj_as
|
||||||
|
|
||||||
from .config_base import InvokeAISettings
|
from .base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
@ -1,11 +1,10 @@
|
|||||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
from ..invocations.compel import CompelInvocation
|
||||||
|
from ..invocations.image import ImageNSFWBlurInvocation
|
||||||
from ...invocations.compel import CompelInvocation
|
from ..invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
||||||
from ...invocations.image import ImageNSFWBlurInvocation
|
from ..invocations.noise import NoiseInvocation
|
||||||
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
from ..invocations.primitives import IntegerInvocation
|
||||||
from ...invocations.noise import NoiseInvocation
|
|
||||||
from ...invocations.primitives import IntegerInvocation
|
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
|
||||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||||
|
|
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.invocations.model import ModelInfo
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@ -11,7 +11,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
|
||||||
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
@ -8,9 +8,11 @@ import networkx as nx
|
|||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from invokeai.app.invocations import * # noqa: F401 F403
|
from ..invocations import * # noqa: F401 F403
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from ..invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
Input,
|
||||||
@ -21,7 +23,6 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import json
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@ -8,11 +9,68 @@ from PIL import Image, PngImagePlugin
|
|||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
from .image_files_base import ImageFileStorageBase
|
|
||||||
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageFileNotFoundException(Exception):
|
||||||
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileSaveException(Exception):
|
||||||
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileDeleteException(Exception):
|
||||||
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> PILImageType:
|
||||||
|
"""Retrieves an image as PIL Image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||||
|
# 500 internal server error. I don't like having this method on the service.
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates the path given for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_name: str,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
thumbnail_size: int = 256,
|
||||||
|
) -> None:
|
||||||
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
||||||
@ -22,7 +80,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
__cache_ids: Queue # TODO: this is an incredibly naive cache
|
||||||
__cache: Dict[Path, PILImageType]
|
__cache: Dict[Path, PILImageType]
|
||||||
__max_cache_size: int
|
__max_cache_size: int
|
||||||
__invoker: Invoker
|
__compress_level: int
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
self.__cache = dict()
|
self.__cache = dict()
|
||||||
@ -31,12 +89,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
self.__thumbnails_folder = self.__output_folder / "thumbnails"
|
||||||
|
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
|
||||||
# Validate required output folders at launch
|
# Validate required output folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self.__invoker = invoker
|
|
||||||
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
def get(self, image_name: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(image_name)
|
||||||
@ -80,12 +136,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
|||||||
if original_workflow is not None:
|
if original_workflow is not None:
|
||||||
pnginfo.add_text("invokeai_workflow", original_workflow)
|
pnginfo.add_text("invokeai_workflow", original_workflow)
|
||||||
|
|
||||||
image.save(
|
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
|
||||||
image_path,
|
|
||||||
"PNG",
|
|
||||||
pnginfo=pnginfo,
|
|
||||||
compress_level=self.__invoker.services.configuration.png_compress_level,
|
|
||||||
)
|
|
||||||
|
|
||||||
thumbnail_name = get_thumbnail_name(image_name)
|
thumbnail_name = get_thumbnail_name(image_name)
|
||||||
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)
|
@ -1,42 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
|
||||||
"""Retrieves an image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
|
||||||
# 500 internal server error. I don't like having this method on the service.
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates the path given for an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_name: str,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
thumbnail_size: int = 256,
|
|
||||||
) -> None:
|
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
|
||||||
pass
|
|
@ -1,20 +0,0 @@
|
|||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
|
||||||
class ImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
@ -1,36 +1,164 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, cast
|
from typing import Generic, Optional, TypeVar, cast
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from pydantic import BaseModel, Field
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
from .image_records_base import ImageRecordStorageBase
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from .image_records_common import (
|
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
||||||
IMAGE_DTO_COLS,
|
|
||||||
ImageCategory,
|
T = TypeVar("T", bound=BaseModel)
|
||||||
ImageRecord,
|
|
||||||
ImageRecordChanges,
|
|
||||||
ImageRecordDeleteException,
|
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||||
ImageRecordNotFoundException,
|
"""Offset-paginated results"""
|
||||||
ImageRecordSaveException,
|
|
||||||
ResourceOrigin,
|
# fmt: off
|
||||||
deserialize_image_record,
|
items: list[T] = Field(description="Items")
|
||||||
|
offset: int = Field(description="Offset from which to retrieve items")
|
||||||
|
limit: int = Field(description="Limit of items to get")
|
||||||
|
total: int = Field(description="Total number of items in result")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an image record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordSaveException(Exception):
|
||||||
|
"""Raised when an image record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordDeleteException(Exception):
|
||||||
|
"""Raised when an image record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda c: "images." + c,
|
||||||
|
[
|
||||||
|
"image_name",
|
||||||
|
"image_origin",
|
||||||
|
"image_category",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"session_id",
|
||||||
|
"node_id",
|
||||||
|
"is_intermediate",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"deleted_at",
|
||||||
|
"starred",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||||
|
"""Gets an image's metadata'."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
"""Updates an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
|
"""Gets a page of image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_many(self, image_names: list[str]) -> None:
|
||||||
|
"""Deletes many image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> list[str]:
|
||||||
|
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
session_id: Optional[str],
|
||||||
|
node_id: Optional[str],
|
||||||
|
metadata: Optional[dict],
|
||||||
|
is_intermediate: bool = False,
|
||||||
|
starred: bool = False,
|
||||||
|
) -> datetime:
|
||||||
|
"""Saves an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||||
|
"""Gets the most recent image for a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._lock = db.lock
|
self._conn = conn
|
||||||
self._conn = db.conn
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self._conn.row_factory = sqlite3.Row
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._lock.acquire()
|
self._lock.acquire()
|
@ -1,84 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
|
|
||||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the image record store."""
|
|
||||||
|
|
||||||
# TODO: Implement an `update()` method
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
|
||||||
"""Gets an image's metadata'."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> None:
|
|
||||||
"""Updates an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: Optional[int] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageRecord]:
|
|
||||||
"""Gets a page of image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_many(self, image_names: list[str]) -> None:
|
|
||||||
"""Deletes many image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> list[str]:
|
|
||||||
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
session_id: Optional[str],
|
|
||||||
node_id: Optional[str],
|
|
||||||
metadata: Optional[dict],
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
starred: bool = False,
|
|
||||||
) -> datetime:
|
|
||||||
"""Saves an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
|
||||||
"""Gets the most recent image for a board."""
|
|
||||||
pass
|
|
449
invokeai/app/services/images.py
Normal file
449
invokeai/app/services/images.py
Normal file
@ -0,0 +1,449 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from typing import TYPE_CHECKING, Callable, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
|
from invokeai.app.models.image import (
|
||||||
|
ImageCategory,
|
||||||
|
InvalidImageCategoryException,
|
||||||
|
InvalidOriginException,
|
||||||
|
ResourceOrigin,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
||||||
|
from invokeai.app.services.image_file_storage import (
|
||||||
|
ImageFileDeleteException,
|
||||||
|
ImageFileNotFoundException,
|
||||||
|
ImageFileSaveException,
|
||||||
|
ImageFileStorageBase,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.image_record_storage import (
|
||||||
|
ImageRecordDeleteException,
|
||||||
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
|
ImageRecordStorageBase,
|
||||||
|
OffsetPaginatedResults,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
|
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
||||||
|
from invokeai.app.services.resource_name import NameServiceBase
|
||||||
|
from invokeai.app.services.urls import UrlServiceBase
|
||||||
|
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.services.graph import GraphExecutionState
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceABC(ABC):
|
||||||
|
"""High-level service for image management."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
|
"""Register a callback for when an image is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an image is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: ImageDTO) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: bool = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Creates an image, storing the file and its metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
|
"""Gets an image as a PIL image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
|
"""Gets an image's metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's or thumbnail's URL."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a paginated list of image DTOs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str):
|
||||||
|
"""Deletes an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> int:
|
||||||
|
"""Deletes all intermediate images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_images_on_board(self, board_id: str):
|
||||||
|
"""Deletes all images on a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceDependencies:
|
||||||
|
"""Service dependencies for the ImageService."""
|
||||||
|
|
||||||
|
image_records: ImageRecordStorageBase
|
||||||
|
image_files: ImageFileStorageBase
|
||||||
|
board_image_records: BoardImageRecordStorageBase
|
||||||
|
urls: UrlServiceBase
|
||||||
|
logger: Logger
|
||||||
|
names: NameServiceBase
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_record_storage: ImageRecordStorageBase,
|
||||||
|
image_file_storage: ImageFileStorageBase,
|
||||||
|
board_image_record_storage: BoardImageRecordStorageBase,
|
||||||
|
url: UrlServiceBase,
|
||||||
|
logger: Logger,
|
||||||
|
names: NameServiceBase,
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
|
):
|
||||||
|
self.image_records = image_record_storage
|
||||||
|
self.image_files = image_file_storage
|
||||||
|
self.board_image_records = board_image_record_storage
|
||||||
|
self.urls = url
|
||||||
|
self.logger = logger
|
||||||
|
self.names = names
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
|
|
||||||
|
|
||||||
|
class ImageService(ImageServiceABC):
|
||||||
|
_services: ImageServiceDependencies
|
||||||
|
|
||||||
|
def __init__(self, services: ImageServiceDependencies):
|
||||||
|
super().__init__()
|
||||||
|
self._services = services
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: bool = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
if image_origin not in ResourceOrigin:
|
||||||
|
raise InvalidOriginException
|
||||||
|
|
||||||
|
if image_category not in ImageCategory:
|
||||||
|
raise InvalidImageCategoryException
|
||||||
|
|
||||||
|
image_name = self._services.names.create_image_name()
|
||||||
|
|
||||||
|
# TODO: Do we want to store the graph in the image at all? I don't think so...
|
||||||
|
# graph = None
|
||||||
|
# if session_id is not None:
|
||||||
|
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
|
||||||
|
# if session_raw is not None:
|
||||||
|
# try:
|
||||||
|
# graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
|
# except Exception as e:
|
||||||
|
# self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
|
# graph = None
|
||||||
|
|
||||||
|
(width, height) = image.size
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||||
|
self._services.image_records.save(
|
||||||
|
# Non-nullable fields
|
||||||
|
image_name=image_name,
|
||||||
|
image_origin=image_origin,
|
||||||
|
image_category=image_category,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
# Meta fields
|
||||||
|
is_intermediate=is_intermediate,
|
||||||
|
# Nullable fields
|
||||||
|
node_id=node_id,
|
||||||
|
metadata=metadata,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
if board_id is not None:
|
||||||
|
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||||
|
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
|
||||||
|
image_dto = self.get_dto(image_name)
|
||||||
|
|
||||||
|
self._on_changed(image_dto)
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self._services.logger.error("Failed to save image record")
|
||||||
|
raise
|
||||||
|
except ImageFileSaveException:
|
||||||
|
self._services.logger.error("Failed to save image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
self._services.image_records.update(image_name, changes)
|
||||||
|
image_dto = self.get_dto(image_name)
|
||||||
|
self._on_changed(image_dto)
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordSaveException:
|
||||||
|
self._services.logger.error("Failed to update image record")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem updating image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
|
try:
|
||||||
|
return self._services.image_files.get(image_name)
|
||||||
|
except ImageFileNotFoundException:
|
||||||
|
self._services.logger.error("Failed to get image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
|
try:
|
||||||
|
return self._services.image_records.get(image_name)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image record")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
try:
|
||||||
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
|
||||||
|
image_dto = image_record_to_dto(
|
||||||
|
image_record,
|
||||||
|
self._services.urls.get_image_url(image_name),
|
||||||
|
self._services.urls.get_image_url(image_name, True),
|
||||||
|
self._services.board_image_records.get_board_for_image(image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
return image_dto
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
||||||
|
try:
|
||||||
|
image_record = self._services.image_records.get(image_name)
|
||||||
|
metadata = self._services.image_records.get_metadata(image_name)
|
||||||
|
|
||||||
|
if not image_record.session_id:
|
||||||
|
return ImageMetadata(metadata=metadata)
|
||||||
|
|
||||||
|
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
if session_raw:
|
||||||
|
try:
|
||||||
|
graph = get_metadata_graph_from_raw_session(session_raw)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.warn(f"Failed to parse session graph: {e}")
|
||||||
|
graph = None
|
||||||
|
|
||||||
|
return ImageMetadata(graph=graph, metadata=metadata)
|
||||||
|
except ImageRecordNotFoundException:
|
||||||
|
self._services.logger.error("Image record not found")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image DTO")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.image_files.get_path(image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
try:
|
||||||
|
return self._services.image_files.validate_path(path)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem validating image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
try:
|
||||||
|
return self._services.urls.get_image_url(image_name, thumbnail)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting image path")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
try:
|
||||||
|
results = self._services.image_records.get_many(
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
image_origin,
|
||||||
|
categories,
|
||||||
|
is_intermediate,
|
||||||
|
board_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dtos = list(
|
||||||
|
map(
|
||||||
|
lambda r: image_record_to_dto(
|
||||||
|
r,
|
||||||
|
self._services.urls.get_image_url(r.image_name),
|
||||||
|
self._services.urls.get_image_url(r.image_name, True),
|
||||||
|
self._services.board_image_records.get_board_for_image(r.image_name),
|
||||||
|
),
|
||||||
|
results.items,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return OffsetPaginatedResults[ImageDTO](
|
||||||
|
items=image_dtos,
|
||||||
|
offset=results.offset,
|
||||||
|
limit=results.limit,
|
||||||
|
total=results.total,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem getting paginated image DTOs")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete(self, image_name: str):
|
||||||
|
try:
|
||||||
|
self._services.image_files.delete(image_name)
|
||||||
|
self._services.image_records.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image record")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image file")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem deleting image record and file")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_images_on_board(self, board_id: str):
|
||||||
|
try:
|
||||||
|
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
for image_name in image_names:
|
||||||
|
self._services.image_files.delete(image_name)
|
||||||
|
self._services.image_records.delete_many(image_names)
|
||||||
|
for image_name in image_names:
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image records")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image files")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem deleting image records and files")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def delete_intermediates(self) -> int:
|
||||||
|
try:
|
||||||
|
image_names = self._services.image_records.delete_intermediates()
|
||||||
|
count = len(image_names)
|
||||||
|
for image_name in image_names:
|
||||||
|
self._services.image_files.delete(image_name)
|
||||||
|
self._on_deleted(image_name)
|
||||||
|
return count
|
||||||
|
except ImageRecordDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image records")
|
||||||
|
raise
|
||||||
|
except ImageFileDeleteException:
|
||||||
|
self._services.logger.error("Failed to delete image files")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self._services.logger.error("Problem deleting image records and files")
|
||||||
|
raise e
|
@ -1,129 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
|
||||||
from invokeai.app.services.image_records.image_records_common import (
|
|
||||||
ImageCategory,
|
|
||||||
ImageRecord,
|
|
||||||
ImageRecordChanges,
|
|
||||||
ResourceOrigin,
|
|
||||||
)
|
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceABC(ABC):
|
|
||||||
"""High-level service for image management."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
|
||||||
"""Register a callback for when an image is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an image is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: ImageDTO) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Creates an image, storing the file and its metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Updates an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
|
||||||
"""Gets an image as a PIL image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
|
||||||
"""Gets an image DTO."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
|
||||||
"""Gets an image's metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's or thumbnail's URL."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
|
||||||
"""Gets a paginated list of image DTOs."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str):
|
|
||||||
"""Deletes an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> int:
|
|
||||||
"""Deletes all intermediate images."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
|
||||||
"""Deletes all images on a board."""
|
|
||||||
pass
|
|
@ -1,41 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
|
||||||
"""The URLs for an image and its thumbnail."""
|
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
|
||||||
"""The unique name of the image."""
|
|
||||||
image_url: str = Field(description="The URL of the image.")
|
|
||||||
"""The URL of the image."""
|
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
|
||||||
"""The URL of the image's thumbnail."""
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
|
||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
|
||||||
image_record: ImageRecord,
|
|
||||||
image_url: str,
|
|
||||||
thumbnail_url: str,
|
|
||||||
board_id: Optional[str],
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Converts an image record to an image DTO."""
|
|
||||||
return ImageDTO(
|
|
||||||
**image_record.dict(),
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
board_id=board_id,
|
|
||||||
)
|
|
@ -1,286 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
|
||||||
|
|
||||||
from ..image_files.image_files_common import (
|
|
||||||
ImageFileDeleteException,
|
|
||||||
ImageFileNotFoundException,
|
|
||||||
ImageFileSaveException,
|
|
||||||
)
|
|
||||||
from ..image_records.image_records_common import (
|
|
||||||
ImageCategory,
|
|
||||||
ImageRecord,
|
|
||||||
ImageRecordChanges,
|
|
||||||
ImageRecordDeleteException,
|
|
||||||
ImageRecordNotFoundException,
|
|
||||||
ImageRecordSaveException,
|
|
||||||
InvalidImageCategoryException,
|
|
||||||
InvalidOriginException,
|
|
||||||
ResourceOrigin,
|
|
||||||
)
|
|
||||||
from .images_base import ImageServiceABC
|
|
||||||
from .images_common import ImageDTO, image_record_to_dto
|
|
||||||
|
|
||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
|
||||||
__invoker: Invoker
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self.__invoker = invoker
|
|
||||||
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
if image_origin not in ResourceOrigin:
|
|
||||||
raise InvalidOriginException
|
|
||||||
|
|
||||||
if image_category not in ImageCategory:
|
|
||||||
raise InvalidImageCategoryException
|
|
||||||
|
|
||||||
image_name = self.__invoker.services.names.create_image_name()
|
|
||||||
|
|
||||||
(width, height) = image.size
|
|
||||||
|
|
||||||
try:
|
|
||||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
|
||||||
self.__invoker.services.image_records.save(
|
|
||||||
# Non-nullable fields
|
|
||||||
image_name=image_name,
|
|
||||||
image_origin=image_origin,
|
|
||||||
image_category=image_category,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
# Meta fields
|
|
||||||
is_intermediate=is_intermediate,
|
|
||||||
# Nullable fields
|
|
||||||
node_id=node_id,
|
|
||||||
metadata=metadata,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
if board_id is not None:
|
|
||||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
|
||||||
self.__invoker.services.image_files.save(
|
|
||||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
|
||||||
)
|
|
||||||
image_dto = self.get_dto(image_name)
|
|
||||||
|
|
||||||
self._on_changed(image_dto)
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordSaveException:
|
|
||||||
self.__invoker.services.logger.error("Failed to save image record")
|
|
||||||
raise
|
|
||||||
except ImageFileSaveException:
|
|
||||||
self.__invoker.services.logger.error("Failed to save image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> ImageDTO:
|
|
||||||
try:
|
|
||||||
self.__invoker.services.image_records.update(image_name, changes)
|
|
||||||
image_dto = self.get_dto(image_name)
|
|
||||||
self._on_changed(image_dto)
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordSaveException:
|
|
||||||
self.__invoker.services.logger.error("Failed to update image record")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem updating image record")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
|
||||||
try:
|
|
||||||
return self.__invoker.services.image_files.get(image_name)
|
|
||||||
except ImageFileNotFoundException:
|
|
||||||
self.__invoker.services.logger.error("Failed to get image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image file")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
|
||||||
try:
|
|
||||||
return self.__invoker.services.image_records.get(image_name)
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self.__invoker.services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image record")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
|
||||||
try:
|
|
||||||
image_record = self.__invoker.services.image_records.get(image_name)
|
|
||||||
|
|
||||||
image_dto = image_record_to_dto(
|
|
||||||
image_record,
|
|
||||||
self.__invoker.services.urls.get_image_url(image_name),
|
|
||||||
self.__invoker.services.urls.get_image_url(image_name, True),
|
|
||||||
self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
|
||||||
)
|
|
||||||
|
|
||||||
return image_dto
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self.__invoker.services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
|
|
||||||
try:
|
|
||||||
image_record = self.__invoker.services.image_records.get(image_name)
|
|
||||||
metadata = self.__invoker.services.image_records.get_metadata(image_name)
|
|
||||||
|
|
||||||
if not image_record.session_id:
|
|
||||||
return ImageMetadata(metadata=metadata)
|
|
||||||
|
|
||||||
session_raw = self.__invoker.services.graph_execution_manager.get_raw(image_record.session_id)
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
if session_raw:
|
|
||||||
try:
|
|
||||||
graph = get_metadata_graph_from_raw_session(session_raw)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.warn(f"Failed to parse session graph: {e}")
|
|
||||||
graph = None
|
|
||||||
|
|
||||||
return ImageMetadata(graph=graph, metadata=metadata)
|
|
||||||
except ImageRecordNotFoundException:
|
|
||||||
self.__invoker.services.logger.error("Image record not found")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
try:
|
|
||||||
return self.__invoker.services.image_files.get_path(image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
try:
|
|
||||||
return self.__invoker.services.image_files.validate_path(path)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem validating image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
try:
|
|
||||||
return self.__invoker.services.urls.get_image_url(image_name, thumbnail)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting image path")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
|
||||||
try:
|
|
||||||
results = self.__invoker.services.image_records.get_many(
|
|
||||||
offset,
|
|
||||||
limit,
|
|
||||||
image_origin,
|
|
||||||
categories,
|
|
||||||
is_intermediate,
|
|
||||||
board_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_dtos = list(
|
|
||||||
map(
|
|
||||||
lambda r: image_record_to_dto(
|
|
||||||
r,
|
|
||||||
self.__invoker.services.urls.get_image_url(r.image_name),
|
|
||||||
self.__invoker.services.urls.get_image_url(r.image_name, True),
|
|
||||||
self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
|
||||||
),
|
|
||||||
results.items,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return OffsetPaginatedResults[ImageDTO](
|
|
||||||
items=image_dtos,
|
|
||||||
offset=results.offset,
|
|
||||||
limit=results.limit,
|
|
||||||
total=results.total,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem getting paginated image DTOs")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete(self, image_name: str):
|
|
||||||
try:
|
|
||||||
self.__invoker.services.image_files.delete(image_name)
|
|
||||||
self.__invoker.services.image_records.delete(image_name)
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image record")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image file")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem deleting image record and file")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
|
||||||
try:
|
|
||||||
image_names = self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
|
||||||
for image_name in image_names:
|
|
||||||
self.__invoker.services.image_files.delete(image_name)
|
|
||||||
self.__invoker.services.image_records.delete_many(image_names)
|
|
||||||
for image_name in image_names:
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image records")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image files")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem deleting image records and files")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def delete_intermediates(self) -> int:
|
|
||||||
try:
|
|
||||||
image_names = self.__invoker.services.image_records.delete_intermediates()
|
|
||||||
count = len(image_names)
|
|
||||||
for image_name in image_names:
|
|
||||||
self.__invoker.services.image_files.delete(image_name)
|
|
||||||
self._on_deleted(image_name)
|
|
||||||
return count
|
|
||||||
except ImageRecordDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image records")
|
|
||||||
raise
|
|
||||||
except ImageFileDeleteException:
|
|
||||||
self.__invoker.services.logger.error("Failed to delete image files")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.__invoker.services.logger.error("Problem deleting image records and files")
|
|
||||||
raise e
|
|
@ -1,5 +0,0 @@
|
|||||||
from abc import ABC
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC):
|
|
||||||
pass
|
|
@ -1,15 +0,0 @@
|
|||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class CanceledException(Exception):
|
|
||||||
"""Execution canceled by user."""
|
|
||||||
|
|
||||||
pass
|
|
@ -1,11 +1,45 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .invocation_queue_base import InvocationQueueABC
|
from pydantic import BaseModel, Field
|
||||||
from .invocation_queue_common import InvocationQueueItem
|
|
||||||
|
|
||||||
|
class InvocationQueueItem(BaseModel):
|
||||||
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
|
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||||
|
session_queue_item_id: int = Field(
|
||||||
|
description="The ID of session queue item from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
session_queue_batch_id: str = Field(
|
||||||
|
description="The ID of the session batch from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
invoke_all: bool = Field(default=False)
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationQueueABC(ABC):
|
||||||
|
"""Abstract base class for all invocation queues"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self) -> InvocationQueueItem:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationQueue(InvocationQueueABC):
|
class MemoryInvocationQueue(InvocationQueueABC):
|
@ -1,26 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .invocation_queue_common import InvocationQueueItem
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
|
||||||
"""Abstract base class for all invocation queues"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self) -> InvocationQueueItem:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cancel(self, graph_execution_state_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
|
||||||
pass
|
|
@ -1,19 +0,0 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
|
||||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
|
||||||
session_queue_item_id: int = Field(
|
|
||||||
description="The ID of session queue item from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
session_queue_batch_id: str = Field(
|
|
||||||
description="The ID of the session batch from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
invoke_all: bool = Field(default=False)
|
|
||||||
timestamp: float = Field(default_factory=time.time)
|
|
@ -6,27 +6,21 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
from invokeai.app.services.board_images import BoardImagesServiceABC
|
||||||
from .board_images.board_images_base import BoardImagesServiceABC
|
from invokeai.app.services.boards import BoardServiceABC
|
||||||
from .board_records.board_records_base import BoardRecordStorageBase
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .boards.boards_base import BoardServiceABC
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from .config import InvokeAIAppConfig
|
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
||||||
from .events.events_base import EventServiceBase
|
from invokeai.app.services.images import ImageServiceABC
|
||||||
from .image_files.image_files_base import ImageFileStorageBase
|
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from .image_records.image_records_base import ImageRecordStorageBase
|
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
||||||
from .images.images_base import ImageServiceABC
|
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
||||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
from invokeai.app.services.invoker import InvocationProcessorABC
|
||||||
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
from invokeai.app.services.item_storage import ItemStorageABC
|
||||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||||
from .item_storage.item_storage_base import ItemStorageABC
|
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
|
||||||
from .names.names_base import NameServiceBase
|
|
||||||
from .session_processor.session_processor_base import SessionProcessorBase
|
|
||||||
from .session_queue.session_queue_base import SessionQueueBase
|
|
||||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
|
||||||
from .urls.urls_base import UrlServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
@ -34,16 +28,12 @@ class InvocationServices:
|
|||||||
|
|
||||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||||
board_images: "BoardImagesServiceABC"
|
board_images: "BoardImagesServiceABC"
|
||||||
board_image_record_storage: "BoardImageRecordStorageBase"
|
|
||||||
boards: "BoardServiceABC"
|
boards: "BoardServiceABC"
|
||||||
board_records: "BoardRecordStorageBase"
|
|
||||||
configuration: "InvokeAIAppConfig"
|
configuration: "InvokeAIAppConfig"
|
||||||
events: "EventServiceBase"
|
events: "EventServiceBase"
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]"
|
graph_library: "ItemStorageABC[LibraryGraph]"
|
||||||
images: "ImageServiceABC"
|
images: "ImageServiceABC"
|
||||||
image_records: "ImageRecordStorageBase"
|
|
||||||
image_files: "ImageFileStorageBase"
|
|
||||||
latents: "LatentsStorageBase"
|
latents: "LatentsStorageBase"
|
||||||
logger: "Logger"
|
logger: "Logger"
|
||||||
model_manager: "ModelManagerServiceBase"
|
model_manager: "ModelManagerServiceBase"
|
||||||
@ -53,22 +43,16 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase"
|
session_queue: "SessionQueueBase"
|
||||||
session_processor: "SessionProcessorBase"
|
session_processor: "SessionProcessorBase"
|
||||||
invocation_cache: "InvocationCacheBase"
|
invocation_cache: "InvocationCacheBase"
|
||||||
names: "NameServiceBase"
|
|
||||||
urls: "UrlServiceBase"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
board_images: "BoardImagesServiceABC",
|
board_images: "BoardImagesServiceABC",
|
||||||
board_image_records: "BoardImageRecordStorageBase",
|
|
||||||
boards: "BoardServiceABC",
|
boards: "BoardServiceABC",
|
||||||
board_records: "BoardRecordStorageBase",
|
|
||||||
configuration: "InvokeAIAppConfig",
|
configuration: "InvokeAIAppConfig",
|
||||||
events: "EventServiceBase",
|
events: "EventServiceBase",
|
||||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||||
graph_library: "ItemStorageABC[LibraryGraph]",
|
graph_library: "ItemStorageABC[LibraryGraph]",
|
||||||
images: "ImageServiceABC",
|
images: "ImageServiceABC",
|
||||||
image_files: "ImageFileStorageBase",
|
|
||||||
image_records: "ImageRecordStorageBase",
|
|
||||||
latents: "LatentsStorageBase",
|
latents: "LatentsStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
@ -78,20 +62,14 @@ class InvocationServices:
|
|||||||
session_queue: "SessionQueueBase",
|
session_queue: "SessionQueueBase",
|
||||||
session_processor: "SessionProcessorBase",
|
session_processor: "SessionProcessorBase",
|
||||||
invocation_cache: "InvocationCacheBase",
|
invocation_cache: "InvocationCacheBase",
|
||||||
names: "NameServiceBase",
|
|
||||||
urls: "UrlServiceBase",
|
|
||||||
):
|
):
|
||||||
self.board_images = board_images
|
self.board_images = board_images
|
||||||
self.board_image_records = board_image_records
|
|
||||||
self.boards = boards
|
self.boards = boards
|
||||||
self.board_records = board_records
|
|
||||||
self.configuration = configuration
|
self.configuration = configuration
|
||||||
self.events = events
|
self.events = events
|
||||||
self.graph_execution_manager = graph_execution_manager
|
self.graph_execution_manager = graph_execution_manager
|
||||||
self.graph_library = graph_library
|
self.graph_library = graph_library
|
||||||
self.images = images
|
self.images = images
|
||||||
self.image_files = image_files
|
|
||||||
self.image_records = image_records
|
|
||||||
self.latents = latents
|
self.latents = latents
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
@ -101,5 +79,3 @@ class InvocationServices:
|
|||||||
self.session_queue = session_queue
|
self.session_queue = session_queue
|
||||||
self.session_processor = session_processor
|
self.session_processor = session_processor
|
||||||
self.invocation_cache = invocation_cache
|
self.invocation_cache = invocation_cache
|
||||||
self.names = names
|
|
||||||
self.urls = urls
|
|
||||||
|
@ -1,35 +1,171 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
from .invocation_stats_base import InvocationStatsServiceBase
|
from ..invocations.baseinvocation import BaseInvocation
|
||||||
from .invocation_stats_common import GIG, NodeLog, NodeStats
|
from .graph import GraphExecutionState
|
||||||
|
from .item_storage import ItemStorageABC
|
||||||
|
from .model_manager_service import ModelManagerService
|
||||||
|
|
||||||
|
# size of GIG in bytes
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
:param graph_execution_manager: Graph execution manager for this session
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state: GraphExecutionState object from the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
|
:param ram_used: How much RAM is currently in use.
|
||||||
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
||||||
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
"""Accumulate performance information about a running graph. Collects time spent in each node,
|
||||||
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
as well as the maximum and current VRAM utilisation for CUDA systems"""
|
||||||
|
|
||||||
_invoker: Invoker
|
def __init__(self, graph_execution_manager: ItemStorageABC["GraphExecutionState"]):
|
||||||
|
self.graph_execution_manager = graph_execution_manager
|
||||||
def __init__(self):
|
|
||||||
# {graph_id => NodeLog}
|
# {graph_id => NodeLog}
|
||||||
self._stats: Dict[str, NodeLog] = {}
|
self._stats: Dict[str, NodeLog] = {}
|
||||||
self._cache_stats: Dict[str, CacheStats] = {}
|
self._cache_stats: Dict[str, CacheStats] = {}
|
||||||
self.ram_used: float = 0.0
|
self.ram_used: float = 0.0
|
||||||
self.ram_changed: float = 0.0
|
self.ram_changed: float = 0.0
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self._invoker = invoker
|
|
||||||
|
|
||||||
class StatsContext:
|
class StatsContext:
|
||||||
"""Context manager for collecting statistics."""
|
"""Context manager for collecting statistics."""
|
||||||
|
|
||||||
@ -38,13 +174,13 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
graph_id: str
|
graph_id: str
|
||||||
start_time: float
|
start_time: float
|
||||||
ram_used: int
|
ram_used: int
|
||||||
model_manager: ModelManagerServiceBase
|
model_manager: ModelManagerService
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
model_manager: ModelManagerServiceBase,
|
model_manager: ModelManagerService,
|
||||||
collector: "InvocationStatsServiceBase",
|
collector: "InvocationStatsServiceBase",
|
||||||
):
|
):
|
||||||
"""Initialize statistics for this run."""
|
"""Initialize statistics for this run."""
|
||||||
@ -81,11 +217,12 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
self,
|
self,
|
||||||
invocation: BaseInvocation,
|
invocation: BaseInvocation,
|
||||||
graph_execution_state_id: str,
|
graph_execution_state_id: str,
|
||||||
|
model_manager: ModelManagerService,
|
||||||
) -> StatsContext:
|
) -> StatsContext:
|
||||||
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
if not self._stats.get(graph_execution_state_id): # first time we're seeing this
|
||||||
self._stats[graph_execution_state_id] = NodeLog()
|
self._stats[graph_execution_state_id] = NodeLog()
|
||||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||||
return self.StatsContext(invocation, graph_execution_state_id, self._invoker.services.model_manager, self)
|
return self.StatsContext(invocation, graph_execution_state_id, model_manager, self)
|
||||||
|
|
||||||
def reset_all_stats(self):
|
def reset_all_stats(self):
|
||||||
"""Zero all statistics"""
|
"""Zero all statistics"""
|
||||||
@ -124,7 +261,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
|||||||
errored = set()
|
errored = set()
|
||||||
for graph_id, node_log in self._stats.items():
|
for graph_id, node_log in self._stats.items():
|
||||||
try:
|
try:
|
||||||
current_graph_state = self._invoker.services.graph_execution_manager.get(graph_id)
|
current_graph_state = self.graph_execution_manager.get(graph_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
errored.add(graph_id)
|
errored.add(graph_id)
|
||||||
continue
|
continue
|
@ -1,121 +0,0 @@
|
|||||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
|
||||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
statistics = InvocationStatsService(graph_execution_manager)
|
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
|
||||||
... execute graphs...
|
|
||||||
statistics.log_stats()
|
|
||||||
|
|
||||||
Typical output:
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
|
||||||
|
|
||||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
|
||||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import AbstractContextManager
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
|
||||||
|
|
||||||
from .invocation_stats_common import NodeLog
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
|
||||||
|
|
||||||
# {graph_id => NodeLog}
|
|
||||||
_stats: Dict[str, NodeLog]
|
|
||||||
_cache_stats: Dict[str, CacheStats]
|
|
||||||
ram_used: float
|
|
||||||
ram_changed: float
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialize the InvocationStatsService and reset counters to zero
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_stats(
|
|
||||||
self,
|
|
||||||
invocation: BaseInvocation,
|
|
||||||
graph_execution_state_id: str,
|
|
||||||
) -> AbstractContextManager:
|
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics on the execution
|
|
||||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state_id: The id of the current session.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_stats(self, graph_execution_state_id: str):
|
|
||||||
"""
|
|
||||||
Reset all statistics for the indicated graph
|
|
||||||
:param graph_execution_state_id
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_all_stats(self):
|
|
||||||
"""Zero all statistics"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_invocation_stats(
|
|
||||||
self,
|
|
||||||
graph_id: str,
|
|
||||||
invocation_type: str,
|
|
||||||
time_used: float,
|
|
||||||
vram_used: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log_stats(self):
|
|
||||||
"""
|
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_mem_stats(
|
|
||||||
self,
|
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -1,25 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
# size of GIG in bytes
|
|
||||||
GIG = 1073741824
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeStats:
|
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
|
||||||
|
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
@ -1,10 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
from .graph import Graph, GraphExecutionState
|
||||||
|
from .invocation_queue import InvocationQueueItem
|
||||||
from .invocation_services import InvocationServices
|
from .invocation_services import InvocationServices
|
||||||
from .shared.graph import Graph, GraphExecutionState
|
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
@ -83,3 +84,7 @@ class Invoker:
|
|||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationProcessorABC(ABC):
|
||||||
|
pass
|
||||||
|
@ -1,16 +1,25 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Generic, Optional, TypeVar
|
from typing import Callable, Generic, Optional, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class PaginatedResults(GenericModel, Generic[T]):
|
||||||
"""Provides storage for a single type of item. The type must be a Pydantic model."""
|
"""Paginated results"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
items: list[T] = Field(description="Items")
|
||||||
|
page: int = Field(description="Current Page")
|
||||||
|
pages: int = Field(description="Total number of pages")
|
||||||
|
per_page: int = Field(description="Number of items per page")
|
||||||
|
total: int = Field(description="Total number of items in result")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
119
invokeai/app/services/latent_storage.py
Normal file
119
invokeai/app/services/latent_storage.py
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Callable, Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsStorageBase(ABC):
|
||||||
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||||
|
"""Register a callback for when an item is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an item is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: torch.Tensor) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
|
__cache: Dict[str, torch.Tensor]
|
||||||
|
__cache_ids: Queue
|
||||||
|
__max_cache_size: int
|
||||||
|
__underlying_storage: LatentsStorageBase
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
|
super().__init__()
|
||||||
|
self.__underlying_storage = underlying_storage
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
cache_item = self.__get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self.__underlying_storage.get(name)
|
||||||
|
self.__set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__underlying_storage.save(name, data)
|
||||||
|
self.__set_cache(name, data)
|
||||||
|
self._on_changed(data)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self.__underlying_storage.delete(name)
|
||||||
|
if name in self.__cache:
|
||||||
|
del self.__cache[name]
|
||||||
|
self._on_deleted(name)
|
||||||
|
|
||||||
|
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||||
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
|
if name not in self.__cache:
|
||||||
|
self.__cache[name] = data
|
||||||
|
self.__cache_ids.put(name)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
||||||
|
|
||||||
|
|
||||||
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
|
__output_folder: Union[str, Path]
|
||||||
|
|
||||||
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
return torch.load(latent_path)
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
latent_path.unlink()
|
||||||
|
|
||||||
|
def get_path(self, name: str) -> Path:
|
||||||
|
return self.__output_folder / name
|
@ -1,45 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
|
||||||
"""Responsible for storing and retrieving latents."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
|
||||||
"""Register a callback for when an item is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an item is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: torch.Tensor) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
@ -1,34 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .latents_storage_base import LatentsStorageBase
|
|
||||||
|
|
||||||
|
|
||||||
class DiskLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Stores latents in a folder on disk without caching"""
|
|
||||||
|
|
||||||
__output_folder: Path
|
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
return torch.load(latent_path)
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
torch.save(data, latent_path)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
latent_path.unlink()
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
|
||||||
return self.__output_folder / name
|
|
@ -1,54 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .latents_storage_base import LatentsStorageBase
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
|
||||||
|
|
||||||
__cache: Dict[str, torch.Tensor]
|
|
||||||
__cache_ids: Queue
|
|
||||||
__max_cache_size: int
|
|
||||||
__underlying_storage: LatentsStorageBase
|
|
||||||
|
|
||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
|
||||||
super().__init__()
|
|
||||||
self.__underlying_storage = underlying_storage
|
|
||||||
self.__cache = dict()
|
|
||||||
self.__cache_ids = Queue()
|
|
||||||
self.__max_cache_size = max_cache_size
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
cache_item = self.__get_cache(name)
|
|
||||||
if cache_item is not None:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
latent = self.__underlying_storage.get(name)
|
|
||||||
self.__set_cache(name, latent)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__underlying_storage.save(name, data)
|
|
||||||
self.__set_cache(name, data)
|
|
||||||
self._on_changed(data)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
self.__underlying_storage.delete(name)
|
|
||||||
if name in self.__cache:
|
|
||||||
del self.__cache[name]
|
|
||||||
self._on_deleted(name)
|
|
||||||
|
|
||||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
|
||||||
return None if name not in self.__cache else self.__cache[name]
|
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
|
||||||
if name not in self.__cache:
|
|
||||||
self.__cache[name] = data
|
|
||||||
self.__cache_ids.put(name)
|
|
||||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
|
||||||
self.__cache.pop(self.__cache_ids.get())
|
|
@ -1,286 +0,0 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
||||||
from invokeai.backend.model_management import (
|
|
||||||
AddModelResult,
|
|
||||||
BaseModelType,
|
|
||||||
MergeInterpolationMethod,
|
|
||||||
ModelInfo,
|
|
||||||
ModelType,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
logger: Logger,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize with the path to the models.yaml config file.
|
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""Retrieve the indicated model with name and type.
|
|
||||||
submodel can be used to get a part (such as the vae)
|
|
||||||
of a diffusers pipeline."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def logger(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_exists(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
Uses the exact format as the omegaconf stanza.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Return information about the model using the same format as list_models()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
clobber: bool = False,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
|
||||||
ModelNotFoundException if the name does not already exist.
|
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
new_name: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rename the indicated model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def convert_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
||||||
version and deleting the original checkpoint file if it is in the models
|
|
||||||
directory.
|
|
||||||
:param model_name: Name of the model to convert
|
|
||||||
:param base_model: Base model type
|
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
||||||
directory already in place.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
items_to_import: set[str],
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> dict[str, AddModelResult]:
|
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
||||||
|
|
||||||
The prediction type helper is necessary to distinguish between
|
|
||||||
models based on Stable Diffusion 2 Base (requiring
|
|
||||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
||||||
(requiring SchedulerPredictionType.VPrediction). It is
|
|
||||||
generally impossible to do this programmatically, so the
|
|
||||||
prediction_type_helper usually asks the user to choose.
|
|
||||||
|
|
||||||
The result is a set of successfully installed models. Each element
|
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
||||||
that model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_models(
|
|
||||||
self,
|
|
||||||
model_names: List[str] = Field(
|
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
|
||||||
),
|
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
|
||||||
default=None, description="Base model shared by all models to be merged"
|
|
||||||
),
|
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
|
||||||
:param model_names: List of 2-3 models to merge
|
|
||||||
:param base_model: Base model to use for all models
|
|
||||||
:param merged_model_name: Name of destination merged model
|
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
||||||
:param interp: Interpolation method. None (default)
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Return list of all models found in the designated directory.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sync_to_config(self):
|
|
||||||
"""
|
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
|
||||||
in the autoimport directories. Call after making changes outside the
|
|
||||||
model manager API.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
||||||
"""
|
|
||||||
Reset model cache statistics for graph with graph_id.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
|
||||||
pass
|
|
@ -2,15 +2,16 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -25,12 +26,273 @@ from invokeai.backend.model_management import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
|
||||||
|
|
||||||
from .model_manager_base import ModelManagerServiceBase
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
|
from .config import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManagerServiceBase(ABC):
|
||||||
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
logger: ModuleType,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize with the path to the models.yaml config file.
|
||||||
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
|
and sequential_offload boolean. Note that the default device
|
||||||
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
node: Optional[BaseInvocation] = None,
|
||||||
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Retrieve the indicated model with name and type.
|
||||||
|
submodel can be used to get a part (such as the vae)
|
||||||
|
of a diffusers pipeline."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def logger(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_exists(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
|
Uses the exact format as the omegaconf stanza.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_type1:
|
||||||
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||||
|
'model_name' : name,
|
||||||
|
'model_type' : SDModelType,
|
||||||
|
'description': description,
|
||||||
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
},
|
||||||
|
model_type2:
|
||||||
|
{ model_name_n: etc
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Return information about the model using the same format as list_models()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
|
"""
|
||||||
|
Returns a list of all the model names known.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def del_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete the named model from configuration. If delete_files is true,
|
||||||
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
|
as well. Call commit() to write to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_checkpoint_configs(self) -> List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
|
directory already in place.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(
|
||||||
|
self,
|
||||||
|
items_to_import: set[str],
|
||||||
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
|
) -> dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(
|
||||||
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||||
|
),
|
||||||
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
|
default=None, description="Base model shared by all models to be merged"
|
||||||
|
),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_for_models(self, directory: Path) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||||
|
"""
|
||||||
|
Reset model cache statistics for graph with graph_id.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
If no conf_file is provided, then replaces the
|
||||||
|
original file/database used to initialize the object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
@ -24,6 +24,15 @@ class BoardRecord(BaseModelExcludeNull):
|
|||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
|
class BoardDTO(BoardRecord):
|
||||||
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
|
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||||
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
|
"""The number of images in the board."""
|
||||||
|
|
||||||
|
|
||||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||||
"""Deserializes a board record."""
|
"""Deserializes a board record."""
|
||||||
|
|
||||||
@ -44,29 +53,3 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
|||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordNotFoundException(Exception):
|
|
||||||
"""Raised when an board record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordSaveException(Exception):
|
|
||||||
"""Raised when an board record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordDeleteException(Exception):
|
|
||||||
"""Raised when an board record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not deleted"):
|
|
||||||
super().__init__(message)
|
|
@ -1,117 +1,13 @@
|
|||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
|
||||||
import datetime
|
import datetime
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The origin of a resource (eg image).
|
|
||||||
|
|
||||||
- INTERNAL: The resource was created by the application.
|
|
||||||
- EXTERNAL: The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
INTERNAL = "internal"
|
|
||||||
"""The resource was created by the application."""
|
|
||||||
EXTERNAL = "external"
|
|
||||||
"""The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidOriginException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid resource origin."):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The category of an image.
|
|
||||||
|
|
||||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
|
||||||
- MASK: The image is a mask image.
|
|
||||||
- CONTROL: The image is a ControlNet control image.
|
|
||||||
- USER: The image is a user-provide image.
|
|
||||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERAL = "general"
|
|
||||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
|
||||||
MASK = "mask"
|
|
||||||
"""MASK: The image is a mask image."""
|
|
||||||
CONTROL = "control"
|
|
||||||
"""CONTROL: The image is a ControlNet control image."""
|
|
||||||
USER = "user"
|
|
||||||
"""USER: The image is a user-provide image."""
|
|
||||||
OTHER = "other"
|
|
||||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageCategoryException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ImageCategory.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid image category."):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordNotFoundException(Exception):
|
|
||||||
"""Raised when an image record is not found."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordSaveException(Exception):
|
|
||||||
"""Raised when an image record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordDeleteException(Exception):
|
|
||||||
"""Raised when an image record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
|
||||||
list(
|
|
||||||
map(
|
|
||||||
lambda c: "images." + c,
|
|
||||||
[
|
|
||||||
"image_name",
|
|
||||||
"image_origin",
|
|
||||||
"image_category",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
"session_id",
|
|
||||||
"node_id",
|
|
||||||
"is_intermediate",
|
|
||||||
"created_at",
|
|
||||||
"updated_at",
|
|
||||||
"deleted_at",
|
|
||||||
"starred",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModelExcludeNull):
|
class ImageRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized image record without metadata."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
@ -170,6 +66,41 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
"""The image's new `starred` state."""
|
"""The image's new `starred` state."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
|
image_url: str = Field(description="The URL of the image.")
|
||||||
|
"""The URL of the image."""
|
||||||
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
"""The URL of the image's thumbnail."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
|
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||||
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def image_record_to_dto(
|
||||||
|
image_record: ImageRecord,
|
||||||
|
image_url: str,
|
||||||
|
thumbnail_url: str,
|
||||||
|
board_id: Optional[str],
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Converts an image record to an image DTO."""
|
||||||
|
return ImageDTO(
|
||||||
|
**image_record.dict(),
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
board_id=board_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||||
"""Deserializes an image record."""
|
"""Deserializes an image record."""
|
||||||
|
|
@ -1,11 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class NameServiceBase(ABC):
|
|
||||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
@abstractmethod
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
"""Creates a name for an image."""
|
|
||||||
pass
|
|
@ -1,8 +0,0 @@
|
|||||||
from enum import Enum, EnumMeta
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
|
||||||
"""Enum for resource types."""
|
|
||||||
|
|
||||||
IMAGE = "image"
|
|
||||||
LATENT = "latent"
|
|
@ -1,13 +0,0 @@
|
|||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
from .names_base import NameServiceBase
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleNameService(NameServiceBase):
|
|
||||||
"""Creates image names from UUIDs."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
uuid_str = uuid_string()
|
|
||||||
filename = f"{uuid_str}.png"
|
|
||||||
return filename
|
|
@ -4,12 +4,12 @@ from threading import BoundedSemaphore, Event, Thread
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
|
||||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invocations.baseinvocation import InvocationContext
|
||||||
from .invocation_processor_base import InvocationProcessorABC
|
from ..models.exceptions import CanceledException
|
||||||
from .invocation_processor_common import CanceledException
|
from .invocation_queue import InvocationQueueItem
|
||||||
|
from .invocation_stats import InvocationStatsServiceBase
|
||||||
|
from .invoker import InvocationProcessorABC, Invoker
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||||
@ -37,6 +37,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
def __process(self, stop_event: Event):
|
def __process(self, stop_event: Event):
|
||||||
try:
|
try:
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
|
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||||
queue_item: Optional[InvocationQueueItem] = None
|
queue_item: Optional[InvocationQueueItem] = None
|
||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
@ -96,7 +97,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
# Invoke
|
# Invoke
|
||||||
try:
|
try:
|
||||||
graph_id = graph_execution_state.id
|
graph_id = graph_execution_state.id
|
||||||
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
model_manager = self.__invoker.services.model_manager
|
||||||
|
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||||
# which handles a few things:
|
# which handles a few things:
|
||||||
# - nodes that require a value, but get it only from a connection
|
# - nodes that require a value, but get it only from a connection
|
||||||
@ -131,13 +133,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
source_node_id=source_node_id,
|
source_node_id=source_node_id,
|
||||||
result=outputs.dict(),
|
result=outputs.dict(),
|
||||||
)
|
)
|
||||||
self.__invoker.services.performance_statistics.log_stats()
|
statistics.log_stats()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except CanceledException:
|
except CanceledException:
|
||||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -162,7 +164,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
|||||||
error_type=e.__class__.__name__,
|
error_type=e.__class__.__name__,
|
||||||
error=error,
|
error=error,
|
||||||
)
|
)
|
||||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state.id)
|
statistics.reset_stats(graph_execution_state.id)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Check queue to see if this is canceled, and skip if so
|
# Check queue to see if this is canceled, and skip if so
|
31
invokeai/app/services/resource_name.py
Normal file
31
invokeai/app/services/resource_name.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum, EnumMeta
|
||||||
|
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||||
|
"""Enum for resource types."""
|
||||||
|
|
||||||
|
IMAGE = "image"
|
||||||
|
LATENT = "latent"
|
||||||
|
|
||||||
|
|
||||||
|
class NameServiceBase(ABC):
|
||||||
|
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
@abstractmethod
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
"""Creates a name for an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNameService(NameServiceBase):
|
||||||
|
"""Creates image names from UUIDs."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
uuid_str = uuid_string()
|
||||||
|
filename = f"{uuid_str}.png"
|
||||||
|
return filename
|
@ -7,7 +7,7 @@ from typing import Optional
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
@ -97,6 +97,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
resume_event.set()
|
resume_event.set()
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
queue_item: Optional[SessionQueueItem] = None
|
queue_item: Optional[SessionQueueItem] = None
|
||||||
|
self.__invoker.services.logger
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
try:
|
try:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
QUEUE_ITEM_STATUS,
|
QUEUE_ITEM_STATUS,
|
||||||
Batch,
|
Batch,
|
||||||
@ -17,8 +18,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
|
||||||
|
|
||||||
|
|
||||||
class SessionQueueBase(ABC):
|
class SessionQueueBase(ABC):
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator,
|
|||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# region Errors
|
# region Errors
|
||||||
|
@ -5,7 +5,8 @@ from typing import Optional, Union, cast
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
|
from invokeai.app.services.graph import Graph
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
@ -28,9 +29,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
calc_session_count,
|
calc_session_count,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteSessionQueue(SessionQueueBase):
|
class SqliteSessionQueue(SessionQueueBase):
|
||||||
@ -46,11 +45,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase) -> None:
|
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__lock = db.lock
|
self.__conn = conn
|
||||||
self.__conn = db.conn
|
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||||
|
self.__conn.row_factory = sqlite3.Row
|
||||||
self.__cursor = self.__conn.cursor()
|
self.__cursor = self.__conn.cursor()
|
||||||
|
self.__lock = lock
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
|
|
||||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||||
|
14
invokeai/app/services/shared/models.py
Normal file
14
invokeai/app/services/shared/models.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
|
||||||
|
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
||||||
|
"""Cursor-paginated results"""
|
||||||
|
|
||||||
|
limit: int = Field(..., description="Limit of items to get")
|
||||||
|
has_more: bool = Field(..., description="Whether there are more items available")
|
||||||
|
items: list[GenericBaseModel] = Field(..., description="Items")
|
@ -1,42 +0,0 @@
|
|||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pydantic.generics import GenericModel
|
|
||||||
|
|
||||||
GenericBaseModel = TypeVar("GenericBaseModel", bound=BaseModel)
|
|
||||||
|
|
||||||
|
|
||||||
class CursorPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
|
||||||
"""
|
|
||||||
Cursor-paginated results
|
|
||||||
Generic must be a Pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
limit: int = Field(..., description="Limit of items to get")
|
|
||||||
has_more: bool = Field(..., description="Whether there are more items available")
|
|
||||||
items: list[GenericBaseModel] = Field(..., description="Items")
|
|
||||||
|
|
||||||
|
|
||||||
class OffsetPaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
|
||||||
"""
|
|
||||||
Offset-paginated results
|
|
||||||
Generic must be a Pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
limit: int = Field(description="Limit of items to get")
|
|
||||||
offset: int = Field(description="Offset from which to retrieve items")
|
|
||||||
total: int = Field(description="Total number of items in result")
|
|
||||||
items: list[GenericBaseModel] = Field(description="Items")
|
|
||||||
|
|
||||||
|
|
||||||
class PaginatedResults(GenericModel, Generic[GenericBaseModel]):
|
|
||||||
"""
|
|
||||||
Paginated results
|
|
||||||
Generic must be a Pydantic model
|
|
||||||
"""
|
|
||||||
|
|
||||||
page: int = Field(description="Current Page")
|
|
||||||
pages: int = Field(description="Total number of pages")
|
|
||||||
per_page: int = Field(description="Number of items per page")
|
|
||||||
total: int = Field(description="Total number of items in result")
|
|
||||||
items: list[GenericBaseModel] = Field(description="Items")
|
|
@ -1,48 +0,0 @@
|
|||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
from logging import Logger
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
sqlite_memory = ":memory:"
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteDatabase:
|
|
||||||
conn: sqlite3.Connection
|
|
||||||
lock: threading.Lock
|
|
||||||
_logger: Logger
|
|
||||||
_config: InvokeAIAppConfig
|
|
||||||
|
|
||||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
|
||||||
self._logger = logger
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
if self._config.use_memory_db:
|
|
||||||
location = sqlite_memory
|
|
||||||
logger.info("Using in-memory database")
|
|
||||||
else:
|
|
||||||
db_path = self._config.db_path
|
|
||||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
location = str(db_path)
|
|
||||||
self._logger.info(f"Using database at {location}")
|
|
||||||
|
|
||||||
self.conn = sqlite3.connect(location, check_same_thread=False)
|
|
||||||
self.lock = threading.Lock()
|
|
||||||
self.conn.row_factory = sqlite3.Row
|
|
||||||
|
|
||||||
if self._config.log_sql:
|
|
||||||
self.conn.set_trace_callback(self._logger.debug)
|
|
||||||
|
|
||||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
|
||||||
|
|
||||||
def clean(self) -> None:
|
|
||||||
try:
|
|
||||||
self.lock.acquire()
|
|
||||||
self.conn.execute("VACUUM;")
|
|
||||||
self.conn.commit()
|
|
||||||
self._logger.info("Cleaned database")
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(f"Error cleaning database: {e}")
|
|
||||||
raise e
|
|
||||||
finally:
|
|
||||||
self.lock.release()
|
|
@ -4,13 +4,12 @@ from typing import Generic, Optional, TypeVar, get_args
|
|||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_raw_as
|
||||||
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from .item_storage import ItemStorageABC, PaginatedResults
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
|
|
||||||
from .item_storage_base import ItemStorageABC
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
sqlite_memory = ":memory:"
|
||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_table_name: str
|
_table_name: str
|
||||||
@ -19,13 +18,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
_id_field: str
|
_id_field: str
|
||||||
_lock: threading.Lock
|
_lock: threading.Lock
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase, table_name: str, id_field: str = "id"):
|
def __init__(self, conn: sqlite3.Connection, table_name: str, lock: threading.Lock, id_field: str = "id"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._lock = db.lock
|
|
||||||
self._conn = db.conn
|
|
||||||
self._table_name = table_name
|
self._table_name = table_name
|
||||||
self._id_field = id_field # TODO: validate that T has this field
|
self._id_field = id_field # TODO: validate that T has this field
|
||||||
|
self._lock = lock
|
||||||
|
self._conn = conn
|
||||||
self._cursor = self._conn.cursor()
|
self._cursor = self._conn.cursor()
|
||||||
|
|
||||||
self._create_table()
|
self._create_table()
|
||||||
@ -45,8 +44,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
# __orig_class__ is technically an implementation detail of the typing module, not a supported API
|
item_type = get_args(self.__orig_class__)[0]
|
||||||
item_type = get_args(self.__orig_class__)[0] # type: ignore
|
|
||||||
return parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
3
invokeai/app/services/thread.py
Normal file
3
invokeai/app/services/thread.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
import threading
|
||||||
|
|
||||||
|
lock = threading.Lock()
|
@ -1,6 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from .urls_base import UrlServiceBase
|
|
||||||
|
class UrlServiceBase(ABC):
|
||||||
|
"""Responsible for building URLs for resources."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets the URL for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class LocalUrlService(UrlServiceBase):
|
class LocalUrlService(UrlServiceBase):
|
@ -1,10 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class UrlServiceBase(ABC):
|
|
||||||
"""Responsible for building URLs for resources."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets the URL for an image or thumbnail."""
|
|
||||||
pass
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.shared.graph import Edge
|
from invokeai.app.services.graph import Edge
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
|
from invokeai.app.models.image import ProgressImage
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user