mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
get model image url from model config, added thumbnail formatting for images
This commit is contained in:
parent
239b1e8cc7
commit
8411029d93
@ -113,6 +113,9 @@ async def list_model_records(
|
|||||||
found_models.extend(
|
found_models.extend(
|
||||||
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||||
)
|
)
|
||||||
|
for model in found_models:
|
||||||
|
cover_image = ApiDependencies.invoker.services.model_images.get_url(model.key)
|
||||||
|
model.cover_image = cover_image
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
@ -156,6 +159,8 @@ async def get_model_record(
|
|||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
try:
|
try:
|
||||||
config: AnyModelConfig = record_store.get_model(key)
|
config: AnyModelConfig = record_store.get_model(key)
|
||||||
|
cover_image = ApiDependencies.invoker.services.model_images.get_url(key)
|
||||||
|
config.cover_image = cover_image
|
||||||
return config
|
return config
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
@ -292,8 +297,7 @@ async def get_model_image(
|
|||||||
"""Gets a full-resolution image file"""
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# still need to handle this gracefully when path doesnt exist instead of throwing error
|
path = ApiDependencies.invoker.services.model_images.get_path(key)
|
||||||
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
|
|
||||||
|
|
||||||
if not path:
|
if not path:
|
||||||
raise HTTPException(status_code=404)
|
raise HTTPException(status_code=404)
|
||||||
|
@ -12,10 +12,15 @@ class ModelImagesBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, model_key: str) -> Path:
|
def get_path(self, model_key: str) -> Path | None:
|
||||||
"""Gets the internal path to a model image."""
|
"""Gets the internal path to a model image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, model_key: str) -> str | None:
|
||||||
|
"""Gets the URL to a model image."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
|
|||||||
from send2trash import send2trash
|
from send2trash import send2trash
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
|
from invokeai.app.util.thumbnails import make_thumbnail
|
||||||
|
|
||||||
from .model_images_base import ModelImagesBase
|
from .model_images_base import ModelImagesBase
|
||||||
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
|
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
|
||||||
@ -27,9 +28,12 @@ class ModelImagesService(ModelImagesBase):
|
|||||||
|
|
||||||
def get(self, model_key: str) -> PILImageType:
|
def get(self, model_key: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(model_key + '.png')
|
path = self.get_path(model_key)
|
||||||
|
|
||||||
|
if not self.validate_path(path):
|
||||||
|
raise ModelImageFileNotFoundException
|
||||||
|
|
||||||
image = Image.open(image_path)
|
image = Image.open(path)
|
||||||
return image
|
return image
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise ModelImageFileNotFoundException from e
|
raise ModelImageFileNotFoundException from e
|
||||||
@ -41,8 +45,12 @@ class ModelImagesService(ModelImagesBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
image_path = self.get_path(model_key + '.png')
|
logger = self.__invoker.services.logger
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
image_path = self.__model_images_folder / (model_key + '.png')
|
||||||
|
logger.debug(f"Saving image for model {model_key} to image_path {image_path}")
|
||||||
|
|
||||||
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
|
image = make_thumbnail(image, 256)
|
||||||
|
|
||||||
image.save(
|
image.save(
|
||||||
image_path,
|
image_path,
|
||||||
@ -55,22 +63,33 @@ class ModelImagesService(ModelImagesBase):
|
|||||||
raise ModelImageFileSaveException from e
|
raise ModelImageFileSaveException from e
|
||||||
|
|
||||||
def get_path(self, model_key: str) -> Path:
|
def get_path(self, model_key: str) -> Path:
|
||||||
path = self.__model_images_folder / model_key
|
path = self.__model_images_folder / (model_key + '.png')
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def get_url(self, model_key: str) -> str:
|
def get_url(self, model_key: str) -> str | None:
|
||||||
|
path = self.get_path(model_key)
|
||||||
|
if not self.validate_path(path):
|
||||||
|
return
|
||||||
|
|
||||||
return self.__invoker.services.urls.get_model_image_url(model_key)
|
return self.__invoker.services.urls.get_model_image_url(model_key)
|
||||||
|
|
||||||
def delete(self, model_key: str) -> None:
|
def delete(self, model_key: str) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(model_key + '.png')
|
path = self.get_path(model_key)
|
||||||
|
|
||||||
if image_path.exists():
|
if not self.validate_path(path):
|
||||||
send2trash(image_path)
|
raise ModelImageFileNotFoundException
|
||||||
|
|
||||||
|
send2trash(path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ModelImageFileDeleteException from e
|
raise ModelImageFileDeleteException from e
|
||||||
|
|
||||||
|
def validate_path(self, path: Union[str, Path]) -> bool:
|
||||||
|
"""Validates the path given for an image."""
|
||||||
|
path = path if isinstance(path, Path) else Path(path)
|
||||||
|
return path.exists()
|
||||||
|
|
||||||
def __validate_storage_folders(self) -> None:
|
def __validate_storage_folders(self) -> None:
|
||||||
"""Checks if the required folders exist and create them if they don't"""
|
"""Checks if the required folders exist and create them if they don't"""
|
||||||
|
@ -4,8 +4,9 @@ from .urls_base import UrlServiceBase
|
|||||||
|
|
||||||
|
|
||||||
class LocalUrlService(UrlServiceBase):
|
class LocalUrlService(UrlServiceBase):
|
||||||
def __init__(self, base_url: str = "api/v1"):
|
def __init__(self, base_url: str = "api/v1", base_url_v2: str = "api/v2"):
|
||||||
self._base_url = base_url
|
self._base_url = base_url
|
||||||
|
self._base_url_v2 = base_url_v2
|
||||||
|
|
||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
image_basename = os.path.basename(image_name)
|
image_basename = os.path.basename(image_name)
|
||||||
@ -17,4 +18,4 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||||
|
|
||||||
def get_model_image_url(self, model_key: str) -> str:
|
def get_model_image_url(self, model_key: str) -> str:
|
||||||
return f"{self._base_url}/model_images/{model_key}.png"
|
return f"{self._base_url_v2}/models/i/{model_key}/image"
|
@ -20,6 +20,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional, Type, Union
|
from typing import Literal, Optional, Type, Union
|
||||||
@ -161,7 +162,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
image: Optional[str] = Field(description="Image to preview model", default=None)
|
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
@ -746,6 +746,7 @@
|
|||||||
"delete": "Delete",
|
"delete": "Delete",
|
||||||
"deleteConfig": "Delete Config",
|
"deleteConfig": "Delete Config",
|
||||||
"deleteModel": "Delete Model",
|
"deleteModel": "Delete Model",
|
||||||
|
"deleteModelImage": "Delete Model Image",
|
||||||
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
"deleteMsg1": "Are you sure you want to delete this model from InvokeAI?",
|
||||||
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
"deleteMsg2": "This WILL delete the model from disk if it is in the InvokeAI root folder. If you are using a custom location, then the model WILL NOT be deleted from disk.",
|
||||||
"description": "Description",
|
"description": "Description",
|
||||||
@ -786,6 +787,10 @@
|
|||||||
"modelDeleteFailed": "Failed to delete model",
|
"modelDeleteFailed": "Failed to delete model",
|
||||||
"modelEntryDeleted": "Model Entry Deleted",
|
"modelEntryDeleted": "Model Entry Deleted",
|
||||||
"modelExists": "Model Exists",
|
"modelExists": "Model Exists",
|
||||||
|
"modelImageDeleted": "Model Image Deleted",
|
||||||
|
"modelImageDeleteFailed": "Model Image Delete Failed",
|
||||||
|
"modelImageUpdated": "Model Image Updated",
|
||||||
|
"modelImageUpdateFailed": "Model Image Update Failed",
|
||||||
"modelLocation": "Model Location",
|
"modelLocation": "Model Location",
|
||||||
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
"modelLocationValidationMsg": "Provide the path to a local folder where your Diffusers Model is stored",
|
||||||
"modelManager": "Model Manager",
|
"modelManager": "Model Manager",
|
||||||
@ -829,7 +834,6 @@
|
|||||||
"repo_id": "Repo ID",
|
"repo_id": "Repo ID",
|
||||||
"repoIDValidationMsg": "Online repository of your model",
|
"repoIDValidationMsg": "Online repository of your model",
|
||||||
"repoVariant": "Repo Variant",
|
"repoVariant": "Repo Variant",
|
||||||
"resetImage": "Reset This Image",
|
|
||||||
"safetensorModels": "SafeTensors",
|
"safetensorModels": "SafeTensors",
|
||||||
"sameFolder": "Same folder",
|
"sameFolder": "Same folder",
|
||||||
"scan": "Scan",
|
"scan": "Scan",
|
||||||
|
@ -1,22 +1,19 @@
|
|||||||
import { Box, Image } from '@invoke-ai/ui-library';
|
import { Box, Image } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useState } from 'react';
|
|
||||||
|
|
||||||
import { buildModelsUrl } from 'services/api/endpoints/models';
|
import { useGetModelConfigQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model_key: string;
|
image_url?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelImage = ({ model_key }: Props) => {
|
const ModelImage = ({ image_url }: Props) => {
|
||||||
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`));
|
|
||||||
|
|
||||||
if (!image) return <Box height="50px" width="50px"></Box>;
|
if (!image_url) return <Box height="50px" width="50px" />;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Image
|
<Image
|
||||||
onError={() => setImage(undefined)}
|
src={image_url}
|
||||||
src={image}
|
|
||||||
objectFit="cover"
|
objectFit="cover"
|
||||||
objectPosition="50% 50%"
|
objectPosition="50% 50%"
|
||||||
height="50px"
|
height="50px"
|
||||||
|
@ -74,7 +74,7 @@ const ModelListItem = (props: ModelListItemProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex gap={2} alignItems="center" w="full">
|
<Flex gap={2} alignItems="center" w="full">
|
||||||
<ModelImage model_key={model.key} />
|
<ModelImage image_url={model.cover_image} />
|
||||||
<Flex
|
<Flex
|
||||||
as={Button}
|
as={Button}
|
||||||
isChecked={isSelected}
|
isChecked={isSelected}
|
||||||
|
@ -6,18 +6,19 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
|||||||
import { Button } from '@invoke-ai/ui-library';
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
import { useDropzone } from 'react-dropzone';
|
import { useDropzone } from 'react-dropzone';
|
||||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||||
import { buildModelsUrl, useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models';
|
import { useUpdateModelImageMutation, useDeleteModelImageMutation } from 'services/api/endpoints/models';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
model_key: string;
|
model_key: string | null;
|
||||||
|
model_image: string | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
const ModelImageUpload = ({ model_key }: Props) => {
|
const ModelImageUpload = ({ model_key, model_image }: Props) => {
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const [image, setImage] = useState<string | undefined>(buildModelsUrl(`i/${model_key}/image`));
|
const [image, setImage] = useState<string | null>(model_image);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const [updateModelImage] = useUpdateModelImageMutation();
|
const [updateModelImage] = useUpdateModelImageMutation();
|
||||||
@ -33,7 +34,7 @@ type Props = {
|
|||||||
|
|
||||||
setImage(URL.createObjectURL(file));
|
setImage(URL.createObjectURL(file));
|
||||||
|
|
||||||
updateModelImage({ key: model_key, image: image })
|
updateModelImage({ key: model_key, image: file })
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then(() => {
|
.then(() => {
|
||||||
dispatch(
|
dispatch(
|
||||||
@ -60,6 +61,9 @@ type Props = {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const handleResetImage = useCallback(() => {
|
const handleResetImage = useCallback(() => {
|
||||||
|
if (!model_key) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
setImage(undefined);
|
setImage(undefined);
|
||||||
deleteModelImage(model_key)
|
deleteModelImage(model_key)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -41,7 +41,7 @@ export const Model = () => {
|
|||||||
<ModelAttrView label="Description" value={data.description} />
|
<ModelAttrView label="Description" value={data.description} />
|
||||||
</Box>
|
</Box>
|
||||||
</Flex>
|
</Flex>
|
||||||
<ModelImageUpload model_key={selectedModelKey || ''} />
|
<ModelImageUpload model_key={selectedModelKey} model_image={data.cover_image} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<Tabs mt="4" h="100%">
|
<Tabs mt="4" h="100%">
|
||||||
|
@ -138,7 +138,7 @@ const buildTransformResponse =
|
|||||||
* buildModelsUrl('some-path')
|
* buildModelsUrl('some-path')
|
||||||
* // '/api/v1/models/some-path'
|
* // '/api/v1/models/some-path'
|
||||||
*/
|
*/
|
||||||
export const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
const buildModelsUrl = (path: string = '') => buildV2Url(`models/${path}`);
|
||||||
|
|
||||||
export const modelsApi = api.injectEndpoints({
|
export const modelsApi = api.injectEndpoints({
|
||||||
endpoints: (build) => ({
|
endpoints: (build) => ({
|
||||||
@ -162,6 +162,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
body: formData,
|
body: formData,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||||
query: ({ source }) => {
|
query: ({ source }) => {
|
||||||
@ -189,6 +190,7 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
getModelImage: build.query<string, string>({
|
getModelImage: build.query<string, string>({
|
||||||
query: (key) => buildModelsUrl(`i/${key}/image`)
|
query: (key) => buildModelsUrl(`i/${key}/image`)
|
||||||
|
Loading…
Reference in New Issue
Block a user