feat(nodes): add nameservice

Currenly only used to make names for images, but when latents, conditioning, etc are managed in DB, will do the same for them.

Intended to eventually support custom naming schemes.
This commit is contained in:
psychedelicious 2023-05-27 09:10:02 +10:00 committed by Kent Keirsey
parent 9a796364da
commit 33a0af4637
5 changed files with 44 additions and 24 deletions

View File

@ -5,6 +5,7 @@ import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -67,7 +68,7 @@ class ApiDependencies:
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage( latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents") DiskLatentsStorage(f"{output_folder}/latents")
) )
@ -78,6 +79,7 @@ class ApiDependencies:
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -16,6 +16,7 @@ from pydantic.fields import Field
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
@ -229,6 +230,7 @@ def invoke_cli():
metadata = CoreMetadataService() metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location) image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images") image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService( images = ImageService(
image_record_storage=image_record_storage, image_record_storage=image_record_storage,
@ -236,6 +238,7 @@ def invoke_cli():
metadata=metadata, metadata=metadata,
url=urls, url=urls,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )

View File

@ -103,7 +103,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, filename: str) -> None: def __init__(self, filename: str) -> None:
super().__init__() super().__init__()
self._filename = filename self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False) self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!) # Enable row factory to get rows as dictionaries (must be done before making the cursor!)

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from logging import Logger from logging import Logger
from os import name
from typing import Optional, TYPE_CHECKING, Union from typing import Optional, TYPE_CHECKING, Union
import uuid import uuid
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
@ -31,6 +32,7 @@ from invokeai.app.services.image_file_storage import (
) )
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -120,6 +122,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase metadata: MetadataServiceBase
urls: UrlServiceBase urls: UrlServiceBase
logger: Logger logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"] graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__( def __init__(
@ -129,6 +132,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self.records = image_record_storage self.records = image_record_storage
@ -136,6 +140,7 @@ class ImageServiceDependencies:
self.metadata = metadata self.metadata = metadata
self.urls = url self.urls = url
self.logger = logger self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager self.graph_execution_manager = graph_execution_manager
@ -149,6 +154,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase, metadata: MetadataServiceBase,
url: UrlServiceBase, url: UrlServiceBase,
logger: Logger, logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"], graph_execution_manager: ItemStorageABC["GraphExecutionState"],
): ):
self._services = ImageServiceDependencies( self._services = ImageServiceDependencies(
@ -157,6 +163,7 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
url=url, url=url,
logger=logger, logger=logger,
names=names,
graph_execution_manager=graph_execution_manager, graph_execution_manager=graph_execution_manager,
) )
@ -175,12 +182,7 @@ class ImageService(ImageServiceABC):
if image_category not in ImageCategory: if image_category not in ImageCategory:
raise InvalidImageCategoryException raise InvalidImageCategoryException
image_name = self._create_image_name( image_name = self._services.names.create_image_name()
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
metadata = self._get_metadata(session_id, node_id) metadata = self._get_metadata(session_id, node_id)
@ -261,7 +263,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record") self._services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_type, image_name)
@ -378,21 +379,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem deleting image record and file") self._services.logger.error("Problem deleting image record and file")
raise e raise e
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata( def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]: ) -> Union[ImageMetadata, None]:

View File

@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename