get model image url from model config, added thumbnail formatting for images

This commit is contained in:
Jennifer Player 2024-03-06 13:15:33 -05:00 committed by Kent Keirsey
parent 239b1e8cc7
commit 8411029d93
11 changed files with 69 additions and 32 deletions

View File

@ -113,6 +113,9 @@ async def list_model_records(
found_models.extend( found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
) )
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models) return ModelsList(models=found_models)
@ -156,6 +159,8 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
try: try:
config: AnyModelConfig = record_store.get_model(key) config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config return config
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@ -292,8 +297,7 @@ async def get_model_image(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
# still need to handle this gracefully when path doesnt exist instead of throwing error path = ApiDependencies.invoker.services.model_images.get_path(key)
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
if not path: if not path:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -12,10 +12,15 @@ class ModelImagesBase(ABC):
pass pass
@abstractmethod @abstractmethod
def get_path(self, model_key: str) -> Path: def get_path(self, model_key: str) -> Path | None:
"""Gets the internal path to a model image.""" """Gets the internal path to a model image."""
pass pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to a model image."""
pass
@abstractmethod @abstractmethod
def save( def save(
self, self,

View File

@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImagesBase from .model_images_base import ModelImagesBase
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
@ -27,9 +28,12 @@ class ModelImagesService(ModelImagesBase):
def get(self, model_key: str) -> PILImageType: def get(self, model_key: str) -> PILImageType:
try: try:
image_path = self.get_path(model_key + '.png') path = self.get_path(model_key)
if not self.validate_path(path):
raise ModelImageFileNotFoundException
image = Image.open(image_path) image = Image.open(path)
return image return image
except FileNotFoundError as e: except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e raise ModelImageFileNotFoundException from e
@ -41,8 +45,12 @@ class ModelImagesService(ModelImagesBase):
) -> None: ) -> None:
try: try:
self.__validate_storage_folders() self.__validate_storage_folders()
image_path = self.get_path(model_key + '.png') logger = self.__invoker.services.logger
pnginfo = PngImagePlugin.PngInfo() image_path = self.__model_images_folder / (model_key + '.png')
logger.debug(f"Saving image for model {model_key} to image_path {image_path}")
pnginfo = PngImagePlugin.PngInfo()
image = make_thumbnail(image, 256)
image.save( image.save(
image_path, image_path,
@ -55,22 +63,33 @@ class ModelImagesService(ModelImagesBase):
raise ModelImageFileSaveException from e raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path: def get_path(self, model_key: str) -> Path:
path = self.__model_images_folder / model_key path = self.__model_images_folder / (model_key + '.png')
return path return path
def get_url(self, model_key: str) -> str: def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self.validate_path(path):
return
return self.__invoker.services.urls.get_model_image_url(model_key) return self.__invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None: def delete(self, model_key: str) -> None:
try: try:
image_path = self.get_path(model_key + '.png') path = self.get_path(model_key)
if image_path.exists(): if not self.validate_path(path):
send2trash(image_path) raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e: except Exception as e:
raise ModelImageFileDeleteException from e raise ModelImageFileDeleteException from e
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None: def __validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't""" """Checks if the required folders exist and create them if they don't"""

View File

@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase): class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"): def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
self._base_url = base_url self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str: def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
@ -17,4 +18,4 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/full" return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str: def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url}/model_images/{model_key}.png" return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -20,6 +20,7 @@ Validation errors will raise an InvalidModelConfigException error.
""" """
from pathlib import Path
import time import time
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
@ -161,7 +162,7 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field( default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None description="Default settings for this model", default=None
) )
image: Optional[str] = Field(description="Image to preview model", default=None) cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:

View File

@ -746,6 +746,7 @@
"delete": "Delete", "delete": "Delete",
"deleteConfig": "Delete Config", "deleteConfig": "Delete Config",
"deleteModel": "Delete Model", "deleteModel": "Delete Model",
"deleteModelImage": "Delete Model Image",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?", "deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.", "deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"description": "Description", "description": "Description",
@ -786,6 +787,10 @@
"modelDeleteFailed": "Failed to delete model", "modelDeleteFailed": "Failed to delete model",
"modelEntryDeleted": "Model Entry Deleted", "modelEntryDeleted": "Model Entry Deleted",
"modelExists": "Model Exists", "modelExists": "Model Exists",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
"modelImageUpdateFailed": "Model Image Update Failed",
"modelLocation": "Model Location", "modelLocation": "Model Location",
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored", "modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
"modelManager": "Model Manager", "modelManager": "Model Manager",
@ -829,7 +834,6 @@
"repo_id": "Repo ID", "repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model", "repoIDValidationMsg": "Online repository of your model",
"repoVariant": "Repo Variant", "repoVariant": "Repo Variant",
"resetImage": "Reset This Image",
"safetensorModels": "SafeTensors", "safetensorModels": "SafeTensors",
"sameFolder": "Same folder", "sameFolder": "Same folder",
"scan": "Scan", "scan": "Scan",

