get model image url from model config, added thumbnail formatting for images

This commit is contained in:
Jennifer Player 2024-03-06 13:15:33 -05:00 committed by Kent Keirsey
parent 239b1e8cc7
commit 8411029d93
11 changed files with 69 additions and 32 deletions

View File

@ -113,6 +113,9 @@ async def list_model_records(
found_models.extend(
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
)
for model in found_models:
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
model.cover_image = cover_image
return ModelsList(models=found_models)
@ -156,6 +159,8 @@ async def get_model_record(
record_store = ApiDependencies.invoker.services.model_manager.store
try:
config: AnyModelConfig = record_store.get_model(key)
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
config.cover_image = cover_image
return config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))
@ -292,8 +297,7 @@ async def get_model_image(
"""Gets a full-resolution image file"""
try:
# still need to handle this gracefully when path doesnt exist instead of throwing error
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
path = ApiDependencies.invoker.services.model_images.get_path(key)
if not path:
raise HTTPException(status_code=404)

View File

@ -12,10 +12,15 @@ class ModelImagesBase(ABC):
pass
@abstractmethod
def get_path(self, model_key: str) -> Path:
def get_path(self, model_key: str) -> Path | None:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str | None:
"""Gets the URL to a model image."""
pass
@abstractmethod
def save(
self,

View File

@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.invoker import Invoker
from invokeai.app.util.thumbnails import make_thumbnail
from .model_images_base import ModelImagesBase
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
@ -27,9 +28,12 @@ class ModelImagesService(ModelImagesBase):
def get(self, model_key: str) -> PILImageType:
try:
image_path = self.get_path(model_key + '.png')
path = self.get_path(model_key)
if not self.validate_path(path):
raise ModelImageFileNotFoundException
image = Image.open(image_path)
image = Image.open(path)
return image
except FileNotFoundError as e:
raise ModelImageFileNotFoundException from e
@ -41,8 +45,12 @@ class ModelImagesService(ModelImagesBase):
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(model_key + '.png')
pnginfo = PngImagePlugin.PngInfo()
logger = self.__invoker.services.logger
image_path = self.__model_images_folder / (model_key + '.png')
logger.debug(f"Saving image for model {model_key} to image_path {image_path}")
pnginfo = PngImagePlugin.PngInfo()
image = make_thumbnail(image, 256)
image.save(
image_path,
@ -55,22 +63,33 @@ class ModelImagesService(ModelImagesBase):
raise ModelImageFileSaveException from e
def get_path(self, model_key: str) -> Path:
path = self.__model_images_folder / model_key
path = self.__model_images_folder / (model_key + '.png')
return path
def get_url(self, model_key: str) -> str:
def get_url(self, model_key: str) -> str | None:
path = self.get_path(model_key)
if not self.validate_path(path):
return
return self.__invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try:
image_path = self.get_path(model_key + '.png')
path = self.get_path(model_key)
if image_path.exists():
send2trash(image_path)
if not self.validate_path(path):
raise ModelImageFileNotFoundException
send2trash(path)
except Exception as e:
raise ModelImageFileDeleteException from e
def validate_path(self, path: Union[str, Path]) -> bool:
"""Validates the path given for an image."""
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def __validate_storage_folders(self) -> None:
"""Checks if the required folders exist and create them if they don't"""

View File

@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
class LocalUrlService(UrlServiceBase):
def __init__(self, base_url: str = "api/v1"):
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
self._base_url = base_url
self._base_url_v2 = base_url_v2
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
image_basename = os.path.basename(image_name)
@ -17,4 +18,4 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url}/model_images/{model_key}.png"
return f"{self._base_url_v2}/models/i/{model_key}/image"

View File

@ -20,6 +20,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
from pathlib import Path
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
@ -161,7 +162,7 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
image: Optional[str] = Field(description="Image to preview model", default=None)
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:

View File

@ -746,6 +746,7 @@
"delete": "Delete",
"deleteConfig": "Delete Config",
"deleteModel": "Delete Model",
"deleteModelImage": "Delete Model Image",
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
"description": "Description",
@ -786,6 +787,10 @@
"modelDeleteFailed": "Failed to delete model",
"modelEntryDeleted": "Model Entry Deleted",
"modelExists": "Model Exists",
"modelImageDeleted": "Model Image Deleted",
"modelImageDeleteFailed": "Model Image Delete Failed",
"modelImageUpdated": "Model Image Updated",
"modelImageUpdateFailed": "Model Image Update Failed",
"modelLocation": "Model Location",
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
"modelManager": "Model Manager",
@ -829,7 +834,6 @@
"repo_id": "Repo ID",
"repoIDValidationMsg": "Online repository of your model",
"repoVariant": "Repo Variant",
"resetImage": "Reset This Image",
"safetensorModels": "SafeTensors",
"sameFolder": "Same folder",
"scan": "Scan",

View File

@ -1,22 +1,19 @@
import { Box, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useState } from 'react';
import { buildModelsUrl } from 'services/api/endpoints/models';
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
type Props = {
model_key: string;
image_url?: string;
};
const ModelImage = ({ model_key }: Props) => {
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`));
const ModelImage = ({ image_url }: Props) => {
if (!image) return <Box height="50px" width="50px"></Box>;
if (!image_url) return <Box height="50px" width="50px" />;
return (
<Image
onError={() => setImage(undefined)}
src={image}
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height="50px"

View File

@ -74,7 +74,7 @@ const ModelListItem = (props: ModelListItemProps) => {
return (
<Flex gap={2} alignItems="center" w="full">
<ModelImage model_key={model.key} />
<ModelImage image_url={model.cover_image} />
<Flex
as={Button}
isChecked={isSelected}

View File

@ -6,18 +6,19 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import { buildModelsUrl, useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models';
import { useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models';
import { useTranslation } from 'react-i18next';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
type Props = {
model_key: string;
model_key: string | null;
model_image: string | null;
};
const ModelImageUpload = ({ model_key }: Props) => {
const ModelImageUpload = ({ model_key, model_image }: Props) => {
const dispatch = useAppDispatch();
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`));
const [image, setImage] = useState<string | null>(model_image);
const { t } = useTranslation();
const [updateModelImage] = useUpdateModelImageMutation();
@ -33,7 +34,7 @@ type Props = {
setImage(URL.createObjectURL(file));
updateModelImage({ key: model_key, image: image })
updateModelImage({ key: model_key, image: file })
.unwrap()
.then(() => {
dispatch(
@ -60,6 +61,9 @@ type Props = {
);
const handleResetImage = useCallback(() => {
if (!model_key) {
return;
}
setImage(undefined);
deleteModelImage(model_key)
.unwrap()

View File

@ -41,7 +41,7 @@ export const Model = () => {
<ModelAttrView label="Description" value={data.description} />
</Box>
</Flex>
<ModelImageUpload model_key={selectedModelKey || ''} />
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
</Flex>
<Tabs mt="4" h="100%">

View File

@ -138,7 +138,7 @@ const buildTransformResponse =
* buildModelsUrl('some-path')
* // '/api/v1/models/some-path'
*/
export const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
export const modelsApi = api.injectEndpoints({
endpoints: (build) => ({
@ -162,6 +162,7 @@ export const modelsApi = api.injectEndpoints({
body: formData,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => {
@ -189,6 +190,7 @@ export const modelsApi = api.injectEndpoints({
method: 'DELETE',
};
},
invalidatesTags: ['Model'],
}),
getModelImage: build.query<string, string>({
query: (key) => buildModelsUrl(`i/${key}/image`)