moved model image to edit page, added model_images service

This commit is contained in:
Jennifer Player 2024-03-04 12:22:26 -05:00 committed by Kent Keirsey
parent f6bfe5e6f2
commit c1cdfd132b
6 changed files with 143 additions and 4 deletions

View File

@ -72,6 +72,7 @@ example_model_input = {
"description": "Model description", "description": "Model description",
"vae": None, "vae": None,
"variant": "normal", "variant": "normal",
"image": "blob"
} }
############################################################################## ##############################################################################

View File

@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImagesBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
pass
@abstractmethod
def get_path(self, image_name: str) -> Path:
"""Gets the internal path to an image."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_name: str,
) -> None:
"""Saves an image. Returns a tuple of the image name and created timestamp."""
pass
@abstractmethod
def delete(self, image_name: str) -> None:
"""Deletes an image."""
pass

View File

@ -0,0 +1,20 @@
# TODO: Should these excpetions subclass existing python exceptions?
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
super().__init__(message)

View File

@ -0,0 +1,82 @@
from pathlib import Path
from typing import Union
from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from .model_images_base import ModelImagesBase
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
class DiskImageFileStorage(ModelImagesBase):
"""Stores images on disk"""
__output_folder: Path
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
# Validate required output folders at launch
self.__validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def get(self, image_name: str) -> PILImageType:
try:
image_path = self.get_path(image_name)
image = Image.open(image_path)
return image
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
def save(
self,
image: PILImageType,
image_name: str,
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save(
image_path,
"PNG",
pnginfo=pnginfo,
compress_level=self.__invoker.services.configuration.png_compress_level,
)
except Exception as e:
raise ModelImageFileSaveException from e
def delete(self, image_name: str) -> None:
try:
image_path = self.get_path(image_name)
if image_path.exists():
send2trash(image_path)
except Exception as e:
raise ModelImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str) -> Path:
path = self.__output_folder / image_name
return path
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)

View File

@ -23,6 +23,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models'; import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import ModelImageUpload from './Fields/ModelImageUpload';
import ModelVariantSelect from './Fields/ModelVariantSelect'; import ModelVariantSelect from './Fields/ModelVariantSelect';
import PredictionTypeSelect from './Fields/PredictionTypeSelect'; import PredictionTypeSelect from './Fields/PredictionTypeSelect';
@ -58,6 +59,7 @@ export const ModelEdit = () => {
key: data.key, key: data.key,
body: values, body: values,
}; };
console.log(responseBody, 'responseBody')
updateModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
@ -129,7 +131,8 @@ export const ModelEdit = () => {
</Flex> </Flex>
<Flex flexDir="column" gap={3} mt="4"> <Flex flexDir="column" gap={3} mt="4">
<Flex> <Flex gap="4" alignItems="center">
<ModelImageUpload control={control} name="image" />
<FormControl flexDir="column" alignItems="flex-start" gap={1}> <FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.description')}</FormLabel> <FormLabel>{t('modelManager.description')}</FormLabel>
<Textarea fontSize="md" resize="none" {...register('description')} /> <Textarea fontSize="md" resize="none" {...register('description')} />

File diff suppressed because one or more lines are too long