fetching model image, still not working

This commit is contained in:
Jennifer Player 2024-03-05 22:57:05 -05:00 committed by Kent Keirsey
parent c1cdfd132b
commit 2f6964bfa5
14 changed files with 461 additions and 53 deletions

View File

@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
from ..services.invoker import Invoker from ..services.invoker import Invoker
from ..services.model_images.model_images_default import ModelImagesService
from ..services.model_manager.model_manager_default import ModelManagerService from ..services.model_manager.model_manager_default import ModelManagerService
from ..services.model_records import ModelRecordServiceSQL from ..services.model_records import ModelRecordServiceSQL
from ..services.names.names_default import SimpleNameService from ..services.names.names_default import SimpleNameService
@ -71,6 +72,8 @@ class ApiDependencies:
image_files = DiskImageFileStorage(f"{output_folder}/images") image_files = DiskImageFileStorage(f"{output_folder}/images")
model_images_folder = config.models_path
db = init_db(config=config, logger=logger, image_files=image_files) db = init_db(config=config, logger=logger, image_files=image_files)
configuration = config configuration = config
@ -92,6 +95,7 @@ class ApiDependencies:
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(event_bus=events)
model_images_service = ModelImagesService(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db),
@ -118,6 +122,7 @@ class ApiDependencies:
images=images, images=images,
invocation_cache=invocation_cache, invocation_cache=invocation_cache,
logger=logger, logger=logger,
model_images=model_images_service,
model_manager=model_manager, model_manager=model_manager,
download_queue=download_queue_service, download_queue=download_queue_service,
names=names, names=names,

View File

