feat(nodes): add nameservice

Currenly only used to make names for images, but when latents, conditioning, etc are managed in DB, will do the same for them.

Intended to eventually support custom naming schemes.
This commit is contained in:
psychedelicious 2023-05-27 09:10:02 +10:00 committed by Kent Keirsey
parent 9a796364da
commit 33a0af4637
5 changed files with 44 additions and 24 deletions

View File

@ -5,6 +5,7 @@ import os
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
@ -67,7 +68,7 @@ class ApiDependencies:
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
@ -78,6 +79,7 @@ class ApiDependencies:
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)

View File

@ -16,6 +16,7 @@ from pydantic.fields import Field
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
@ -229,6 +230,7 @@ def invoke_cli():
metadata = CoreMetadataService()
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
images = ImageService(
image_record_storage=image_record_storage,
@ -236,6 +238,7 @@ def invoke_cli():
metadata=metadata,
url=urls,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)

View File

@ -103,7 +103,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, filename: str) -> None:
super().__init__()
self._filename = filename
self._conn = sqlite3.connect(filename, check_same_thread=False)
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from os import name
from typing import Optional, TYPE_CHECKING, Union
import uuid
from PIL.Image import Image as PILImageType
@ -31,6 +32,7 @@ from invokeai.app.services.image_file_storage import (
)
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
if TYPE_CHECKING:
@ -120,6 +122,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase
urls: UrlServiceBase
logger: Logger
names: NameServiceBase
graph_execution_manager: ItemStorageABC["GraphExecutionState"]
def __init__(
@ -129,6 +132,7 @@ class ImageServiceDependencies:
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self.records = image_record_storage
@ -136,6 +140,7 @@ class ImageServiceDependencies:
self.metadata = metadata
self.urls = url
self.logger = logger
self.names = names
self.graph_execution_manager = graph_execution_manager
@ -149,6 +154,7 @@ class ImageService(ImageServiceABC):
metadata: MetadataServiceBase,
url: UrlServiceBase,
logger: Logger,
names: NameServiceBase,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
):
self._services = ImageServiceDependencies(
@ -157,6 +163,7 @@ class ImageService(ImageServiceABC):
metadata=metadata,
url=url,
logger=logger,
names=names,
graph_execution_manager=graph_execution_manager,
)
@ -175,12 +182,7 @@ class ImageService(ImageServiceABC):
if image_category not in ImageCategory:
raise InvalidImageCategoryException
image_name = self._create_image_name(
image_type=image_type,
image_category=image_category,
node_id=node_id,
session_id=session_id,
)
image_name = self._services.names.create_image_name()
metadata = self._get_metadata(session_id, node_id)
@ -260,7 +262,6 @@ class ImageService(ImageServiceABC):
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
@ -378,21 +379,6 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem deleting image record and file")
raise e
def _create_image_name(
self,
image_type: ImageType,
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Create a unique image name."""
uuid_str = str(uuid.uuid4())
if node_id is not None and session_id is not None:
return f"{image_type.value}_{image_category.value}_{session_id}_{node_id}_{uuid_str}.png"
return f"{image_type.value}_{image_category.value}_{uuid_str}.png"
def _get_metadata(
self, session_id: Optional[str] = None, node_id: Optional[str] = None
) -> Union[ImageMetadata, None]:

View File

@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from enum import Enum, EnumMeta
import uuid
class ResourceType(str, Enum, metaclass=EnumMeta):
"""Enum for resource types."""
IMAGE = "image"
LATENT = "latent"
class NameServiceBase(ABC):
"""Low-level service responsible for naming resources (images, latents, etc)."""
# TODO: Add customizable naming schemes
@abstractmethod
def create_image_name(self) -> str:
"""Creates a name for an image."""
pass
class SimpleNameService(NameServiceBase):
"""Creates image names from UUIDs."""
# TODO: Add customizable naming schemes
def create_image_name(self) -> str:
uuid_str = str(uuid.uuid4())
filename = f"{uuid_str}.png"
return filename