View File

@ -1,22 +1,19 @@
import { Box, Image } from '@invoke-ai/ui-library'; import { Box, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useState } from 'react';
import { buildModelsUrl } from 'services/api/endpoints/models'; import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type Props = { type Props = {
model_key: string; image_url?: string;
}; };
const ModelImage = ({ model_key }: Props) => { const ModelImage = ({ image_url }: Props) => {
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`));
if (!image) return <Box height="50px" width="50px"></Box>; if (!image_url) return <Box height="50px" width="50px" />;
return ( return (
<Image <Image
onError={() => setImage(undefined)} src={image_url}
src={image}
objectFit="cover" objectFit="cover"
objectPosition="50% 50%" objectPosition="50% 50%"
height="50px" height="50px"

View File

@ -74,7 +74,7 @@ const ModelListItem = (props: ModelListItemProps) => {
return ( return (
<Flex gap={2} alignItems="center" w="full"> <Flex gap={2} alignItems="center" w="full">
<ModelImage model_key={model.key} /> <ModelImage image_url={model.cover_image} />
<Flex <Flex
as={Button} as={Button}
isChecked={isSelected} isChecked={isSelected}

View File

@ -6,18 +6,19 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { Button } from '@invoke-ai/ui-library'; import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone'; import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi'; import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import { buildModelsUrl, useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models'; import { useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast'; import { makeToast } from 'features/system/util/makeToast';
type Props = { type Props = {
model_key: string; model_key: string | null;
model_image: string | null;
}; };
const ModelImageUpload = ({ model_key }: Props) => { const ModelImageUpload = ({ model_key, model_image }: Props) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`)); const [image, setImage] = useState<string | null>(model_image);
const { t } = useTranslation(); const { t } = useTranslation();
const [updateModelImage] = useUpdateModelImageMutation(); const [updateModelImage] = useUpdateModelImageMutation();
@ -33,7 +34,7 @@ type Props = {
setImage(URL.createObjectURL(file)); setImage(URL.createObjectURL(file));
updateModelImage({ key: model_key, image: image }) updateModelImage({ key: model_key, image: file })
.unwrap() .unwrap()
.then(() => { .then(() => {
dispatch( dispatch(
@ -60,6 +61,9 @@ type Props = {
); );
const handleResetImage = useCallback(() => { const handleResetImage = useCallback(() => {
if (!model_key) {
return;
}
setImage(undefined); setImage(undefined);
deleteModelImage(model_key) deleteModelImage(model_key)
.unwrap() .unwrap()

View File

@ -41,7 +41,7 @@ export const Model = () => {
<ModelAttrView label="Description" value={data.description} /> <ModelAttrView label="Description" value={data.description} />
</Box> </Box>
</Flex> </Flex>
<ModelImageUpload model_key={selectedModelKey || ''} /> <ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
</Flex> </Flex>
<Tabs mt="4" h="100%"> <Tabs mt="4" h="100%">

View File

@ -138,7 +138,7 @@ const buildTransformResponse =
* buildModelsUrl('some-path') * buildModelsUrl('some-path')
* // '/api/v1/models/some-path' * // '/api/v1/models/some-path'
*/ */
export const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`); const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({ export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({ endpoints: (build) => ({
@ -162,6 +162,7 @@ export const modelsApi = api.injectEndpoints({
body: formData, body: formData,
}; };
}, },
invalidatesTags: ['Model'],
}), }),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => { query: ({ source }) => {
@ -189,6 +190,7 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE', method: 'DELETE',
}; };
}, },
invalidatesTags: ['Model'],
}), }),
getModelImage: build.query<string, string>({ getModelImage: build.query<string, string>({
query: (key) => buildModelsUrl(`i/${key}/image`) query: (key) => buildModelsUrl(`i/${key}/image`)