@ -1,12 +1,16 @@
# Copyright (c) 2023 Lincoln D. Stein # Copyright (c) 2023 Lincoln D. Stein
"""FastAPI route for model configuration records.""" """FastAPI route for model configuration records."""
import io
import pathlib import pathlib
import shutil import shutil
import traceback
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from typing_extensions import Annotated from typing_extensions import Annotated
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"]) model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""
@ -72,7 +79,7 @@ example_model_input = {
"description": "Model description", "description": "Model description",
"vae": None, "vae": None,
"variant": "normal", "variant": "normal",
"image": "blob" "image": "blob",
} }
############################################################################## ##############################################################################
@ -267,6 +274,93 @@ async def update_model_record(
return model_response return model_response
@model_manager_router.get(
"/i/{key}/image",
operation_id="get_model_image",
responses={
200: {
"description": "The model image was fetched successfully",
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
status_code=200,
)
async def get_model_image(
key: str = Path(description="The name of model image file to get"),
) -> FileResponse:
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
if not path:
raise HTTPException(status_code=404)
response = FileResponse(
path,
media_type="image/png",
filename=key + ".png",
content_disposition_type="inline",
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception:
raise HTTPException(status_code=404)
# async def get_model_image(
# key: Annotated[str, Path(description="Unique key of model")],
# ) -> Optional[str]:
# model_images = ApiDependencies.invoker.services.model_images
# try:
# url = model_images.get_url(key)
# if not url:
# return None
# return url
# except Exception:
# raise HTTPException(status_code=404)
@model_manager_router.patch(
"/i/{key}/image",
operation_id="update_model_image",
responses={
200: {
"description": "The model image was updated successfully",
},
400: {"description": "Bad request"},
},
status_code=200,
)
async def update_model_image(
key: Annotated[str, Path(description="Unique key of model")],
image: UploadFile,
) -> None:
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await image.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
raise HTTPException(status_code=415, detail="Failed to read image")
logger = ApiDependencies.invoker.services.logger
model_images = ApiDependencies.invoker.services.model_images
try:
model_images.save(pil_image, key)
logger.info(f"Updated image for model: {key}")
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
return
@model_manager_router.delete( @model_manager_router.delete(
"/i/{key}", "/i/{key}",
operation_id="delete_model", operation_id="delete_model",

View File

@ -25,6 +25,7 @@ if TYPE_CHECKING:
from .images.images_base import ImageServiceABC from .images.images_base import ImageServiceABC
from .invocation_cache.invocation_cache_base import InvocationCacheBase from .invocation_cache.invocation_cache_base import InvocationCacheBase
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
from .model_images.model_images_base import ModelImagesBase
from .model_manager.model_manager_base import ModelManagerServiceBase from .model_manager.model_manager_base import ModelManagerServiceBase
from .names.names_base import NameServiceBase from .names.names_base import NameServiceBase
from .session_processor.session_processor_base import SessionProcessorBase from .session_processor.session_processor_base import SessionProcessorBase
@ -49,6 +50,7 @@ class InvocationServices:
image_files: "ImageFileStorageBase", image_files: "ImageFileStorageBase",
image_records: "ImageRecordStorageBase", image_records: "ImageRecordStorageBase",
logger: "Logger", logger: "Logger",
model_images: "ModelImagesBase",
model_manager: "ModelManagerServiceBase", model_manager: "ModelManagerServiceBase",
download_queue: "DownloadQueueServiceBase", download_queue: "DownloadQueueServiceBase",
performance_statistics: "InvocationStatsServiceBase", performance_statistics: "InvocationStatsServiceBase",
@ -72,6 +74,7 @@ class InvocationServices:
self.image_files = image_files self.image_files = image_files
self.image_records = image_records self.image_records = image_records
self.logger = logger self.logger = logger
self.model_images = model_images
self.model_manager = model_manager self.model_manager = model_manager
self.download_queue = download_queue self.download_queue = download_queue
self.performance_statistics = performance_statistics self.performance_statistics = performance_statistics

View File

@ -3,30 +3,34 @@ from pathlib import Path
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
class ModelImagesBase(ABC): class ModelImagesBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_name: str) -> PILImageType: def get(self, model_key: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves a model image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_name: str) -> Path: def get_path(self, model_key: str) -> Path:
"""Gets the internal path to an image.""" """Gets the internal path to a model image."""
pass
@abstractmethod
def get_url(self, model_key: str) -> str:
"""Gets the url for a model image."""
pass pass
@abstractmethod @abstractmethod
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_name: str, model_key: str,
) -> None: ) -> None:
"""Saves an image. Returns a tuple of the image name and created timestamp.""" """Saves a model image. Returns a tuple of the image name and created timestamp."""
pass pass
@abstractmethod @abstractmethod
def delete(self, image_name: str) -> None: def delete(self, model_key: str) -> None:
"""Deletes an image.""" """Deletes a model image."""
pass pass

View File

@ -2,19 +2,19 @@
class ModelImageFileNotFoundException(Exception): class ModelImageFileNotFoundException(Exception):
"""Raised when an image file is not found in storage.""" """Raised when an image file is not found in storage."""
def __init__(self, message="Image file not found"): def __init__(self, message="Model image file not found"):
super().__init__(message) super().__init__(message)
class ModelImageFileSaveException(Exception): class ModelImageFileSaveException(Exception):
"""Raised when an image cannot be saved.""" """Raised when an image cannot be saved."""
def __init__(self, message="Image file not saved"): def __init__(self, message="Model image file not saved"):
super().__init__(message) super().__init__(message)
class ModelImageFileDeleteException(Exception): class ModelImageFileDeleteException(Exception):
"""Raised when an image cannot be deleted.""" """Raised when an image cannot be deleted."""
def __init__(self, message="Image file not deleted"): def __init__(self, message="Model image file not deleted"):
super().__init__(message) super().__init__(message)

View File

@ -10,25 +10,24 @@ from invokeai.app.services.invoker import Invoker
from .model_images_base import ModelImagesBase from .model_images_base import ModelImagesBase
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
class ModelImagesService(ModelImagesBase):
class DiskImageFileStorage(ModelImagesBase):
"""Stores images on disk""" """Stores images on disk"""
__output_folder: Path __model_images_folder: Path
__invoker: Invoker __invoker: Invoker
def __init__(self, output_folder: Union[str, Path]): def __init__(self, model_images_folder: Union[str, Path]):
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__model_images_folder: Path = model_images_folder if isinstance(model_images_folder, Path) else Path(model_images_folder)
# Validate required output folders at launch # Validate required folders at launch
self.__validate_storage_folders() self.__validate_storage_folders()
def start(self, invoker: Invoker) -> None: def start(self, invoker: Invoker) -> None:
self.__invoker = invoker self.__invoker = invoker
def get(self, image_name: str) -> PILImageType: def get(self, model_key: str) -> PILImageType:
try: try:
image_path = self.get_path(image_name) image_path = self.get_path(model_key + '.png')
image = Image.open(image_path) image = Image.open(image_path)
return image return image
@ -38,17 +37,13 @@ class DiskImageFileStorage(ModelImagesBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_name: str, model_key: str,
) -> None: ) -> None:
try: try:
self.__validate_storage_folders() self.__validate_storage_folders()
image_path = self.get_path(image_name) image_path = self.get_path(model_key + '.png')
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
info_dict = {}
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
image.save( image.save(
image_path, image_path,
"PNG", "PNG",
@ -59,9 +54,17 @@ class DiskImageFileStorage(ModelImagesBase):
except Exception as e: except Exception as e:
raise ModelImageFileSaveException from e raise ModelImageFileSaveException from e
def delete(self, image_name: str) -> None: def get_path(self, model_key: str) -> Path:
path = self.__model_images_folder / model_key
return path
def get_url(self, model_key: str) -> str:
return self.__invoker.services.urls.get_model_image_url(model_key)
def delete(self, model_key: str) -> None:
try: try:
image_path = self.get_path(image_name) image_path = self.get_path(model_key + '.png')
if image_path.exists(): if image_path.exists():
send2trash(image_path) send2trash(image_path)
@ -69,14 +72,8 @@ class DiskImageFileStorage(ModelImagesBase):
except Exception as e: except Exception as e:
raise ModelImageFileDeleteException from e raise ModelImageFileDeleteException from e
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_name: str) -> Path:
path = self.__output_folder / image_name
return path
def __validate_storage_folders(self) -> None: def __validate_storage_folders(self) -> None:
"""Checks if the required output folders exist and create them if they don't""" """Checks if the required folders exist and create them if they don't"""
folders: list[Path] = [self.__output_folder] folders: list[Path] = [self.__model_images_folder]
for folder in folders: for folder in folders:
folder.mkdir(parents=True, exist_ok=True) folder.mkdir(parents=True, exist_ok=True)

View File

@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str: def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@abstractmethod
def get_model_image_url(self, model_key: str) -> str:
"""Gets the URL for a model image"""
pass

View File

@ -15,3 +15,6 @@ class LocalUrlService(UrlServiceBase):
return f"{self._base_url}/images/i/{image_basename}/thumbnail" return f"{self._base_url}/images/i/{image_basename}/thumbnail"
return f"{self._base_url}/images/i/{image_basename}/full" return f"{self._base_url}/images/i/{image_basename}/full"
def get_model_image_url(self, model_key: str) -> str:
return f"{self._base_url}/model_images/{model_key}.png"

View File

@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
default_settings: Optional[ModelDefaultSettings] = Field( default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None description="Default settings for this model", default=None
) )
image: Optional[str] = Field(description="Image to preview model", default=None)
@staticmethod @staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
@ -374,6 +375,10 @@ AnyModelConfig = Annotated[
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelImage(str, Enum):
path: str
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""

View File

@ -0,0 +1,73 @@
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { useCallback } from 'react';
import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types';
import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props);
const onDropAccepted = useCallback(
(files: File[]) => {
const file = files[0];
if (!file) {
return;
}
field.onChange(file);
},
[field]
);
const handleResetControlImage = useCallback(() => {
field.onChange(undefined);
}, [field]);
console.log('field', field);
const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted,
noDrag: true,
multiple: false,
});
if (field.value) {
return (
<Box>
<Image
src={field.value ? URL.createObjectURL(field.value) : 'http://localhost:5173/api/v2/models/i/6b8a6c0d68127ad8db5550f16d9a304b/image'}
objectFit="contain"
maxW="full"
maxH="200px"
borderRadius="base"
/>
<IconButton
onClick={handleResetControlImage}
aria-label="reset this image"
tooltip="reset this image"
icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm"
variant="link"
/>
</Box>
);
}
return (
<>
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
Upload Image
</Button>
<input {...getInputProps()} />
</>
);
};
export default typedMemo(ModelImageUpload);

