fetching model image, still not working

This commit is contained in:
Jennifer Player 2024-03-05 22:57:05 -05:00 committed by Kent Keirsey
parent c1cdfd132b
commit 2f6964bfa5
14 changed files with 461 additions and 53 deletions

View File

@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImagesService
from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService
@ -71,6 +72,8 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config
@ -92,6 +95,7 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImagesService(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
@ -118,6 +122,7 @@ class ApiDependencies:
images=images,
invocation_cache=invocation_cache,
logger=logger,
model_images=model_images_service,
model_manager=model_manager,
download_queue=download_queue_service,
names=names,

View File

@ -1,12 +1,16 @@
# Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records."""
import io
import pathlib
import shutil
import traceback
from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel):
"""Return list of configs."""
@ -72,7 +79,7 @@ example_model_input = {
"description": "Model description",
"vae": None,
"variant": "normal",
"image": "blob"
"image": "blob",
}
##############################################################################
@ -267,6 +274,93 @@ async def update_model_record(
return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
if not path:
raise HTTPException(status_code=404)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
# async def get_model_image(
# key: Annotated[str, Path(description="Unique key of model")],
# ) -> Optional[str]:
# model_images = ApiDependencies.invoker.services.model_images
# try:
# url = model_images.get_url(key)
# if not url:
# return None
# return url
# except Exception:
# raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete(
"/i/{key}",
operation_id="delete_model",

View File

@ -25,6 +25,7 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImagesBase
from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase
@ -49,6 +50,7 @@ class InvocationServices:
image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase",
logger: "Logger",
model_images: "ModelImagesBase",
model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase",
@ -72,6 +74,7 @@ class InvocationServices:
self.image_files = image_files
self.image_records = image_records
self.logger = logger
self.model_images = model_images
self.model_manager = model_manager
self.download_queue = download_queue
self.performance_statistics = performance_statistics

View File

@ -3,30 +3,34 @@ from pathlib import Path
from PIL.Image import Image as PILImageType
class ModelImagesBase(ABC):
"""Low-level service responsible for storing and retrieving image files."""
@abstractmethod
def get(self, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image."""
def get(self, model_key: str) -> PILImageType:
"""Retrieves a model image as PIL Image."""
pass
@abstractmethod
def get_path(self, image_name: str) -> Path:
"""Gets the internal path to an image."""
def get_path(self, model_key: str) -> Path:
"""Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str:
"""Gets the url for a model image."""
pass
@abstractmethod
def save(
self,
image: PILImageType,
image_name: str,
model_key: str,
) -> None:
"""Saves an image. Returns a tuple of the image name and created timestamp."""
"""Saves a model image. Returns a tuple of the image name and created timestamp."""
pass
@abstractmethod
def delete(self, image_name: str) -> None:
"""Deletes an image."""
def delete(self, model_key: str) -> None:
"""Deletes a model image."""
pass

View File

@ -2,19 +2,19 @@
class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"):
def __init__(self, message="Model image file not found"):
super().__init__(message)
class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"):
def __init__(self, message="Model image file not saved"):
super().__init__(message)
class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"):
def __init__(self, message="Model image file not deleted"):
super().__init__(message)

View File

@ -10,25 +10,24 @@ from invokeai.app.services.invoker import Invoker
from .model_images_base import ModelImagesBase
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
class DiskImageFileStorage(ModelImagesBase):
class ModelImagesService(ModelImagesBase):
"""Stores images on disk"""
__output_folder: Path
__model_images_folder: Path
__invoker: Invoker
def __init__(self, output_folder: Union[str, Path]):
def __init__(self, model_images_folder: Union[str, Path]):
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
# Validate required output folders at launch
self.__model_images_folder: Path = model_images_folder if isinstance(model_images_folder, Path) else Path(model_images_folder)
# Validate required folders at launch
self.__validate_storage_folders()
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
def get(self, image_name: str) -> PILImageType:
def get(self, model_key: str) -> PILImageType:
try:
image_path = self.get_path(image_name)
image_path = self.get_path(model_key + '.png')
image = Image.open(image_path)
return image
@ -38,17 +37,13 @@ class DiskImageFileStorage(ModelImagesBase):
def save(
self,
image: PILImageType,
image_name: str,
model_key: str,
) -> None:
try:
self.__validate_storage_folders()
image_path = self.get_path(image_name)
image_path = self.get_path(model_key + '.png')
pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save(
image_path,
"PNG",
@ -59,9 +54,17 @@ class DiskImageFileStorage(ModelImagesBase):
except Exception as e:
raise ModelImageFileSaveException from e
def delete(self, image_name: str) -> None:
def get_path(self, model_key: str) -> Path:
path = self.__model_images_folder / model_key
return path
def get_url(self, model_key: str) -> str:
return self.__invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try:
image_path = self.get_path(image_name)
image_path = self.get_path(model_key + '.png')
if image_path.exists():
send2trash(image_path)
@ -69,14 +72,8 @@ class DiskImageFileStorage(ModelImagesBase):
except Exception as e:
raise ModelImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str) -> Path:
path = self.__output_folder / image_name
return path
def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder]
"""Checks if the required folders exist and create them if they don't"""
folders: list[Path] = [self.__model_images_folder]
for folder in folders:
folder.mkdir(parents=True, exist_ok=True)

View File

@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail."""
pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -15,3 +15,6 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url}/model_images/{model_key}.png"

View File

@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
image: Optional[str] = Field(description="Image to preview model", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -374,6 +375,10 @@ AnyModelConfig = Annotated[
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelImage(str, Enum):
path: str
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -0,0 +1,73 @@
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file) {
return;
}
field.onChange(file);
},
[field]
);
const handleResetControlImage = useCallback(() => {
field.onChange(undefined);
}, [field]);
console.log('field', field);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
if (field.value) {
return (
<Box>
<Image
src={field.value ? URL.createObjectURL(field.value) : 'http://localhost:5173/api/v2/models/i/6b8a6c0d68127ad8db5550f16d9a304b/image'}
objectFit="contain"
maxW="full"
maxH="200px"
borderRadius="base"
/>
<IconButton
onClick={handleResetControlImage}
aria-label="reset this image"
tooltip="reset this image"
icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm"
variant="link"
/>
</Box>
);
}
return (
<>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
Upload Image
</Button>
<input {...getInputProps()} />
</>
);
};
export default typedMemo(ModelImageUpload);

View File

@ -1,18 +1,22 @@
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback } from 'react';
import { useCallback, useEffect, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { useGetModelImageQuery } from 'services/api/endpoints/models';
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
const key = useWatch({ control: props.control, name: 'key' });
const { data } = useGetModelImageQuery(key);
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
@ -30,8 +34,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
field.onChange(undefined);
}, [field]);
console.log('field', field);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted,
@ -39,11 +41,20 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
multiple: false,
});
if (field.value) {
const image = useMemo(() => {
console.log(field.value, 'asdf' );
if (field.value) {
return URL.createObjectURL(field.value);
}
return data;
}, [field.value, data]);
if (image) {
return (
<Box>
<Image
src={URL.createObjectURL(field.value)}
src={image}
objectFit="contain"
maxW="full"
maxH="200px"
@ -56,7 +67,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm"
variant="link"
// sx={sx}
/>
</Box>
);
@ -73,4 +83,3 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
};
export default typedMemo(ModelImageUpload);

View File

@ -20,7 +20,11 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
import {
useGetModelConfigQuery,
useUpdateModelImageMutation,
useUpdateModelsMutation,
} from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect';
import ModelImageUpload from './Fields/ModelImageUpload';
@ -32,7 +36,8 @@ export const ModelEdit = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const [updateModelImage] = useUpdateModelImageMutation();
const { t } = useTranslation();
@ -55,11 +60,15 @@ export const ModelEdit = () => {
return;
}
// remove image from body
const image = values.image;
if (values.image) {
delete values.image;
}
const responseBody: UpdateModelArg = {
key: data.key,
body: values,
};
console.log(responseBody, 'responseBody')
updateModel(responseBody)
.unwrap()
@ -86,6 +95,33 @@ export const ModelEdit = () => {
)
);
});
if (image) {
updateModelImage({ key: data.key, image: image })
.unwrap()
.then((payload) => {
reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
}
},
[dispatch, data?.key, reset, t, updateModel]
);

View File

@ -23,7 +23,16 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
};
export type UpdateModelImageArg = {
key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key'];
image: paths['/api/v2/models/i/{key}/image']['patch']['formData']['content']['multipart/form-data'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelImageResponse =
paths['/api/v2/models/i/{key}/image']['get']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
@ -144,6 +153,21 @@ export const modelsApi = api.injectEndpoints({
},
invalidatesTags: ['Model'],
}),
getModelImage: build.query<GetModelImageResponse, string>({
query: (key) => buildModelsUrl(`i/${key}/image`),
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
query: ({ key, image }) => {
const formData = new FormData();
formData.append('image', image);
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'PATCH',
body: formData,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => {
return {
@ -330,7 +354,9 @@ export const {
useGetTextualInversionModelsQuery,
useGetVaeModelsQuery,
useDeleteModelsMutation,
useUpdateModelMutation,
useUpdateModelsMutation,
useGetModelImageQuery,
useUpdateModelImageMutation,
useInstallModelMutation,
useConvertModelMutation,
useSyncModelsMutation,

File diff suppressed because one or more lines are too long