mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: refactor services folder/module structure
Refactor services folder/module structure. **Motivation** While working on our services I've repeatedly encountered circular imports and a general lack of clarity regarding where to put things. The structure introduced goes a long way towards resolving those issues, setting us up for a clean structure going forward. **Services** Services are now in their own folder with a few files: - `services/{service_name}/__init__.py`: init as needed, mostly empty now - `services/{service_name}/{service_name}_base.py`: the base class for the service - `services/{service_name}/{service_name}_{impl_type}.py`: the default concrete implementation of the service - typically one of `sqlite`, `default`, or `memory` - `services/{service_name}/{service_name}_common.py`: any common items - models, exceptions, utilities, etc Though it's a bit verbose to have the service name both as the folder name and the prefix for files, I found it is _extremely_ confusing to have all of the base classes just be named `base.py`. So, at the cost of some verbosity when importing things, I've included the service name in the filename. There are some minor logic changes. For example, in `InvocationProcessor`, instead of assigning the model manager service to a variable to be used later in the file, the service is used directly via the `Invoker`. **Shared** Things that are used across disparate services are in `services/shared/`: - `default_graphs.py`: previously in `services/` - `graphs.py`: previously in `services/` - `paginatation`: generic pagination models used in a few services - `sqlite`: the `SqliteDatabase` class, other sqlite-specific things
This commit is contained in:
parent
88bee96ca3
commit
402cf9b0ee
@ -2,33 +2,34 @@
|
|||||||
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage
|
|
||||||
from invokeai.app.services.board_images import BoardImagesService
|
|
||||||
from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage
|
|
||||||
from invokeai.app.services.boards import BoardService
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
|
||||||
from invokeai.app.services.images import ImageService
|
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
|
||||||
from invokeai.app.services.resource_name import SimpleNameService
|
|
||||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
|
||||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.app.services.urls import LocalUrlService
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from invokeai.version.invokeai_version import __version__
|
from invokeai.version.invokeai_version import __version__
|
||||||
|
|
||||||
from ..services.default_graphs import create_system_graphs
|
from ..services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||||
from ..services.graph import GraphExecutionState, LibraryGraph
|
from ..services.board_images.board_images_default import BoardImagesService
|
||||||
from ..services.image_file_storage import DiskImageFileStorage
|
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.boards.boards_default import BoardService
|
||||||
|
from ..services.config import InvokeAIAppConfig
|
||||||
|
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||||
|
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||||
|
from ..services.images.images_default import ImageService
|
||||||
|
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
|
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||||
|
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
from ..services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from ..services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
from ..services.model_manager_service import ModelManagerService
|
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||||
from ..services.processor import DefaultInvocationProcessor
|
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||||
from ..services.sqlite import SqliteItemStorage
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
|
from ..services.names.names_default import SimpleNameService
|
||||||
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||||
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
|
from ..services.shared.default_graphs import create_system_graphs
|
||||||
|
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||||
|
from ..services.shared.sqlite import SqliteDatabase
|
||||||
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
from .events import FastAPIEventService
|
from .events import FastAPIEventService
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from typing import Any
|
|||||||
|
|
||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from ..services.events import EventServiceBase
|
from ..services.events.events_base import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
|
@ -4,8 +4,8 @@ from fastapi import Body, HTTPException, Path, Query
|
|||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
@ -8,8 +8,8 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecordChanges, ImageUrlsDTO
|
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
@ -18,9 +18,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
|
|
||||||
from ...services.graph import Graph
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.services.shared.pagination import PaginatedResults
|
|||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ...invocations import * # noqa: F401 F403
|
from ...invocations import * # noqa: F401 F403
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
from ...services.shared.graph import Edge, EdgeConnection, Graph, GraphExecutionState, NodeAlreadyExecutedError
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||||
|
@ -5,7 +5,7 @@ from fastapi_events.handlers.local import local_handler
|
|||||||
from fastapi_events.typing import Event
|
from fastapi_events.typing import Event
|
||||||
from socketio import ASGIApp, AsyncServer
|
from socketio import ASGIApp, AsyncServer
|
||||||
|
|
||||||
from ..services.events import EventServiceBase
|
from ..services.events.events_base import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class SocketIO:
|
class SocketIO:
|
||||||
|
@ -28,7 +28,7 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from pydantic.fields import ModelField, Undefined
|
from pydantic.fields import ModelField, Undefined
|
||||||
from pydantic.typing import NoArgAnyCallable
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
|
@ -27,9 +27,9 @@ from PIL import Image
|
|||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
from ...backend.model_management import BaseModelType
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -6,7 +6,7 @@ import numpy
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
@ -9,10 +9,10 @@ from PIL import Image, ImageChops, ImageFilter, ImageOps
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import BoardField, ColorField, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, FieldDescriptions, Input, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,12 +7,12 @@ import numpy as np
|
|||||||
from PIL import Image, ImageOps
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||||
from invokeai.backend.image_util.lama import LaMA
|
from invokeai.backend.image_util.lama import LaMA
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ from invokeai.app.invocations.primitives import (
|
|||||||
build_latents_output,
|
build_latents_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
@ -54,7 +55,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -14,13 +14,13 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
from ...backend.model_management import ONNXModelPatcher
|
from ...backend.model_management import ONNXModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.util import choose_torch_device
|
from ...backend.util import choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
|
@ -10,7 +10,7 @@ from PIL import Image
|
|||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
@ -1,4 +0,0 @@
|
|||||||
class CanceledException(Exception):
|
|
||||||
"""Execution canceled by user."""
|
|
||||||
|
|
||||||
pass
|
|
@ -1,71 +0,0 @@
|
|||||||
from enum import Enum
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
|
||||||
|
|
||||||
|
|
||||||
class ProgressImage(BaseModel):
|
|
||||||
"""The progress image sent intermittently during processing"""
|
|
||||||
|
|
||||||
width: int = Field(description="The effective width of the image in pixels")
|
|
||||||
height: int = Field(description="The effective height of the image in pixels")
|
|
||||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The origin of a resource (eg image).
|
|
||||||
|
|
||||||
- INTERNAL: The resource was created by the application.
|
|
||||||
- EXTERNAL: The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
INTERNAL = "internal"
|
|
||||||
"""The resource was created by the application."""
|
|
||||||
EXTERNAL = "external"
|
|
||||||
"""The resource was not created by the application.
|
|
||||||
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidOriginException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ResourceOrigin.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid resource origin."):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
|
||||||
"""The category of an image.
|
|
||||||
|
|
||||||
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
|
||||||
- MASK: The image is a mask image.
|
|
||||||
- CONTROL: The image is a ControlNet control image.
|
|
||||||
- USER: The image is a user-provide image.
|
|
||||||
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
GENERAL = "general"
|
|
||||||
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
|
||||||
MASK = "mask"
|
|
||||||
"""MASK: The image is a mask image."""
|
|
||||||
CONTROL = "control"
|
|
||||||
"""CONTROL: The image is a ControlNet control image."""
|
|
||||||
USER = "user"
|
|
||||||
"""USER: The image is a user-provide image."""
|
|
||||||
OTHER = "other"
|
|
||||||
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidImageCategoryException(ValueError):
|
|
||||||
"""Raised when a provided value is not a valid ImageCategory.
|
|
||||||
|
|
||||||
Subclasses `ValueError`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message="Invalid image category."):
|
|
||||||
super().__init__(message)
|
|
@ -0,0 +1,47 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImageRecordStorageBase(ABC):
|
||||||
|
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_count_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Gets the number of images for a board."""
|
||||||
|
pass
|
@ -1,56 +1,12 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, deserialize_image_record
|
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
from .board_image_records_base import BoardImageRecordStorageBase
|
||||||
class BoardImageRecordStorageBase(ABC):
|
|
||||||
"""Abstract base class for the one-to-many board-image relationship record storage."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_count_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> int:
|
|
||||||
"""Gets the number of images for a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
|
@ -1,85 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from invokeai.app.services.board_record_storage import BoardRecord
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesServiceABC(ABC):
|
|
||||||
"""High-level service for board-image relationship management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Adds an image to a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Removes an image from a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Gets all board images for a board, as a list of the image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Gets an image's board id, if it has one."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BoardImagesService(BoardImagesServiceABC):
|
|
||||||
__invoker: Invoker
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
|
||||||
self.__invoker = invoker
|
|
||||||
|
|
||||||
def add_image_to_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
|
||||||
|
|
||||||
def remove_image_from_board(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> None:
|
|
||||||
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
|
||||||
|
|
||||||
def get_all_board_image_names_for_board(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> list[str]:
|
|
||||||
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
|
||||||
|
|
||||||
def get_board_for_image(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
) -> Optional[str]:
|
|
||||||
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
|
||||||
return board_id
|
|
||||||
|
|
||||||
|
|
||||||
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
|
||||||
"""Converts a board record to a board DTO."""
|
|
||||||
return BoardDTO(
|
|
||||||
**board_record.dict(exclude={"cover_image_name"}),
|
|
||||||
cover_image_name=cover_image_name,
|
|
||||||
image_count=image_count,
|
|
||||||
)
|
|
0
invokeai/app/services/board_images/__init__.py
Normal file
0
invokeai/app/services/board_images/__init__.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
39
invokeai/app/services/board_images/board_images_base.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesServiceABC(ABC):
|
||||||
|
"""High-level service for board-image relationship management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Adds an image to a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Removes an image from a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Gets all board images for a board, as a list of the image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Gets an image's board id, if it has one."""
|
||||||
|
pass
|
38
invokeai/app/services/board_images/board_images_default.py
Normal file
38
invokeai/app/services/board_images/board_images_default.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
|
||||||
|
from .board_images_base import BoardImagesServiceABC
|
||||||
|
|
||||||
|
|
||||||
|
class BoardImagesService(BoardImagesServiceABC):
|
||||||
|
__invoker: Invoker
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
|
self.__invoker = invoker
|
||||||
|
|
||||||
|
def add_image_to_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self.__invoker.services.board_image_records.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def remove_image_from_board(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> None:
|
||||||
|
self.__invoker.services.board_image_records.remove_image_from_board(image_name)
|
||||||
|
|
||||||
|
def get_all_board_image_names_for_board(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
return self.__invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
def get_board_for_image(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
) -> Optional[str]:
|
||||||
|
board_id = self.__invoker.services.board_image_records.get_board_for_image(image_name)
|
||||||
|
return board_id
|
55
invokeai/app/services/board_records/board_records_base.py
Normal file
55
invokeai/app/services/board_records/board_records_base.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .board_records_common import BoardChanges, BoardRecord
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the board record store."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, board_id: str) -> None:
|
||||||
|
"""Deletes a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Saves a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Gets a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardRecord:
|
||||||
|
"""Updates a board record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardRecord]:
|
||||||
|
"""Gets many board records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardRecord]:
|
||||||
|
"""Gets all board records."""
|
||||||
|
pass
|
@ -1,7 +1,7 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Extra, Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
@ -24,15 +24,6 @@ class BoardRecord(BaseModelExcludeNull):
|
|||||||
"""The name of the cover image of the board."""
|
"""The name of the cover image of the board."""
|
||||||
|
|
||||||
|
|
||||||
class BoardDTO(BoardRecord):
|
|
||||||
"""Deserialized board record with cover image URL and image count."""
|
|
||||||
|
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
|
||||||
"""The URL of the thumbnail of the most recent image in the board."""
|
|
||||||
image_count: int = Field(description="The number of images in the board.")
|
|
||||||
"""The number of images in the board."""
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
||||||
"""Deserializes a board record."""
|
"""Deserializes a board record."""
|
||||||
|
|
||||||
@ -53,3 +44,29 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
|
|||||||
updated_at=updated_at,
|
updated_at=updated_at,
|
||||||
deleted_at=deleted_at,
|
deleted_at=deleted_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardChanges(BaseModel, extra=Extra.forbid):
|
||||||
|
board_name: Optional[str] = Field(description="The board's new name.")
|
||||||
|
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an board record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordSaveException(Exception):
|
||||||
|
"""Raised when an board record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class BoardRecordDeleteException(Exception):
|
||||||
|
"""Raised when an board record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Board record not deleted"):
|
||||||
|
super().__init__(message)
|
@ -1,90 +1,20 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from typing import Union, cast
|
||||||
from typing import Optional, Union, cast
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field
|
|
||||||
|
|
||||||
from invokeai.app.services.models.board_record import BoardRecord, deserialize_board_record
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
from .board_records_base import BoardRecordStorageBase
|
||||||
class BoardChanges(BaseModel, extra=Extra.forbid):
|
from .board_records_common import (
|
||||||
board_name: Optional[str] = Field(description="The board's new name.")
|
BoardChanges,
|
||||||
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
|
BoardRecord,
|
||||||
|
BoardRecordDeleteException,
|
||||||
|
BoardRecordNotFoundException,
|
||||||
class BoardRecordNotFoundException(Exception):
|
BoardRecordSaveException,
|
||||||
"""Raised when an board record is not found."""
|
deserialize_board_record,
|
||||||
|
)
|
||||||
def __init__(self, message="Board record not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordSaveException(Exception):
|
|
||||||
"""Raised when an board record cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordDeleteException(Exception):
|
|
||||||
"""Raised when an board record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Board record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class BoardRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the board record store."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, board_id: str) -> None:
|
|
||||||
"""Deletes a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Saves a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Gets a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardRecord:
|
|
||||||
"""Updates a board record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardRecord]:
|
|
||||||
"""Gets many board records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardRecord]:
|
|
||||||
"""Gets all board records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
0
invokeai/app/services/boards/__init__.py
Normal file
0
invokeai/app/services/boards/__init__.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
59
invokeai/app/services/boards/boards_base.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .boards_common import BoardDTO
|
||||||
|
|
||||||
|
|
||||||
|
class BoardServiceABC(ABC):
|
||||||
|
"""High-level service for board management."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
board_name: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Creates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Gets a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
changes: BoardChanges,
|
||||||
|
) -> BoardDTO:
|
||||||
|
"""Updates a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
board_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Deletes a board."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> OffsetPaginatedResults[BoardDTO]:
|
||||||
|
"""Gets many boards."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_all(
|
||||||
|
self,
|
||||||
|
) -> list[BoardDTO]:
|
||||||
|
"""Gets all boards."""
|
||||||
|
pass
|
23
invokeai/app/services/boards/boards_common.py
Normal file
23
invokeai/app/services/boards/boards_common.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from ..board_records.board_records_common import BoardRecord
|
||||||
|
|
||||||
|
|
||||||
|
class BoardDTO(BoardRecord):
|
||||||
|
"""Deserialized board record with cover image URL and image count."""
|
||||||
|
|
||||||
|
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
|
||||||
|
"""The URL of the thumbnail of the most recent image in the board."""
|
||||||
|
image_count: int = Field(description="The number of images in the board.")
|
||||||
|
"""The number of images in the board."""
|
||||||
|
|
||||||
|
|
||||||
|
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
|
||||||
|
"""Converts a board record to a board DTO."""
|
||||||
|
return BoardDTO(
|
||||||
|
**board_record.dict(exclude={"cover_image_name"}),
|
||||||
|
cover_image_name=cover_image_name,
|
||||||
|
image_count=image_count,
|
||||||
|
)
|
@ -1,63 +1,10 @@
|
|||||||
from abc import ABC, abstractmethod
|
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||||
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
from invokeai.app.services.board_images import board_record_to_dto
|
|
||||||
from invokeai.app.services.board_record_storage import BoardChanges
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.models.board_record import BoardDTO
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .boards_base import BoardServiceABC
|
||||||
class BoardServiceABC(ABC):
|
from .boards_common import board_record_to_dto
|
||||||
"""High-level service for board management."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
board_name: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Creates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Gets a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
changes: BoardChanges,
|
|
||||||
) -> BoardDTO:
|
|
||||||
"""Updates a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(
|
|
||||||
self,
|
|
||||||
board_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Deletes a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> OffsetPaginatedResults[BoardDTO]:
|
|
||||||
"""Gets many boards."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_all(
|
|
||||||
self,
|
|
||||||
) -> list[BoardDTO]:
|
|
||||||
"""Gets all boards."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class BoardService(BoardServiceABC):
|
class BoardService(BoardServiceABC):
|
@ -2,5 +2,5 @@
|
|||||||
Init file for InvokeAI configure package
|
Init file for InvokeAI configure package
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .base import PagingArgumentParser # noqa F401
|
from .config_base import PagingArgumentParser # noqa F401
|
||||||
from .invokeai_config import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
from .config_default import InvokeAIAppConfig, get_invokeai_config # noqa F401
|
||||||
|
@ -12,7 +12,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pydoc
|
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -21,16 +20,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get
|
|||||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||||
from pydantic import BaseSettings
|
from pydantic import BaseSettings
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_common import PagingArgumentParser, int_or_float_or_str
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
|
||||||
"""
|
|
||||||
A custom ArgumentParser that uses pydoc to page its output.
|
|
||||||
It also supports reading defaults from an init file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def print_help(self, file=None):
|
|
||||||
text = self.format_help()
|
|
||||||
pydoc.pager(text)
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAISettings(BaseSettings):
|
class InvokeAISettings(BaseSettings):
|
||||||
@ -223,18 +213,3 @@ class InvokeAISettings(BaseSettings):
|
|||||||
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
|
||||||
help=field.field_info.description,
|
help=field.field_info.description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
|
||||||
"""
|
|
||||||
Workaround for argparse type checking.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except Exception as e: # noqa F841
|
|
||||||
pass
|
|
||||||
return str(value)
|
|
41
invokeai/app/services/config/config_common.py
Normal file
41
invokeai/app/services/config/config_common.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
Base class for the InvokeAI configuration system.
|
||||||
|
It defines a type of pydantic BaseSettings object that
|
||||||
|
is able to read and write from an omegaconf-based config file,
|
||||||
|
with overriding of settings from environment variables and/or
|
||||||
|
the command line.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pydoc
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
|
||||||
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
|
"""
|
||||||
|
A custom ArgumentParser that uses pydoc to page its output.
|
||||||
|
It also supports reading defaults from an init file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def print_help(self, file=None):
|
||||||
|
text = self.format_help()
|
||||||
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_float_or_str(value: str) -> Union[int, float, str]:
|
||||||
|
"""
|
||||||
|
Workaround for argparse type checking.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except Exception as e: # noqa F841
|
||||||
|
pass
|
||||||
|
return str(value)
|
@ -177,7 +177,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
|
|||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from pydantic import Field, parse_obj_as
|
from pydantic import Field, parse_obj_as
|
||||||
|
|
||||||
from .base import InvokeAISettings
|
from .config_base import InvokeAISettings
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
0
invokeai/app/services/events/__init__.py
Normal file
0
invokeai/app/services/events/__init__.py
Normal file
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.invocations.model import ModelInfo
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
BatchStatus,
|
BatchStatus,
|
||||||
EnqueueBatchResult,
|
EnqueueBatchResult,
|
||||||
@ -11,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
0
invokeai/app/services/image_files/__init__.py
Normal file
0
invokeai/app/services/image_files/__init__.py
Normal file
42
invokeai/app/services/image_files/image_files_base.py
Normal file
42
invokeai/app/services/image_files/image_files_base.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> PILImageType:
|
||||||
|
"""Retrieves an image as PIL Image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets the internal path to an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
||||||
|
# 500 internal server error. I don't like having this method on the service.
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates the path given for an image or thumbnail."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_name: str,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
thumbnail_size: int = 256,
|
||||||
|
) -> None:
|
||||||
|
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image and its thumbnail (if one exists)."""
|
||||||
|
pass
|
20
invokeai/app/services/image_files/image_files_common.py
Normal file
20
invokeai/app/services/image_files/image_files_common.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
|
class ImageFileNotFoundException(Exception):
|
||||||
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileSaveException(Exception):
|
||||||
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageFileDeleteException(Exception):
|
||||||
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image file not deleted"):
|
||||||
|
super().__init__(message)
|
@ -1,6 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
@ -12,65 +11,8 @@ from send2trash import send2trash
|
|||||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
||||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||||
|
|
||||||
|
from .image_files_base import ImageFileStorageBase
|
||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
from .image_files_common import ImageFileDeleteException, ImageFileNotFoundException, ImageFileSaveException
|
||||||
class ImageFileNotFoundException(Exception):
|
|
||||||
"""Raised when an image file is not found in storage."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not found"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileSaveException(Exception):
|
|
||||||
"""Raised when an image cannot be saved."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileDeleteException(Exception):
|
|
||||||
"""Raised when an image cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image file not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageFileStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
|
||||||
"""Retrieves an image as PIL Image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets the internal path to an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: We need to validate paths before starlette makes the FileResponse, else we get a
|
|
||||||
# 500 internal server error. I don't like having this method on the service.
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates the path given for an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_name: str,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
thumbnail_size: int = 256,
|
|
||||||
) -> None:
|
|
||||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image and its thumbnail (if one exists)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DiskImageFileStorage(ImageFileStorageBase):
|
class DiskImageFileStorage(ImageFileStorageBase):
|
0
invokeai/app/services/image_records/__init__.py
Normal file
0
invokeai/app/services/image_records/__init__.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
84
invokeai/app/services/image_records/image_records_base.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordStorageBase(ABC):
|
||||||
|
"""Low-level service responsible for interfacing with the image record store."""
|
||||||
|
|
||||||
|
# TODO: Implement an `update()` method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> Optional[dict]:
|
||||||
|
"""Gets an image's metadata'."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> None:
|
||||||
|
"""Updates an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: Optional[int] = None,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageRecord]:
|
||||||
|
"""Gets a page of image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
||||||
|
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str) -> None:
|
||||||
|
"""Deletes an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_many(self, image_names: list[str]) -> None:
|
||||||
|
"""Deletes many image records."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> list[str]:
|
||||||
|
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
session_id: Optional[str],
|
||||||
|
node_id: Optional[str],
|
||||||
|
metadata: Optional[dict],
|
||||||
|
is_intermediate: bool = False,
|
||||||
|
starred: bool = False,
|
||||||
|
) -> datetime:
|
||||||
|
"""Saves an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||||
|
"""Gets the most recent image for a board."""
|
||||||
|
pass
|
@ -1,13 +1,117 @@
|
|||||||
|
# TODO: Should these excpetions subclass existing python exceptions?
|
||||||
import datetime
|
import datetime
|
||||||
|
from enum import Enum
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Extra, Field, StrictBool, StrictStr
|
from pydantic import Extra, Field, StrictBool, StrictStr
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The origin of a resource (eg image).
|
||||||
|
|
||||||
|
- INTERNAL: The resource was created by the application.
|
||||||
|
- EXTERNAL: The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
INTERNAL = "internal"
|
||||||
|
"""The resource was created by the application."""
|
||||||
|
EXTERNAL = "external"
|
||||||
|
"""The resource was not created by the application.
|
||||||
|
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidOriginException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ResourceOrigin.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid resource origin."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCategory(str, Enum, metaclass=MetaEnum):
|
||||||
|
"""The category of an image.
|
||||||
|
|
||||||
|
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
|
||||||
|
- MASK: The image is a mask image.
|
||||||
|
- CONTROL: The image is a ControlNet control image.
|
||||||
|
- USER: The image is a user-provide image.
|
||||||
|
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
GENERAL = "general"
|
||||||
|
"""GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
|
||||||
|
MASK = "mask"
|
||||||
|
"""MASK: The image is a mask image."""
|
||||||
|
CONTROL = "control"
|
||||||
|
"""CONTROL: The image is a ControlNet control image."""
|
||||||
|
USER = "user"
|
||||||
|
"""USER: The image is a user-provide image."""
|
||||||
|
OTHER = "other"
|
||||||
|
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidImageCategoryException(ValueError):
|
||||||
|
"""Raised when a provided value is not a valid ImageCategory.
|
||||||
|
|
||||||
|
Subclasses `ValueError`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message="Invalid image category."):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordNotFoundException(Exception):
|
||||||
|
"""Raised when an image record is not found."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not found"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordSaveException(Exception):
|
||||||
|
"""Raised when an image record cannot be saved."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not saved"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageRecordDeleteException(Exception):
|
||||||
|
"""Raised when an image record cannot be deleted."""
|
||||||
|
|
||||||
|
def __init__(self, message="Image record not deleted"):
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_DTO_COLS = ", ".join(
|
||||||
|
list(
|
||||||
|
map(
|
||||||
|
lambda c: "images." + c,
|
||||||
|
[
|
||||||
|
"image_name",
|
||||||
|
"image_origin",
|
||||||
|
"image_category",
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
"session_id",
|
||||||
|
"node_id",
|
||||||
|
"is_intermediate",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
"deleted_at",
|
||||||
|
"starred",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageRecord(BaseModelExcludeNull):
|
class ImageRecord(BaseModelExcludeNull):
|
||||||
"""Deserialized image record without metadata."""
|
"""Deserialized image record without metadata."""
|
||||||
|
|
||||||
@ -66,41 +170,6 @@ class ImageRecordChanges(BaseModelExcludeNull, extra=Extra.forbid):
|
|||||||
"""The image's new `starred` state."""
|
"""The image's new `starred` state."""
|
||||||
|
|
||||||
|
|
||||||
class ImageUrlsDTO(BaseModelExcludeNull):
|
|
||||||
"""The URLs for an image and its thumbnail."""
|
|
||||||
|
|
||||||
image_name: str = Field(description="The unique name of the image.")
|
|
||||||
"""The unique name of the image."""
|
|
||||||
image_url: str = Field(description="The URL of the image.")
|
|
||||||
"""The URL of the image."""
|
|
||||||
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
|
||||||
"""The URL of the image's thumbnail."""
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
|
||||||
"""Deserialized image record, enriched for the frontend."""
|
|
||||||
|
|
||||||
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
|
||||||
"""The id of the board the image belongs to, if one exists."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def image_record_to_dto(
|
|
||||||
image_record: ImageRecord,
|
|
||||||
image_url: str,
|
|
||||||
thumbnail_url: str,
|
|
||||||
board_id: Optional[str],
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Converts an image record to an image DTO."""
|
|
||||||
return ImageDTO(
|
|
||||||
**image_record.dict(),
|
|
||||||
image_url=image_url,
|
|
||||||
thumbnail_url=thumbnail_url,
|
|
||||||
board_id=board_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||||
"""Deserializes an image record."""
|
"""Deserializes an image record."""
|
||||||
|
|
@ -1,138 +1,26 @@
|
|||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
|
||||||
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
from .image_records_base import ImageRecordStorageBase
|
||||||
# TODO: Should these excpetions subclass existing python exceptions?
|
from .image_records_common import (
|
||||||
class ImageRecordNotFoundException(Exception):
|
IMAGE_DTO_COLS,
|
||||||
"""Raised when an image record is not found."""
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
def __init__(self, message="Image record not found"):
|
ImageRecordChanges,
|
||||||
super().__init__(message)
|
ImageRecordDeleteException,
|
||||||
|
ImageRecordNotFoundException,
|
||||||
|
ImageRecordSaveException,
|
||||||
class ImageRecordSaveException(Exception):
|
ResourceOrigin,
|
||||||
"""Raised when an image record cannot be saved."""
|
deserialize_image_record,
|
||||||
|
|
||||||
def __init__(self, message="Image record not saved"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordDeleteException(Exception):
|
|
||||||
"""Raised when an image record cannot be deleted."""
|
|
||||||
|
|
||||||
def __init__(self, message="Image record not deleted"):
|
|
||||||
super().__init__(message)
|
|
||||||
|
|
||||||
|
|
||||||
IMAGE_DTO_COLS = ", ".join(
|
|
||||||
list(
|
|
||||||
map(
|
|
||||||
lambda c: "images." + c,
|
|
||||||
[
|
|
||||||
"image_name",
|
|
||||||
"image_origin",
|
|
||||||
"image_category",
|
|
||||||
"width",
|
|
||||||
"height",
|
|
||||||
"session_id",
|
|
||||||
"node_id",
|
|
||||||
"is_intermediate",
|
|
||||||
"created_at",
|
|
||||||
"updated_at",
|
|
||||||
"deleted_at",
|
|
||||||
"starred",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageRecordStorageBase(ABC):
|
|
||||||
"""Low-level service responsible for interfacing with the image record store."""
|
|
||||||
|
|
||||||
# TODO: Implement an `update()` method
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> Optional[dict]:
|
|
||||||
"""Gets an image's metadata'."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> None:
|
|
||||||
"""Updates an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: Optional[int] = None,
|
|
||||||
limit: Optional[int] = None,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageRecord]:
|
|
||||||
"""Gets a page of image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
# TODO: The database has a nullable `deleted_at` column, currently unused.
|
|
||||||
# Should we implement soft deletes? Would need coordination with ImageFileStorage.
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str) -> None:
|
|
||||||
"""Deletes an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_many(self, image_names: list[str]) -> None:
|
|
||||||
"""Deletes many image records."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> list[str]:
|
|
||||||
"""Deletes all intermediate image records, returning a list of deleted image names."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
width: int,
|
|
||||||
height: int,
|
|
||||||
session_id: Optional[str],
|
|
||||||
node_id: Optional[str],
|
|
||||||
metadata: Optional[dict],
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
starred: bool = False,
|
|
||||||
) -> datetime:
|
|
||||||
"""Saves an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
|
||||||
"""Gets the most recent image for a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||||
_conn: sqlite3.Connection
|
_conn: sqlite3.Connection
|
||||||
_cursor: sqlite3.Cursor
|
_cursor: sqlite3.Cursor
|
0
invokeai/app/services/images/__init__.py
Normal file
0
invokeai/app/services/images/__init__.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
129
invokeai/app/services/images/images_base.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
|
from invokeai.app.services.image_records.image_records_common import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
|
ResourceOrigin,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
class ImageServiceABC(ABC):
|
||||||
|
"""High-level service for image management."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
||||||
|
"""Register a callback for when an image is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an image is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: ImageDTO) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
image: PILImageType,
|
||||||
|
image_origin: ResourceOrigin,
|
||||||
|
image_category: ImageCategory,
|
||||||
|
node_id: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
is_intermediate: bool = False,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
workflow: Optional[str] = None,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Creates an image, storing the file and its metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
image_name: str,
|
||||||
|
changes: ImageRecordChanges,
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Updates an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_pil_image(self, image_name: str) -> PILImageType:
|
||||||
|
"""Gets an image as a PIL image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_record(self, image_name: str) -> ImageRecord:
|
||||||
|
"""Gets an image record."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
|
"""Gets an image DTO."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_metadata(self, image_name: str) -> ImageMetadata:
|
||||||
|
"""Gets an image's metadata."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate_path(self, path: str) -> bool:
|
||||||
|
"""Validates an image's path."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets an image's or thumbnail's URL."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_many(
|
||||||
|
self,
|
||||||
|
offset: int = 0,
|
||||||
|
limit: int = 10,
|
||||||
|
image_origin: Optional[ResourceOrigin] = None,
|
||||||
|
categories: Optional[list[ImageCategory]] = None,
|
||||||
|
is_intermediate: Optional[bool] = None,
|
||||||
|
board_id: Optional[str] = None,
|
||||||
|
) -> OffsetPaginatedResults[ImageDTO]:
|
||||||
|
"""Gets a paginated list of image DTOs."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, image_name: str):
|
||||||
|
"""Deletes an image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_intermediates(self) -> int:
|
||||||
|
"""Deletes all intermediate images."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete_images_on_board(self, board_id: str):
|
||||||
|
"""Deletes all images on a board."""
|
||||||
|
pass
|
41
invokeai/app/services/images/images_common.py
Normal file
41
invokeai/app/services/images/images_common.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.image_records.image_records_common import ImageRecord
|
||||||
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUrlsDTO(BaseModelExcludeNull):
|
||||||
|
"""The URLs for an image and its thumbnail."""
|
||||||
|
|
||||||
|
image_name: str = Field(description="The unique name of the image.")
|
||||||
|
"""The unique name of the image."""
|
||||||
|
image_url: str = Field(description="The URL of the image.")
|
||||||
|
"""The URL of the image."""
|
||||||
|
thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
|
||||||
|
"""The URL of the image's thumbnail."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||||
|
"""Deserialized image record, enriched for the frontend."""
|
||||||
|
|
||||||
|
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
|
||||||
|
"""The id of the board the image belongs to, if one exists."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def image_record_to_dto(
|
||||||
|
image_record: ImageRecord,
|
||||||
|
image_url: str,
|
||||||
|
thumbnail_url: str,
|
||||||
|
board_id: Optional[str],
|
||||||
|
) -> ImageDTO:
|
||||||
|
"""Converts an image record to an image DTO."""
|
||||||
|
return ImageDTO(
|
||||||
|
**image_record.dict(),
|
||||||
|
image_url=image_url,
|
||||||
|
thumbnail_url=thumbnail_url,
|
||||||
|
board_id=board_id,
|
||||||
|
)
|
@ -1,144 +1,30 @@
|
|||||||
from abc import ABC, abstractmethod
|
from typing import Optional
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import ImageMetadata
|
from invokeai.app.invocations.metadata import ImageMetadata
|
||||||
from invokeai.app.models.image import (
|
from invokeai.app.services.invoker import Invoker
|
||||||
ImageCategory,
|
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||||
InvalidImageCategoryException,
|
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
||||||
InvalidOriginException,
|
|
||||||
ResourceOrigin,
|
from ..image_files.image_files_common import (
|
||||||
)
|
|
||||||
from invokeai.app.services.image_file_storage import (
|
|
||||||
ImageFileDeleteException,
|
ImageFileDeleteException,
|
||||||
ImageFileNotFoundException,
|
ImageFileNotFoundException,
|
||||||
ImageFileSaveException,
|
ImageFileSaveException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.image_record_storage import (
|
from ..image_records.image_records_common import (
|
||||||
|
ImageCategory,
|
||||||
|
ImageRecord,
|
||||||
|
ImageRecordChanges,
|
||||||
ImageRecordDeleteException,
|
ImageRecordDeleteException,
|
||||||
ImageRecordNotFoundException,
|
ImageRecordNotFoundException,
|
||||||
ImageRecordSaveException,
|
ImageRecordSaveException,
|
||||||
|
InvalidImageCategoryException,
|
||||||
|
InvalidOriginException,
|
||||||
|
ResourceOrigin,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.invoker import Invoker
|
from .images_base import ImageServiceABC
|
||||||
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
|
from .images_common import ImageDTO, image_record_to_dto
|
||||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
|
||||||
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
|
|
||||||
|
|
||||||
|
|
||||||
class ImageServiceABC(ABC):
|
|
||||||
"""High-level service for image management."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[ImageDTO], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[ImageDTO], None]) -> None:
|
|
||||||
"""Register a callback for when an image is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an image is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: ImageDTO) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create(
|
|
||||||
self,
|
|
||||||
image: PILImageType,
|
|
||||||
image_origin: ResourceOrigin,
|
|
||||||
image_category: ImageCategory,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
session_id: Optional[str] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: bool = False,
|
|
||||||
metadata: Optional[dict] = None,
|
|
||||||
workflow: Optional[str] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Creates an image, storing the file and its metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update(
|
|
||||||
self,
|
|
||||||
image_name: str,
|
|
||||||
changes: ImageRecordChanges,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""Updates an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_pil_image(self, image_name: str) -> PILImageType:
|
|
||||||
"""Gets an image as a PIL image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_record(self, image_name: str) -> ImageRecord:
|
|
||||||
"""Gets an image record."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_dto(self, image_name: str) -> ImageDTO:
|
|
||||||
"""Gets an image DTO."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_metadata(self, image_name: str) -> ImageMetadata:
|
|
||||||
"""Gets an image's metadata."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def validate_path(self, path: str) -> bool:
|
|
||||||
"""Validates an image's path."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets an image's or thumbnail's URL."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_many(
|
|
||||||
self,
|
|
||||||
offset: int = 0,
|
|
||||||
limit: int = 10,
|
|
||||||
image_origin: Optional[ResourceOrigin] = None,
|
|
||||||
categories: Optional[list[ImageCategory]] = None,
|
|
||||||
is_intermediate: Optional[bool] = None,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
) -> OffsetPaginatedResults[ImageDTO]:
|
|
||||||
"""Gets a paginated list of image DTOs."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, image_name: str):
|
|
||||||
"""Deletes an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_intermediates(self) -> int:
|
|
||||||
"""Deletes all intermediate images."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete_images_on_board(self, board_id: str):
|
|
||||||
"""Deletes all images on a board."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ImageService(ImageServiceABC):
|
class ImageService(ImageServiceABC):
|
@ -0,0 +1,5 @@
|
|||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationProcessorABC(ABC):
|
||||||
|
pass
|
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressImage(BaseModel):
|
||||||
|
"""The progress image sent intermittently during processing"""
|
||||||
|
|
||||||
|
width: int = Field(description="The effective width of the image in pixels")
|
||||||
|
height: int = Field(description="The effective height of the image in pixels")
|
||||||
|
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||||
|
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
"""Execution canceled by user."""
|
||||||
|
|
||||||
|
pass
|
@ -4,11 +4,12 @@ from threading import BoundedSemaphore, Event, Thread
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
from ..invocations.baseinvocation import InvocationContext
|
from ..invoker import Invoker
|
||||||
from ..models.exceptions import CanceledException
|
from .invocation_processor_base import InvocationProcessorABC
|
||||||
from .invocation_queue import InvocationQueueItem
|
from .invocation_processor_common import CanceledException
|
||||||
from .invoker import InvocationProcessorABC, Invoker
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
class DefaultInvocationProcessor(InvocationProcessorABC):
|
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
0
invokeai/app/services/invocation_queue/__init__.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationQueueABC(ABC):
|
||||||
|
"""Abstract base class for all invocation queues"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self) -> InvocationQueueItem:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cancel(self, graph_execution_state_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||||
|
pass
|
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationQueueItem(BaseModel):
|
||||||
|
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||||
|
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||||
|
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||||
|
session_queue_item_id: int = Field(
|
||||||
|
description="The ID of session queue item from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
session_queue_batch_id: str = Field(
|
||||||
|
description="The ID of the session batch from which this invocation queue item came"
|
||||||
|
)
|
||||||
|
invoke_all: bool = Field(default=False)
|
||||||
|
timestamp: float = Field(default_factory=time.time)
|
@ -1,45 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from .invocation_queue_base import InvocationQueueABC
|
||||||
|
from .invocation_queue_common import InvocationQueueItem
|
||||||
|
|
||||||
class InvocationQueueItem(BaseModel):
|
|
||||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
|
||||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
|
||||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
|
||||||
session_queue_item_id: int = Field(
|
|
||||||
description="The ID of session queue item from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
session_queue_batch_id: str = Field(
|
|
||||||
description="The ID of the session batch from which this invocation queue item came"
|
|
||||||
)
|
|
||||||
invoke_all: bool = Field(default=False)
|
|
||||||
timestamp: float = Field(default_factory=time.time)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationQueueABC(ABC):
|
|
||||||
"""Abstract base class for all invocation queues"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self) -> InvocationQueueItem:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def cancel(self, graph_execution_state_id: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryInvocationQueue(InvocationQueueABC):
|
class MemoryInvocationQueue(InvocationQueueABC):
|
@ -6,27 +6,27 @@ from typing import TYPE_CHECKING
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.board_image_record_storage import BoardImageRecordStorageBase
|
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||||
from invokeai.app.services.board_images import BoardImagesServiceABC
|
from .board_images.board_images_base import BoardImagesServiceABC
|
||||||
from invokeai.app.services.board_record_storage import BoardRecordStorageBase
|
from .board_records.board_records_base import BoardRecordStorageBase
|
||||||
from invokeai.app.services.boards import BoardServiceABC
|
from .boards.boards_base import BoardServiceABC
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from .config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from .events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
|
from .image_files.image_files_base import ImageFileStorageBase
|
||||||
from invokeai.app.services.image_file_storage import ImageFileStorageBase
|
from .image_records.image_records_base import ImageRecordStorageBase
|
||||||
from invokeai.app.services.image_record_storage import ImageRecordStorageBase
|
from .images.images_base import ImageServiceABC
|
||||||
from invokeai.app.services.images import ImageServiceABC
|
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
|
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
||||||
from invokeai.app.services.invocation_queue import InvocationQueueABC
|
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
from invokeai.app.services.invoker import InvocationProcessorABC
|
from .item_storage.item_storage_base import ItemStorageABC
|
||||||
from invokeai.app.services.item_storage import ItemStorageABC
|
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
from invokeai.app.services.resource_name import NameServiceBase
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
from .session_queue.session_queue_base import SessionQueueBase
|
||||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||||
from invokeai.app.services.urls import UrlServiceBase
|
from .urls.urls_base import UrlServiceBase
|
||||||
|
|
||||||
|
|
||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
|
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
0
invokeai/app/services/invocation_stats/__init__.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
121
invokeai/app/services/invocation_stats/invocation_stats_base.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
||||||
|
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
statistics = InvocationStatsService(graph_execution_manager)
|
||||||
|
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||||
|
... execute graphs...
|
||||||
|
statistics.log_stats()
|
||||||
|
|
||||||
|
Typical output:
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
||||||
|
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
||||||
|
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
||||||
|
|
||||||
|
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
||||||
|
writes to the system log is stored in InvocationServices.performance_statistics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import AbstractContextManager
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
|
from .invocation_stats_common import NodeLog
|
||||||
|
|
||||||
|
|
||||||
|
class InvocationStatsServiceBase(ABC):
|
||||||
|
"Abstract base class for recording node memory/time performance statistics"
|
||||||
|
|
||||||
|
# {graph_id => NodeLog}
|
||||||
|
_stats: Dict[str, NodeLog]
|
||||||
|
_cache_stats: Dict[str, CacheStats]
|
||||||
|
ram_used: float
|
||||||
|
ram_changed: float
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the InvocationStatsService and reset counters to zero
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_stats(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
graph_execution_state_id: str,
|
||||||
|
) -> AbstractContextManager:
|
||||||
|
"""
|
||||||
|
Return a context object that will capture the statistics on the execution
|
||||||
|
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||||
|
:param invocation: BaseInvocation object from the current graph.
|
||||||
|
:param graph_execution_state_id: The id of the current session.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_stats(self, graph_execution_state_id: str):
|
||||||
|
"""
|
||||||
|
Reset all statistics for the indicated graph
|
||||||
|
:param graph_execution_state_id
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset_all_stats(self):
|
||||||
|
"""Zero all statistics"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_invocation_stats(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
invocation_type: str,
|
||||||
|
time_used: float,
|
||||||
|
vram_used: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add timing information on execution of a node. Usually
|
||||||
|
used internally.
|
||||||
|
:param graph_id: ID of the graph that is currently executing
|
||||||
|
:param invocation_type: String literal type of the node
|
||||||
|
:param time_used: Time used by node's exection (sec)
|
||||||
|
:param vram_used: Maximum VRAM used during exection (GB)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def log_stats(self):
|
||||||
|
"""
|
||||||
|
Write out the accumulated statistics to the log or somewhere else.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_mem_stats(
|
||||||
|
self,
|
||||||
|
ram_used: float,
|
||||||
|
ram_changed: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the collector with RAM memory usage info.
|
||||||
|
|
||||||
|
:param ram_used: How much RAM is currently in use.
|
||||||
|
:param ram_changed: How much RAM changed since last generation.
|
||||||
|
"""
|
||||||
|
pass
|
@ -0,0 +1,25 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
# size of GIG in bytes
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeStats:
|
||||||
|
"""Class for tracking execution stats of an invocation node"""
|
||||||
|
|
||||||
|
calls: int = 0
|
||||||
|
time_used: float = 0.0 # seconds
|
||||||
|
max_vram: float = 0.0 # GB
|
||||||
|
cache_hits: int = 0
|
||||||
|
cache_misses: int = 0
|
||||||
|
cache_high_watermark: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NodeLog:
|
||||||
|
"""Class for tracking node usage"""
|
||||||
|
|
||||||
|
# {node_type => NodeStats}
|
||||||
|
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
@ -1,154 +1,17 @@
|
|||||||
# Copyright 2023 Lincoln D. Stein <lincoln.stein@gmail.com>
|
|
||||||
"""Utility to collect execution time and GPU usage stats on invocations in flight
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
statistics = InvocationStatsService(graph_execution_manager)
|
|
||||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
|
||||||
... execute graphs...
|
|
||||||
statistics.log_stats()
|
|
||||||
|
|
||||||
Typical output:
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Graph stats: c7764585-9c68-4d9d-a199-55e8186790f3
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> Node Calls Seconds VRAM Used
|
|
||||||
[2023-08-02 18:03:04,507]::[InvokeAI]::INFO --> main_model_loader 1 0.005s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> clip_skip 1 0.004s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> compel 2 0.512s 0.26G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> rand_int 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> range_of_size 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> iterate 1 0.001s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> metadata_accumulator 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> noise 1 0.002s 0.01G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> t2l 1 3.541s 1.93G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> l2i 1 0.679s 0.58G
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> TOTAL GRAPH EXECUTION TIME: 4.749s
|
|
||||||
[2023-08-02 18:03:04,508]::[InvokeAI]::INFO --> Current VRAM utilization 0.01G
|
|
||||||
|
|
||||||
The abstract base class for this class is InvocationStatsServiceBase. An implementing class which
|
|
||||||
writes to the system log is stored in InvocationServices.performance_statistics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import AbstractContextManager
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
from ..invocations.baseinvocation import BaseInvocation
|
from .invocation_stats_base import InvocationStatsServiceBase
|
||||||
from .model_manager_service import ModelManagerServiceBase
|
from .invocation_stats_common import GIG, NodeLog, NodeStats
|
||||||
|
|
||||||
# size of GIG in bytes
|
|
||||||
GIG = 1073741824
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeStats:
|
|
||||||
"""Class for tracking execution stats of an invocation node"""
|
|
||||||
|
|
||||||
calls: int = 0
|
|
||||||
time_used: float = 0.0 # seconds
|
|
||||||
max_vram: float = 0.0 # GB
|
|
||||||
cache_hits: int = 0
|
|
||||||
cache_misses: int = 0
|
|
||||||
cache_high_watermark: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NodeLog:
|
|
||||||
"""Class for tracking node usage"""
|
|
||||||
|
|
||||||
# {node_type => NodeStats}
|
|
||||||
nodes: Dict[str, NodeStats] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsServiceBase(ABC):
|
|
||||||
"Abstract base class for recording node memory/time performance statistics"
|
|
||||||
|
|
||||||
# {graph_id => NodeLog}
|
|
||||||
_stats: Dict[str, NodeLog]
|
|
||||||
_cache_stats: Dict[str, CacheStats]
|
|
||||||
ram_used: float
|
|
||||||
ram_changed: float
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(self):
|
|
||||||
"""
|
|
||||||
Initialize the InvocationStatsService and reset counters to zero
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_stats(
|
|
||||||
self,
|
|
||||||
invocation: BaseInvocation,
|
|
||||||
graph_execution_state_id: str,
|
|
||||||
) -> AbstractContextManager:
|
|
||||||
"""
|
|
||||||
Return a context object that will capture the statistics on the execution
|
|
||||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
|
||||||
:param invocation: BaseInvocation object from the current graph.
|
|
||||||
:param graph_execution_state: GraphExecutionState object from the current session.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_stats(self, graph_execution_state_id: str):
|
|
||||||
"""
|
|
||||||
Reset all statistics for the indicated graph
|
|
||||||
:param graph_execution_state_id
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def reset_all_stats(self):
|
|
||||||
"""Zero all statistics"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_invocation_stats(
|
|
||||||
self,
|
|
||||||
graph_id: str,
|
|
||||||
invocation_type: str,
|
|
||||||
time_used: float,
|
|
||||||
vram_used: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Add timing information on execution of a node. Usually
|
|
||||||
used internally.
|
|
||||||
:param graph_id: ID of the graph that is currently executing
|
|
||||||
:param invocation_type: String literal type of the node
|
|
||||||
:param time_used: Time used by node's exection (sec)
|
|
||||||
:param vram_used: Maximum VRAM used during exection (GB)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def log_stats(self):
|
|
||||||
"""
|
|
||||||
Write out the accumulated statistics to the log or somewhere else.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_mem_stats(
|
|
||||||
self,
|
|
||||||
ram_used: float,
|
|
||||||
ram_changed: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the collector with RAM memory usage info.
|
|
||||||
|
|
||||||
:param ram_used: How much RAM is currently in use.
|
|
||||||
:param ram_changed: How much RAM changed since last generation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class InvocationStatsService(InvocationStatsServiceBase):
|
class InvocationStatsService(InvocationStatsServiceBase):
|
@ -1,11 +1,10 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .graph import Graph, GraphExecutionState
|
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||||
from .invocation_queue import InvocationQueueItem
|
|
||||||
from .invocation_services import InvocationServices
|
from .invocation_services import InvocationServices
|
||||||
|
from .shared.graph import Graph, GraphExecutionState
|
||||||
|
|
||||||
|
|
||||||
class Invoker:
|
class Invoker:
|
||||||
@ -84,7 +83,3 @@ class Invoker:
|
|||||||
self.__stop_service(getattr(self.services, service))
|
self.__stop_service(getattr(self.services, service))
|
||||||
|
|
||||||
self.services.queue.put(None)
|
self.services.queue.put(None)
|
||||||
|
|
||||||
|
|
||||||
class InvocationProcessorABC(ABC):
|
|
||||||
pass
|
|
||||||
|
0
invokeai/app/services/item_storage/__init__.py
Normal file
0
invokeai/app/services/item_storage/__init__.py
Normal file
@ -9,6 +9,8 @@ T = TypeVar("T", bound=BaseModel)
|
|||||||
|
|
||||||
|
|
||||||
class ItemStorageABC(ABC, Generic[T]):
|
class ItemStorageABC(ABC, Generic[T]):
|
||||||
|
"""Provides storage for a single type of item. The type must be a Pydantic model."""
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[T], None]]
|
_on_changed_callbacks: list[Callable[[T], None]]
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
@ -4,15 +4,13 @@ from typing import Generic, Optional, TypeVar, get_args
|
|||||||
|
|
||||||
from pydantic import BaseModel, parse_raw_as
|
from pydantic import BaseModel, parse_raw_as
|
||||||
|
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
|
||||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
from .item_storage import ItemStorageABC
|
from .item_storage_base import ItemStorageABC
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
sqlite_memory = ":memory:"
|
|
||||||
|
|
||||||
|
|
||||||
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
||||||
_table_name: str
|
_table_name: str
|
||||||
@ -47,7 +45,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
|
|||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def _parse_item(self, item: str) -> T:
|
def _parse_item(self, item: str) -> T:
|
||||||
item_type = get_args(self.__orig_class__)[0]
|
# __orig_class__ is technically an implementation detail of the typing module, not a supported API
|
||||||
|
item_type = get_args(self.__orig_class__)[0] # type: ignore
|
||||||
return parse_raw_as(item_type, item)
|
return parse_raw_as(item_type, item)
|
||||||
|
|
||||||
def set(self, item: T):
|
def set(self, item: T):
|
@ -1,119 +0,0 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from queue import Queue
|
|
||||||
from typing import Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsStorageBase(ABC):
|
|
||||||
"""Responsible for storing and retrieving latents."""
|
|
||||||
|
|
||||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
|
||||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._on_changed_callbacks = list()
|
|
||||||
self._on_deleted_callbacks = list()
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
|
||||||
"""Register a callback for when an item is changed"""
|
|
||||||
self._on_changed_callbacks.append(on_changed)
|
|
||||||
|
|
||||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
|
||||||
"""Register a callback for when an item is deleted"""
|
|
||||||
self._on_deleted_callbacks.append(on_deleted)
|
|
||||||
|
|
||||||
def _on_changed(self, item: torch.Tensor) -> None:
|
|
||||||
for callback in self._on_changed_callbacks:
|
|
||||||
callback(item)
|
|
||||||
|
|
||||||
def _on_deleted(self, item_id: str) -> None:
|
|
||||||
for callback in self._on_deleted_callbacks:
|
|
||||||
callback(item_id)
|
|
||||||
|
|
||||||
|
|
||||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
|
||||||
|
|
||||||
__cache: Dict[str, torch.Tensor]
|
|
||||||
__cache_ids: Queue
|
|
||||||
__max_cache_size: int
|
|
||||||
__underlying_storage: LatentsStorageBase
|
|
||||||
|
|
||||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
|
||||||
super().__init__()
|
|
||||||
self.__underlying_storage = underlying_storage
|
|
||||||
self.__cache = dict()
|
|
||||||
self.__cache_ids = Queue()
|
|
||||||
self.__max_cache_size = max_cache_size
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
cache_item = self.__get_cache(name)
|
|
||||||
if cache_item is not None:
|
|
||||||
return cache_item
|
|
||||||
|
|
||||||
latent = self.__underlying_storage.get(name)
|
|
||||||
self.__set_cache(name, latent)
|
|
||||||
return latent
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__underlying_storage.save(name, data)
|
|
||||||
self.__set_cache(name, data)
|
|
||||||
self._on_changed(data)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
self.__underlying_storage.delete(name)
|
|
||||||
if name in self.__cache:
|
|
||||||
del self.__cache[name]
|
|
||||||
self._on_deleted(name)
|
|
||||||
|
|
||||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
|
||||||
return None if name not in self.__cache else self.__cache[name]
|
|
||||||
|
|
||||||
def __set_cache(self, name: str, data: torch.Tensor):
|
|
||||||
if name not in self.__cache:
|
|
||||||
self.__cache[name] = data
|
|
||||||
self.__cache_ids.put(name)
|
|
||||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
|
||||||
self.__cache.pop(self.__cache_ids.get())
|
|
||||||
|
|
||||||
|
|
||||||
class DiskLatentsStorage(LatentsStorageBase):
|
|
||||||
"""Stores latents in a folder on disk without caching"""
|
|
||||||
|
|
||||||
__output_folder: Union[str, Path]
|
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
|
||||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def get(self, name: str) -> torch.Tensor:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
return torch.load(latent_path)
|
|
||||||
|
|
||||||
def save(self, name: str, data: torch.Tensor) -> None:
|
|
||||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
torch.save(data, latent_path)
|
|
||||||
|
|
||||||
def delete(self, name: str) -> None:
|
|
||||||
latent_path = self.get_path(name)
|
|
||||||
latent_path.unlink()
|
|
||||||
|
|
||||||
def get_path(self, name: str) -> Path:
|
|
||||||
return self.__output_folder / name
|
|
0
invokeai/app/services/latents_storage/__init__.py
Normal file
0
invokeai/app/services/latents_storage/__init__.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsStorageBase(ABC):
|
||||||
|
"""Responsible for storing and retrieving latents."""
|
||||||
|
|
||||||
|
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||||
|
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._on_changed_callbacks = list()
|
||||||
|
self._on_deleted_callbacks = list()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||||
|
"""Register a callback for when an item is changed"""
|
||||||
|
self._on_changed_callbacks.append(on_changed)
|
||||||
|
|
||||||
|
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||||
|
"""Register a callback for when an item is deleted"""
|
||||||
|
self._on_deleted_callbacks.append(on_deleted)
|
||||||
|
|
||||||
|
def _on_changed(self, item: torch.Tensor) -> None:
|
||||||
|
for callback in self._on_changed_callbacks:
|
||||||
|
callback(item)
|
||||||
|
|
||||||
|
def _on_deleted(self, item_id: str) -> None:
|
||||||
|
for callback in self._on_deleted_callbacks:
|
||||||
|
callback(item_id)
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .latents_storage_base import LatentsStorageBase
|
||||||
|
|
||||||
|
|
||||||
|
class DiskLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Stores latents in a folder on disk without caching"""
|
||||||
|
|
||||||
|
__output_folder: Path
|
||||||
|
|
||||||
|
def __init__(self, output_folder: Union[str, Path]):
|
||||||
|
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
return torch.load(latent_path)
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
torch.save(data, latent_path)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
latent_path = self.get_path(name)
|
||||||
|
latent_path.unlink()
|
||||||
|
|
||||||
|
def get_path(self, name: str) -> Path:
|
||||||
|
return self.__output_folder / name
|
@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
from queue import Queue
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .latents_storage_base import LatentsStorageBase
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||||
|
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||||
|
|
||||||
|
__cache: Dict[str, torch.Tensor]
|
||||||
|
__cache_ids: Queue
|
||||||
|
__max_cache_size: int
|
||||||
|
__underlying_storage: LatentsStorageBase
|
||||||
|
|
||||||
|
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||||
|
super().__init__()
|
||||||
|
self.__underlying_storage = underlying_storage
|
||||||
|
self.__cache = dict()
|
||||||
|
self.__cache_ids = Queue()
|
||||||
|
self.__max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
def get(self, name: str) -> torch.Tensor:
|
||||||
|
cache_item = self.__get_cache(name)
|
||||||
|
if cache_item is not None:
|
||||||
|
return cache_item
|
||||||
|
|
||||||
|
latent = self.__underlying_storage.get(name)
|
||||||
|
self.__set_cache(name, latent)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def save(self, name: str, data: torch.Tensor) -> None:
|
||||||
|
self.__underlying_storage.save(name, data)
|
||||||
|
self.__set_cache(name, data)
|
||||||
|
self._on_changed(data)
|
||||||
|
|
||||||
|
def delete(self, name: str) -> None:
|
||||||
|
self.__underlying_storage.delete(name)
|
||||||
|
if name in self.__cache:
|
||||||
|
del self.__cache[name]
|
||||||
|
self._on_deleted(name)
|
||||||
|
|
||||||
|
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||||
|
return None if name not in self.__cache else self.__cache[name]
|
||||||
|
|
||||||
|
def __set_cache(self, name: str, data: torch.Tensor):
|
||||||
|
if name not in self.__cache:
|
||||||
|
self.__cache[name] = data
|
||||||
|
self.__cache_ids.put(name)
|
||||||
|
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||||
|
self.__cache.pop(self.__cache_ids.get())
|
0
invokeai/app/services/model_manager/__init__.py
Normal file
0
invokeai/app/services/model_manager/__init__.py
Normal file
286
invokeai/app/services/model_manager/model_manager_base.py
Normal file
286
invokeai/app/services/model_manager/model_manager_base.py
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.model_management import (
|
||||||
|
AddModelResult,
|
||||||
|
BaseModelType,
|
||||||
|
MergeInterpolationMethod,
|
||||||
|
ModelInfo,
|
||||||
|
ModelType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManagerServiceBase(ABC):
|
||||||
|
"""Responsible for managing models on disk and in memory"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize with the path to the models.yaml config file.
|
||||||
|
Optional parameters are the torch device type, precision, max_models,
|
||||||
|
and sequential_offload boolean. Note that the default device
|
||||||
|
type and precision are set up for a CUDA system running at half precision.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel: Optional[SubModelType] = None,
|
||||||
|
node: Optional[BaseInvocation] = None,
|
||||||
|
context: Optional[InvocationContext] = None,
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Retrieve the indicated model with name and type.
|
||||||
|
submodel can be used to get a part (such as the vae)
|
||||||
|
of a diffusers pipeline."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def logger(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_exists(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Given a model name returns a dict-like (OmegaConf) object describing it.
|
||||||
|
Uses the exact format as the omegaconf stanza.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
||||||
|
"""
|
||||||
|
Return a dict of models in the format:
|
||||||
|
{ model_type1:
|
||||||
|
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
||||||
|
'model_name' : name,
|
||||||
|
'model_type' : SDModelType,
|
||||||
|
'description': description,
|
||||||
|
'format': 'folder'|'safetensors'|'ckpt'
|
||||||
|
},
|
||||||
|
model_name2: { etc }
|
||||||
|
},
|
||||||
|
model_type2:
|
||||||
|
{ model_name_n: etc
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Return information about the model using the same format as list_models()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
||||||
|
"""
|
||||||
|
Returns a list of all the model names known.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def add_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
clobber: bool = False,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
model_attributes: dict,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Update the named model with a dictionary of attributes. Will fail with a
|
||||||
|
ModelNotFoundException if the name does not already exist.
|
||||||
|
|
||||||
|
On a successful update, the config will be changed in memory. Will fail
|
||||||
|
with an assertion error if provided attributes are incorrect or
|
||||||
|
the model name is missing. Call commit() to write changes to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def del_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete the named model from configuration. If delete_files is true,
|
||||||
|
then the underlying weight file or diffusers directory will be deleted
|
||||||
|
as well. Call commit() to write to disk.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rename_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
new_name: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Rename the indicated model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_checkpoint_configs(self) -> List[Path]:
|
||||||
|
"""
|
||||||
|
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def convert_model(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: Literal[ModelType.Main, ModelType.Vae],
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||||
|
version and deleting the original checkpoint file if it is in the models
|
||||||
|
directory.
|
||||||
|
:param model_name: Name of the model to convert
|
||||||
|
:param base_model: Base model type
|
||||||
|
:param model_type: Type of model ['vae' or 'main']
|
||||||
|
|
||||||
|
This will raise a ValueError unless the model is not a checkpoint. It will
|
||||||
|
also raise a ValueError in the event that there is a similarly-named diffusers
|
||||||
|
directory already in place.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def heuristic_import(
|
||||||
|
self,
|
||||||
|
items_to_import: set[str],
|
||||||
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||||
|
) -> dict[str, AddModelResult]:
|
||||||
|
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
|
successfully imported items.
|
||||||
|
:param items_to_import: Set of strings corresponding to models to be imported.
|
||||||
|
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
||||||
|
|
||||||
|
The prediction type helper is necessary to distinguish between
|
||||||
|
models based on Stable Diffusion 2 Base (requiring
|
||||||
|
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
||||||
|
(requiring SchedulerPredictionType.VPrediction). It is
|
||||||
|
generally impossible to do this programmatically, so the
|
||||||
|
prediction_type_helper usually asks the user to choose.
|
||||||
|
|
||||||
|
The result is a set of successfully installed models. Each element
|
||||||
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
|
that model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def merge_models(
|
||||||
|
self,
|
||||||
|
model_names: List[str] = Field(
|
||||||
|
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
||||||
|
),
|
||||||
|
base_model: Union[BaseModelType, str] = Field(
|
||||||
|
default=None, description="Base model shared by all models to be merged"
|
||||||
|
),
|
||||||
|
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
||||||
|
alpha: Optional[float] = 0.5,
|
||||||
|
interp: Optional[MergeInterpolationMethod] = None,
|
||||||
|
force: Optional[bool] = False,
|
||||||
|
merge_dest_directory: Optional[Path] = None,
|
||||||
|
) -> AddModelResult:
|
||||||
|
"""
|
||||||
|
Merge two to three diffusrs pipeline models and save as a new model.
|
||||||
|
:param model_names: List of 2-3 models to merge
|
||||||
|
:param base_model: Base model to use for all models
|
||||||
|
:param merged_model_name: Name of destination merged model
|
||||||
|
:param alpha: Alpha strength to apply to 2d and 3d model
|
||||||
|
:param interp: Interpolation method. None (default)
|
||||||
|
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def search_for_models(self, directory: Path) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Return list of all models found in the designated directory.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_to_config(self):
|
||||||
|
"""
|
||||||
|
Re-read models.yaml, rescan the models directory, and reimport models
|
||||||
|
in the autoimport directories. Call after making changes outside the
|
||||||
|
model manager API.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collect_cache_stats(self, cache_stats: CacheStats):
|
||||||
|
"""
|
||||||
|
Reset model cache statistics for graph with graph_id.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def commit(self, conf_file: Optional[Path] = None) -> None:
|
||||||
|
"""
|
||||||
|
Write current configuration out to the indicated file.
|
||||||
|
If no conf_file is provided, then replaces the
|
||||||
|
original file/database used to initialize the object.
|
||||||
|
"""
|
||||||
|
pass
|
@ -2,16 +2,15 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import ModuleType
|
|
||||||
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||||
from invokeai.backend.model_management import (
|
from invokeai.backend.model_management import (
|
||||||
AddModelResult,
|
AddModelResult,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -26,273 +25,12 @@ from invokeai.backend.model_management import (
|
|||||||
)
|
)
|
||||||
from invokeai.backend.model_management.model_cache import CacheStats
|
from invokeai.backend.model_management.model_cache import CacheStats
|
||||||
from invokeai.backend.model_management.model_search import FindModels
|
from invokeai.backend.model_management.model_search import FindModels
|
||||||
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from .model_manager_base import ModelManagerServiceBase
|
||||||
from .config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||||
|
|
||||||
|
|
||||||
class ModelManagerServiceBase(ABC):
|
|
||||||
"""Responsible for managing models on disk and in memory"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: InvokeAIAppConfig,
|
|
||||||
logger: ModuleType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize with the path to the models.yaml config file.
|
|
||||||
Optional parameters are the torch device type, precision, max_models,
|
|
||||||
and sequential_offload boolean. Note that the default device
|
|
||||||
type and precision are set up for a CUDA system running at half precision.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
submodel: Optional[SubModelType] = None,
|
|
||||||
node: Optional[BaseInvocation] = None,
|
|
||||||
context: Optional[InvocationContext] = None,
|
|
||||||
) -> ModelInfo:
|
|
||||||
"""Retrieve the indicated model with name and type.
|
|
||||||
submodel can be used to get a part (such as the vae)
|
|
||||||
of a diffusers pipeline."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def logger(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_exists(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Given a model name returns a dict-like (OmegaConf) object describing it.
|
|
||||||
Uses the exact format as the omegaconf stanza.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None) -> dict:
|
|
||||||
"""
|
|
||||||
Return a dict of models in the format:
|
|
||||||
{ model_type1:
|
|
||||||
{ model_name1: {'status': 'active'|'cached'|'not loaded',
|
|
||||||
'model_name' : name,
|
|
||||||
'model_type' : SDModelType,
|
|
||||||
'description': description,
|
|
||||||
'format': 'folder'|'safetensors'|'ckpt'
|
|
||||||
},
|
|
||||||
model_name2: { etc }
|
|
||||||
},
|
|
||||||
model_type2:
|
|
||||||
{ model_name_n: etc
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
|
||||||
"""
|
|
||||||
Return information about the model using the same format as list_models()
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]:
|
|
||||||
"""
|
|
||||||
Returns a list of all the model names known.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
clobber: bool = False,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
model_attributes: dict,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Update the named model with a dictionary of attributes. Will fail with a
|
|
||||||
ModelNotFoundException if the name does not already exist.
|
|
||||||
|
|
||||||
On a successful update, the config will be changed in memory. Will fail
|
|
||||||
with an assertion error if provided attributes are incorrect or
|
|
||||||
the model name is missing. Call commit() to write changes to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def del_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the named model from configuration. If delete_files is true,
|
|
||||||
then the underlying weight file or diffusers directory will be deleted
|
|
||||||
as well. Call commit() to write to disk.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def rename_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: ModelType,
|
|
||||||
new_name: str,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Rename the indicated model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_checkpoint_configs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
List the checkpoint config paths from ROOT/configs/stable-diffusion.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def convert_model(
|
|
||||||
self,
|
|
||||||
model_name: str,
|
|
||||||
base_model: BaseModelType,
|
|
||||||
model_type: Literal[ModelType.Main, ModelType.Vae],
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
|
||||||
version and deleting the original checkpoint file if it is in the models
|
|
||||||
directory.
|
|
||||||
:param model_name: Name of the model to convert
|
|
||||||
:param base_model: Base model type
|
|
||||||
:param model_type: Type of model ['vae' or 'main']
|
|
||||||
|
|
||||||
This will raise a ValueError unless the model is not a checkpoint. It will
|
|
||||||
also raise a ValueError in the event that there is a similarly-named diffusers
|
|
||||||
directory already in place.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def heuristic_import(
|
|
||||||
self,
|
|
||||||
items_to_import: set[str],
|
|
||||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
||||||
) -> dict[str, AddModelResult]:
|
|
||||||
"""Import a list of paths, repo_ids or URLs. Returns the set of
|
|
||||||
successfully imported items.
|
|
||||||
:param items_to_import: Set of strings corresponding to models to be imported.
|
|
||||||
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
|
|
||||||
|
|
||||||
The prediction type helper is necessary to distinguish between
|
|
||||||
models based on Stable Diffusion 2 Base (requiring
|
|
||||||
SchedulerPredictionType.Epsilson) and Stable Diffusion 768
|
|
||||||
(requiring SchedulerPredictionType.VPrediction). It is
|
|
||||||
generally impossible to do this programmatically, so the
|
|
||||||
prediction_type_helper usually asks the user to choose.
|
|
||||||
|
|
||||||
The result is a set of successfully installed models. Each element
|
|
||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
|
||||||
that model.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def merge_models(
|
|
||||||
self,
|
|
||||||
model_names: List[str] = Field(
|
|
||||||
default=None, min_items=2, max_items=3, description="List of model names to merge"
|
|
||||||
),
|
|
||||||
base_model: Union[BaseModelType, str] = Field(
|
|
||||||
default=None, description="Base model shared by all models to be merged"
|
|
||||||
),
|
|
||||||
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
|
|
||||||
alpha: Optional[float] = 0.5,
|
|
||||||
interp: Optional[MergeInterpolationMethod] = None,
|
|
||||||
force: Optional[bool] = False,
|
|
||||||
merge_dest_directory: Optional[Path] = None,
|
|
||||||
) -> AddModelResult:
|
|
||||||
"""
|
|
||||||
Merge two to three diffusrs pipeline models and save as a new model.
|
|
||||||
:param model_names: List of 2-3 models to merge
|
|
||||||
:param base_model: Base model to use for all models
|
|
||||||
:param merged_model_name: Name of destination merged model
|
|
||||||
:param alpha: Alpha strength to apply to 2d and 3d model
|
|
||||||
:param interp: Interpolation method. None (default)
|
|
||||||
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_for_models(self, directory: Path) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Return list of all models found in the designated directory.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sync_to_config(self):
|
|
||||||
"""
|
|
||||||
Re-read models.yaml, rescan the models directory, and reimport models
|
|
||||||
in the autoimport directories. Call after making changes outside the
|
|
||||||
model manager API.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def collect_cache_stats(self, cache_stats: CacheStats):
|
|
||||||
"""
|
|
||||||
Reset model cache statistics for graph with graph_id.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def commit(self, conf_file: Optional[Path] = None) -> None:
|
|
||||||
"""
|
|
||||||
Write current configuration out to the indicated file.
|
|
||||||
If no conf_file is provided, then replaces the
|
|
||||||
original file/database used to initialize the object.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# simple implementation
|
# simple implementation
|
0
invokeai/app/services/names/__init__.py
Normal file
0
invokeai/app/services/names/__init__.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
11
invokeai/app/services/names/names_base.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class NameServiceBase(ABC):
|
||||||
|
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
@abstractmethod
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
"""Creates a name for an image."""
|
||||||
|
pass
|
8
invokeai/app/services/names/names_common.py
Normal file
8
invokeai/app/services/names/names_common.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from enum import Enum, EnumMeta
|
||||||
|
|
||||||
|
|
||||||
|
class ResourceType(str, Enum, metaclass=EnumMeta):
|
||||||
|
"""Enum for resource types."""
|
||||||
|
|
||||||
|
IMAGE = "image"
|
||||||
|
LATENT = "latent"
|
13
invokeai/app/services/names/names_default.py
Normal file
13
invokeai/app/services/names/names_default.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
|
from .names_base import NameServiceBase
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNameService(NameServiceBase):
|
||||||
|
"""Creates image names from UUIDs."""
|
||||||
|
|
||||||
|
# TODO: Add customizable naming schemes
|
||||||
|
def create_image_name(self) -> str:
|
||||||
|
uuid_str = uuid_string()
|
||||||
|
filename = f"{uuid_str}.png"
|
||||||
|
return filename
|
@ -1,31 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from enum import Enum, EnumMeta
|
|
||||||
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
|
|
||||||
class ResourceType(str, Enum, metaclass=EnumMeta):
|
|
||||||
"""Enum for resource types."""
|
|
||||||
|
|
||||||
IMAGE = "image"
|
|
||||||
LATENT = "latent"
|
|
||||||
|
|
||||||
|
|
||||||
class NameServiceBase(ABC):
|
|
||||||
"""Low-level service responsible for naming resources (images, latents, etc)."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
@abstractmethod
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
"""Creates a name for an image."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleNameService(NameServiceBase):
|
|
||||||
"""Creates image names from UUIDs."""
|
|
||||||
|
|
||||||
# TODO: Add customizable naming schemes
|
|
||||||
def create_image_name(self) -> str:
|
|
||||||
uuid_str = uuid_string()
|
|
||||||
filename = f"{uuid_str}.png"
|
|
||||||
return filename
|
|
@ -7,7 +7,7 @@ from typing import Optional
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from invokeai.app.services.graph import Graph
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
QUEUE_ITEM_STATUS,
|
QUEUE_ITEM_STATUS,
|
||||||
Batch,
|
Batch,
|
||||||
@ -18,6 +17,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
SessionQueueItemDTO,
|
SessionQueueItemDTO,
|
||||||
SessionQueueStatus,
|
SessionQueueStatus,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, StrictStr, parse_raw_as, root_validator,
|
|||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState, NodeNotFoundError
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# region Errors
|
# region Errors
|
||||||
|
@ -5,8 +5,7 @@ from typing import Optional, Union, cast
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.graph import Graph
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
@ -29,8 +28,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
calc_session_count,
|
calc_session_count,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
|
||||||
class SqliteSessionQueue(SessionQueueBase):
|
class SqliteSessionQueue(SessionQueueBase):
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from ..invocations.compel import CompelInvocation
|
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||||
from ..invocations.image import ImageNSFWBlurInvocation
|
|
||||||
from ..invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
from ...invocations.compel import CompelInvocation
|
||||||
from ..invocations.noise import NoiseInvocation
|
from ...invocations.image import ImageNSFWBlurInvocation
|
||||||
from ..invocations.primitives import IntegerInvocation
|
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
||||||
|
from ...invocations.noise import NoiseInvocation
|
||||||
|
from ...invocations.primitives import IntegerInvocation
|
||||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||||
from .item_storage import ItemStorageABC
|
|
||||||
|
|
||||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||||
|
|
@ -8,11 +8,9 @@ import networkx as nx
|
|||||||
from pydantic import BaseModel, root_validator, validator
|
from pydantic import BaseModel, root_validator, validator
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
|
||||||
|
|
||||||
# Importing * is bad karma but needed here for node detection
|
# Importing * is bad karma but needed here for node detection
|
||||||
from ..invocations import * # noqa: F401 F403
|
from invokeai.app.invocations import * # noqa: F401 F403
|
||||||
from ..invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
Input,
|
Input,
|
||||||
@ -23,6 +21,7 @@ from ..invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
|
||||||
# in 3.10 this would be "from types import NoneType"
|
# in 3.10 this would be "from types import NoneType"
|
||||||
NoneType = type(None)
|
NoneType = type(None)
|
@ -4,6 +4,8 @@ from logging import Logger
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
sqlite_memory = ":memory:"
|
||||||
|
|
||||||
|
|
||||||
class SqliteDatabase:
|
class SqliteDatabase:
|
||||||
conn: sqlite3.Connection
|
conn: sqlite3.Connection
|
||||||
@ -16,7 +18,7 @@ class SqliteDatabase:
|
|||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
if self._config.use_memory_db:
|
if self._config.use_memory_db:
|
||||||
location = ":memory:"
|
location = sqlite_memory
|
||||||
logger.info("Using in-memory database")
|
logger.info("Using in-memory database")
|
||||||
else:
|
else:
|
||||||
db_path = self._config.db_path
|
db_path = self._config.db_path
|
||||||
|
0
invokeai/app/services/urls/__init__.py
Normal file
0
invokeai/app/services/urls/__init__.py
Normal file
10
invokeai/app/services/urls/urls_base.py
Normal file
10
invokeai/app/services/urls/urls_base.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class UrlServiceBase(ABC):
|
||||||
|
"""Responsible for building URLs for resources."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
|
"""Gets the URL for an image or thumbnail."""
|
||||||
|
pass
|
@ -1,14 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
from .urls_base import UrlServiceBase
|
||||||
class UrlServiceBase(ABC):
|
|
||||||
"""Responsible for building URLs for resources."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
|
||||||
"""Gets the URL for an image or thumbnail."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocalUrlService(UrlServiceBase):
|
class LocalUrlService(UrlServiceBase):
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.graph import Edge
|
from invokeai.app.services.shared.graph import Edge
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
def get_metadata_graph_from_raw_session(session_raw: str) -> Optional[dict]:
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||||
from invokeai.app.models.image import ProgressImage
|
|
||||||
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
|
2567
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
2567
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -10,20 +9,27 @@ from .test_nodes import ( # isort: split
|
|||||||
TestEventService,
|
TestEventService,
|
||||||
TextToImageTestInvocation,
|
TextToImageTestInvocation,
|
||||||
)
|
)
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||||
from invokeai.app.invocations.collections import RangeInvocation
|
from invokeai.app.invocations.collections import RangeInvocation
|
||||||
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
|
||||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.app.services.graph import CollectInvocation, Graph, GraphExecutionState, IterateInvocation, LibraryGraph
|
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.shared.graph import (
|
||||||
|
CollectInvocation,
|
||||||
|
Graph,
|
||||||
|
GraphExecutionState,
|
||||||
|
IterateInvocation,
|
||||||
|
LibraryGraph,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .test_invoker import create_edge
|
from .test_invoker import create_edge
|
||||||
|
|
||||||
@ -42,29 +48,33 @@ def simple_graph():
|
|||||||
# the test invocations.
|
# the test invocations.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
lock = threading.Lock()
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
|
db = SqliteDatabase(configuration, InvokeAILogger.get_logger())
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
|
||||||
conn=db_conn, table_name="graph_executions", lock=lock
|
|
||||||
)
|
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
board_image_records=None, # type: ignore
|
||||||
events=TestEventService(),
|
|
||||||
logger=logging, # type: ignore
|
|
||||||
images=None, # type: ignore
|
|
||||||
latents=None, # type: ignore
|
|
||||||
boards=None, # type: ignore
|
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
board_records=None, # type: ignore
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
boards=None, # type: ignore
|
||||||
|
configuration=configuration,
|
||||||
|
events=TestEventService(),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||||
|
image_files=None, # type: ignore
|
||||||
|
image_records=None, # type: ignore
|
||||||
|
images=None, # type: ignore
|
||||||
|
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||||
|
latents=None, # type: ignore
|
||||||
|
logger=logging, # type: ignore
|
||||||
|
model_manager=None, # type: ignore
|
||||||
|
names=None, # type: ignore
|
||||||
|
performance_statistics=InvocationStatsService(),
|
||||||
processor=DefaultInvocationProcessor(),
|
processor=DefaultInvocationProcessor(),
|
||||||
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
|
queue=MemoryInvocationQueue(),
|
||||||
session_queue=None, # type: ignore
|
|
||||||
session_processor=None, # type: ignore
|
session_processor=None, # type: ignore
|
||||||
invocation_cache=MemoryInvocationCache(), # type: ignore
|
session_queue=None, # type: ignore
|
||||||
|
urls=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
# This import must happen before other invoke imports or test in other files(!!) break
|
# This import must happen before other invoke imports or test in other files(!!) break
|
||||||
from .test_nodes import ( # isort: split
|
from .test_nodes import ( # isort: split
|
||||||
@ -16,15 +15,16 @@ from .test_nodes import ( # isort: split
|
|||||||
wait_until,
|
wait_until,
|
||||||
)
|
)
|
||||||
|
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
from invokeai.app.services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||||
|
from invokeai.app.services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -52,29 +52,34 @@ def graph_with_subgraph():
|
|||||||
# the test invocations.
|
# the test invocations.
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
lock = threading.Lock()
|
db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||||
|
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
|
||||||
|
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions")
|
||||||
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
|
|
||||||
conn=db_conn, table_name="graph_executions", lock=lock
|
|
||||||
)
|
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
model_manager=None, # type: ignore
|
board_image_records=None, # type: ignore
|
||||||
events=TestEventService(),
|
|
||||||
logger=logging, # type: ignore
|
|
||||||
images=None, # type: ignore
|
|
||||||
latents=None, # type: ignore
|
|
||||||
boards=None, # type: ignore
|
|
||||||
board_images=None, # type: ignore
|
board_images=None, # type: ignore
|
||||||
queue=MemoryInvocationQueue(),
|
board_records=None, # type: ignore
|
||||||
graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, table_name="graphs", lock=lock),
|
boards=None, # type: ignore
|
||||||
|
configuration=configuration,
|
||||||
|
events=TestEventService(),
|
||||||
graph_execution_manager=graph_execution_manager,
|
graph_execution_manager=graph_execution_manager,
|
||||||
processor=DefaultInvocationProcessor(),
|
graph_library=SqliteItemStorage[LibraryGraph](db=db, table_name="graphs"),
|
||||||
performance_statistics=InvocationStatsService(graph_execution_manager),
|
image_files=None, # type: ignore
|
||||||
configuration=InvokeAIAppConfig(node_cache_size=0), # type: ignore
|
image_records=None, # type: ignore
|
||||||
session_queue=None, # type: ignore
|
images=None, # type: ignore
|
||||||
session_processor=None, # type: ignore
|
|
||||||
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
invocation_cache=MemoryInvocationCache(max_cache_size=0),
|
||||||
|
latents=None, # type: ignore
|
||||||
|
logger=logging, # type: ignore
|
||||||
|
model_manager=None, # type: ignore
|
||||||
|
names=None, # type: ignore
|
||||||
|
performance_statistics=InvocationStatsService(),
|
||||||
|
processor=DefaultInvocationProcessor(),
|
||||||
|
queue=MemoryInvocationQueue(),
|
||||||
|
session_processor=None, # type: ignore
|
||||||
|
session_queue=None, # type: ignore
|
||||||
|
urls=None, # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,8 +11,8 @@ from invokeai.app.invocations.image import ShowImageInvocation
|
|||||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||||
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
from invokeai.app.invocations.primitives import FloatInvocation, IntegerInvocation
|
||||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||||
from invokeai.app.services.default_graphs import create_text_to_image
|
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||||
from invokeai.app.services.graph import (
|
from invokeai.app.services.shared.graph import (
|
||||||
CollectInvocation,
|
CollectInvocation,
|
||||||
Edge,
|
Edge,
|
||||||
EdgeConnection,
|
EdgeConnection,
|
||||||
|
@ -82,8 +82,8 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
# Importing these must happen after test invocations are defined or they won't register
|
# Importing these must happen after test invocations are defined or they won't register
|
||||||
from invokeai.app.services.events import EventServiceBase # noqa: E402
|
from invokeai.app.services.events.events_base import EventServiceBase # noqa: E402
|
||||||
from invokeai.app.services.graph import Edge, EdgeConnection # noqa: E402
|
from invokeai.app.services.shared.graph import Edge, EdgeConnection # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from pydantic import ValidationError, parse_raw_as
|
from pydantic import ValidationError, parse_raw_as
|
||||||
|
|
||||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation
|
|
||||||
from invokeai.app.services.session_queue.session_queue_common import (
|
from invokeai.app.services.session_queue.session_queue_common import (
|
||||||
Batch,
|
Batch,
|
||||||
BatchDataCollection,
|
BatchDataCollection,
|
||||||
@ -12,6 +11,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
|||||||
populate_graph,
|
populate_graph,
|
||||||
prepare_values_to_insert,
|
prepare_values_to_insert,
|
||||||
)
|
)
|
||||||
|
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||||
from tests.nodes.test_nodes import PromptTestInvocation
|
from tests.nodes.test_nodes import PromptTestInvocation
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import sqlite3
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.item_storage.item_storage_sqlite import SqliteItemStorage
|
||||||
|
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
|
||||||
class TestModel(BaseModel):
|
class TestModel(BaseModel):
|
||||||
@ -14,8 +14,8 @@ class TestModel(BaseModel):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db() -> SqliteItemStorage[TestModel]:
|
def db() -> SqliteItemStorage[TestModel]:
|
||||||
db_conn = sqlite3.connect(sqlite_memory, check_same_thread=False)
|
sqlite_db = SqliteDatabase(InvokeAIAppConfig(use_memory_db=True), InvokeAILogger.get_logger())
|
||||||
return SqliteItemStorage[TestModel](db_conn, table_name="test", id_field="id", lock=threading.Lock())
|
return SqliteItemStorage[TestModel](db=sqlite_db, table_name="test", id_field="id")
|
||||||
|
|
||||||
|
|
||||||
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
def test_sqlite_service_can_create_and_get(db: SqliteItemStorage[TestModel]):
|
||||||
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelManager, ModelType, SubModelType
|
||||||
|
|
||||||
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
BASIC_MODEL_NAME = ("SDXL base", BaseModelType.StableDiffusionXL, ModelType.Main)
|
||||||
|
Loading…
Reference in New Issue
Block a user