View File

@ -1,18 +1,22 @@
import { Box, IconButton, Image } from '@invoke-ai/ui-library'; import { Box, IconButton, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo'; import { typedMemo } from 'common/util/typedMemo';
import { useCallback } from 'react'; import { useCallback, useEffect, useMemo } from 'react';
import type { UseControllerProps } from 'react-hook-form'; import type { UseControllerProps } from 'react-hook-form';
import { useController } from 'react-hook-form'; import { useController, useWatch } from 'react-hook-form';
import type { AnyModelConfig } from 'services/api/types'; import type { AnyModelConfig } from 'services/api/types';
import { Button } from '@invoke-ai/ui-library'; import { Button } from '@invoke-ai/ui-library';
import { useDropzone } from 'react-dropzone'; import { useDropzone } from 'react-dropzone';
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi'; import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon'; import { useGetModelImageQuery } from 'services/api/endpoints/models';
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => { const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
const { field } = useController(props); const { field } = useController(props);
const key = useWatch({ control: props.control, name: 'key' });
const { data } = useGetModelImageQuery(key);
const onDropAccepted = useCallback( const onDropAccepted = useCallback(
(files: File[]) => { (files: File[]) => {
const file = files[0]; const file = files[0];
@ -30,8 +34,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
field.onChange(undefined); field.onChange(undefined);
}, [field]); }, [field]);
console.log('field', field);
const { getInputProps, getRootProps } = useDropzone({ const { getInputProps, getRootProps } = useDropzone({
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] }, accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
onDropAccepted, onDropAccepted,
@ -39,11 +41,20 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
multiple: false, multiple: false,
}); });
if (field.value) { const image = useMemo(() => {
console.log(field.value, 'asdf' );
if (field.value) {
return URL.createObjectURL(field.value);
}
return data;
}, [field.value, data]);
if (image) {
return ( return (
<Box> <Box>
<Image <Image
src={URL.createObjectURL(field.value)} src={image}
objectFit="contain" objectFit="contain"
maxW="full" maxW="full"
maxH="200px" maxH="200px"
@ -56,7 +67,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
icon={<PiArrowCounterClockwiseBold size={16} />} icon={<PiArrowCounterClockwiseBold size={16} />}
size="sm" size="sm"
variant="link" variant="link"
// sx={sx}
/> />
</Box> </Box>
); );
@ -73,4 +83,3 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
}; };
export default typedMemo(ModelImageUpload); export default typedMemo(ModelImageUpload);

