tidy(mm): ModelImages service

This commit is contained in:
psychedelicious 2024-03-07 11:31:39 +11:00 committed by Kent Keirsey
parent 347f1fd0b7
commit 9b48029bc9
4 changed files with 26 additions and 46 deletions

View File

@ -25,7 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImagesService from ..services.model_images.model_images_default import ModelImageFileStorageDisk
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
@ -95,7 +95,7 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImagesService(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db),

View File

@ -25,7 +25,7 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImagesBase from .model_images.model_images_base import ModelImageFileStorageBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
@ -50,7 +50,7 @@ class InvocationServices:
image_files: "ImageFileStorageBase", image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase", image_records: "ImageRecordStorageBase",
logger: "Logger", logger: "Logger",
model_images: "ModelImagesBase", model_images: "ModelImageFileStorageBase",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase", download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase", performance_statistics: "InvocationStatsServiceBase",

View File

@ -4,7 +4,7 @@ from pathlib import Path
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
class ModelImagesBase(ABC): class ModelImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
@ -23,11 +23,7 @@ class ModelImagesBase(ABC):
pass pass
@abstractmethod @abstractmethod
def save( def save(self, image: PILImageType, model_key: str) -> None:
self,
image: PILImageType,
model_key: str,
) -> None:
"""Saves a model image.""" """Saves a model image."""
pass pass

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Union
from PIL import Image from PIL import Image
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
@ -8,7 +7,7 @@ from send2trash import send2trash
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImagesBase from .model_images_base import ModelImageFileStorageBase
from .model_images_common import ( from .model_images_common import (
ModelImageFileDeleteException, ModelImageFileDeleteException,
ModelImageFileNotFoundException, ModelImageFileNotFoundException,
@ -16,66 +15,54 @@ from .model_images_common import (
) )
class ModelImagesService(ModelImagesBase): class ModelImageFileStorageDisk(ModelImageFileStorageBase):
"""Stores images on disk""" """Stores images on disk"""
__model_images_folder: Path def __init__(self, model_images_folder: Path):
__invoker: Invoker self._model_images_folder = model_images_folder
self._validate_storage_folders()
def __init__(self, model_images_folder: Union[str, Path]):
self.__model_images_folder: Path = (
model_images_folder if isinstance(model_images_folder, Path) else Path(model_images_folder)
)
# Validate required folders at launch
self.__validate_storage_folders()
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self._invoker = invoker
def get(self, model_key: str) -> PILImageType: def get(self, model_key: str) -> PILImageType:
try: try:
path = self.get_path(model_key) path = self.get_path(model_key)
if not self.validate_path(path): if not self._validate_path(path):
raise ModelImageFileNotFoundException raise ModelImageFileNotFoundException
image = Image.open(path) return Image.open(path)
return image
except FileNotFoundError as e: except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e raise ModelImageFileNotFoundException from e
def save( def save(self, image: PILImageType, model_key: str) -> None:
self,
image: PILImageType,
model_key: str,
) -> None:
try: try:
self.__validate_storage_folders() self._validate_storage_folders()
image_path = self.__model_images_folder / (model_key + ".webp") image_path = self._model_images_folder / (model_key + ".webp")
image = make_thumbnail(image, 256) thumbnail = make_thumbnail(image, 256)
thumbnail.save(image_path, format="webp")
image.save(image_path, format="webp")
except Exception as e: except Exception as e:
raise ModelImageFileSaveException from e raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path: def get_path(self, model_key: str) -> Path:
path = self.__model_images_folder / (model_key + ".webp") path = self._model_images_folder / (model_key + ".webp")
return path return path
def get_url(self, model_key: str) -> str | None: def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key) path = self.get_path(model_key)
if not self.validate_path(path): if not self._validate_path(path):
return return
return self.__invoker.services.urls.get_model_image_url(model_key) return self._invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None: def delete(self, model_key: str) -> None:
try: try:
path = self.get_path(model_key) path = self.get_path(model_key)
if not self.validate_path(path): if not self._validate_path(path):
raise ModelImageFileNotFoundException raise ModelImageFileNotFoundException
send2trash(path) send2trash(path)
@ -83,13 +70,10 @@ class ModelImagesService(ModelImagesBase):
except Exception as e: except Exception as e:
raise ModelImageFileDeleteException from e raise ModelImageFileDeleteException from e
def validate_path(self, path: Union[str, Path]) -> bool: def _validate_path(self, path: Path) -> bool:
"""Validates the path given for an image.""" """Validates the path given for an image."""
path = path if isinstance(path, Path) else Path(path)
return path.exists() return path.exists()
def __validate_storage_folders(self) -> None: def _validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't""" """Checks if the required folders exist and create them if they don't"""
folders: list[Path] = [self.__model_images_folder] self._model_images_folder.mkdir(parents=True, exist_ok=True)
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)