View File

@ -20,7 +20,11 @@ import type { SubmitHandler } from 'react-hook-form';
import { useForm } from 'react-hook-form'; import { useForm } from 'react-hook-form';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { UpdateModelArg } from 'services/api/endpoints/models'; import type { UpdateModelArg } from 'services/api/endpoints/models';
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models'; import {
useGetModelConfigQuery,
useUpdateModelImageMutation,
useUpdateModelsMutation,
} from 'services/api/endpoints/models';
import BaseModelSelect from './Fields/BaseModelSelect'; import BaseModelSelect from './Fields/BaseModelSelect';
import ModelImageUpload from './Fields/ModelImageUpload'; import ModelImageUpload from './Fields/ModelImageUpload';
@ -32,7 +36,8 @@ export const ModelEdit = () => {
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey); const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken); const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation(); const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
const [updateModelImage] = useUpdateModelImageMutation();
const { t } = useTranslation(); const { t } = useTranslation();
@ -55,11 +60,15 @@ export const ModelEdit = () => {
return; return;
} }
// remove image from body
const image = values.image;
if (values.image) {
delete values.image;
}
const responseBody: UpdateModelArg = { const responseBody: UpdateModelArg = {
key: data.key, key: data.key,
body: values, body: values,
}; };
console.log(responseBody, 'responseBody')
updateModel(responseBody) updateModel(responseBody)
.unwrap() .unwrap()
@ -86,6 +95,33 @@ export const ModelEdit = () => {
) )
); );
}); });
if (image) {
updateModelImage({ key: data.key, image: image })
.unwrap()
.then((payload) => {
reset(payload, { keepDefaultValues: true });
dispatch(setSelectedModelMode('view'));
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdated'),
status: 'success',
})
)
);
})
.catch((_) => {
reset();
dispatch(
addToast(
makeToast({
title: t('modelManager.modelUpdateFailed'),
status: 'error',
})
)
);
});
}
}, },
[dispatch, data?.key, reset, t, updateModel] [dispatch, data?.key, reset, t, updateModel]
); );

View File

@ -23,7 +23,16 @@ export type UpdateModelArg = {
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json']; body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
}; };
export type UpdateModelImageArg = {
key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key'];
image: paths['/api/v2/models/i/{key}/image']['patch']['formData']['content']['multipart/form-data'];
};
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json']; type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
type UpdateModelImageResponse =
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
type GetModelImageResponse =
paths['/api/v2/models/i/{key}/image']['get']['responses']['200']['content']['application/json'];
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json']; type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
@ -144,6 +153,21 @@ export const modelsApi = api.injectEndpoints({
}, },
invalidatesTags: ['Model'], invalidatesTags: ['Model'],
}), }),
getModelImage: build.query<GetModelImageResponse, string>({
query: (key) => buildModelsUrl(`i/${key}/image`),
}),
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
query: ({ key, image }) => {
const formData = new FormData();
formData.append('image', image);
return {
url: buildModelsUrl(`i/${key}/image`),
method: 'PATCH',
body: formData,
};
},
invalidatesTags: ['Model'],
}),
installModel: build.mutation<InstallModelResponse, InstallModelArg>({ installModel: build.mutation<InstallModelResponse, InstallModelArg>({
query: ({ source }) => { query: ({ source }) => {
return { return {
@ -330,7 +354,9 @@ export const {
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,
useDeleteModelsMutation, useDeleteModelsMutation,
useUpdateModelMutation, useUpdateModelsMutation,
useGetModelImageQuery,
useUpdateModelImageMutation,
useInstallModelMutation, useInstallModelMutation,
useConvertModelMutation, useConvertModelMutation,
useSyncModelsMutation, useSyncModelsMutation,

File diff suppressed because one or more lines are too long