mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fetching model image, still not working
This commit is contained in:
parent
c1cdfd132b
commit
2f6964bfa5
@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.model_images.model_images_default import ModelImagesService
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
@ -71,6 +72,8 @@ class ApiDependencies:
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
model_images_folder = config.models_path
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
configuration = config
|
||||
@ -92,6 +95,7 @@ class ApiDependencies:
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_images_service = ModelImagesService(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=ModelRecordServiceSQL(db=db),
|
||||
@ -118,6 +122,7 @@ class ApiDependencies:
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
logger=logger,
|
||||
model_images=model_images_service,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
|
@ -1,12 +1,16 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein
|
||||
"""FastAPI route for model configuration records."""
|
||||
|
||||
import io
|
||||
import pathlib
|
||||
import shutil
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
|
||||
|
||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||
|
||||
# images are immutable; set a high max-age
|
||||
IMAGE_MAX_AGE = 31536000
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
"""Return list of configs."""
|
||||
@ -72,7 +79,7 @@ example_model_input = {
|
||||
"description": "Model description",
|
||||
"vae": None,
|
||||
"variant": "normal",
|
||||
"image": "blob"
|
||||
"image": "blob",
|
||||
}
|
||||
|
||||
##############################################################################
|
||||
@ -267,6 +274,93 @@ async def update_model_record(
|
||||
return model_response
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/image",
|
||||
operation_id="get_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was fetched successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def get_model_image(
|
||||
key: str = Path(description="The name of model image file to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a full-resolution image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="image/png",
|
||||
filename=key + ".png",
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
# async def get_model_image(
|
||||
# key: Annotated[str, Path(description="Unique key of model")],
|
||||
# ) -> Optional[str]:
|
||||
# model_images = ApiDependencies.invoker.services.model_images
|
||||
# try:
|
||||
# url = model_images.get_url(key)
|
||||
|
||||
# if not url:
|
||||
# return None
|
||||
|
||||
# return url
|
||||
# except Exception:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}/image",
|
||||
operation_id="update_model_image",
|
||||
responses={
|
||||
200: {
|
||||
"description": "The model image was updated successfully",
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
async def update_model_image(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
image: UploadFile,
|
||||
) -> None:
|
||||
if not image.content_type or not image.content_type.startswith("image"):
|
||||
raise HTTPException(status_code=415, detail="Not an image")
|
||||
|
||||
contents = await image.read()
|
||||
try:
|
||||
pil_image = Image.open(io.BytesIO(contents))
|
||||
|
||||
except Exception:
|
||||
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
model_images = ApiDependencies.invoker.services.model_images
|
||||
try:
|
||||
model_images.save(pil_image, key)
|
||||
logger.info(f"Updated image for model: {key}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return
|
||||
|
||||
|
||||
@model_manager_router.delete(
|
||||
"/i/{key}",
|
||||
operation_id="delete_model",
|
||||
|
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .model_images.model_images_base import ModelImagesBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
@ -49,6 +50,7 @@ class InvocationServices:
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
logger: "Logger",
|
||||
model_images: "ModelImagesBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
@ -72,6 +74,7 @@ class InvocationServices:
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.logger = logger
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
|
@ -3,30 +3,34 @@ from pathlib import Path
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
|
||||
class ModelImagesBase(ABC):
|
||||
"""Low-level service responsible for storing and retrieving image files."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
"""Retrieves an image as PIL Image."""
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
"""Retrieves a model image as PIL Image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str) -> Path:
|
||||
"""Gets the internal path to an image."""
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
"""Gets the internal path to a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, model_key: str) -> str:
|
||||
"""Gets the url for a model image."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
model_key: str,
|
||||
) -> None:
|
||||
"""Saves an image. Returns a tuple of the image name and created timestamp."""
|
||||
"""Saves a model image. Returns a tuple of the image name and created timestamp."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image."""
|
||||
def delete(self, model_key: str) -> None:
|
||||
"""Deletes a model image."""
|
||||
pass
|
||||
|
@ -2,19 +2,19 @@
|
||||
class ModelImageFileNotFoundException(Exception):
|
||||
"""Raised when an image file is not found in storage."""
|
||||
|
||||
def __init__(self, message="Image file not found"):
|
||||
def __init__(self, message="Model image file not found"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileSaveException(Exception):
|
||||
"""Raised when an image cannot be saved."""
|
||||
|
||||
def __init__(self, message="Image file not saved"):
|
||||
def __init__(self, message="Model image file not saved"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelImageFileDeleteException(Exception):
|
||||
"""Raised when an image cannot be deleted."""
|
||||
|
||||
def __init__(self, message="Image file not deleted"):
|
||||
def __init__(self, message="Model image file not deleted"):
|
||||
super().__init__(message)
|
||||
|
@ -10,25 +10,24 @@ from invokeai.app.services.invoker import Invoker
|
||||
from .model_images_base import ModelImagesBase
|
||||
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
|
||||
|
||||
|
||||
class DiskImageFileStorage(ModelImagesBase):
|
||||
class ModelImagesService(ModelImagesBase):
|
||||
"""Stores images on disk"""
|
||||
|
||||
__output_folder: Path
|
||||
__model_images_folder: Path
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
def __init__(self, model_images_folder: Union[str, Path]):
|
||||
|
||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
# Validate required output folders at launch
|
||||
self.__model_images_folder: Path = model_images_folder if isinstance(model_images_folder, Path) else Path(model_images_folder)
|
||||
# Validate required folders at launch
|
||||
self.__validate_storage_folders()
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
|
||||
def get(self, image_name: str) -> PILImageType:
|
||||
def get(self, model_key: str) -> PILImageType:
|
||||
try:
|
||||
image_path = self.get_path(image_name)
|
||||
image_path = self.get_path(model_key + '.png')
|
||||
|
||||
image = Image.open(image_path)
|
||||
return image
|
||||
@ -38,17 +37,13 @@ class DiskImageFileStorage(ModelImagesBase):
|
||||
def save(
|
||||
self,
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
model_key: str,
|
||||
) -> None:
|
||||
try:
|
||||
self.__validate_storage_folders()
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
image_path = self.get_path(model_key + '.png')
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
info_dict = {}
|
||||
|
||||
# When saving the image, the image object's info field is not populated. We need to set it
|
||||
image.info = info_dict
|
||||
image.save(
|
||||
image_path,
|
||||
"PNG",
|
||||
@ -59,9 +54,17 @@ class DiskImageFileStorage(ModelImagesBase):
|
||||
except Exception as e:
|
||||
raise ModelImageFileSaveException from e
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
def get_path(self, model_key: str) -> Path:
|
||||
path = self.__model_images_folder / model_key
|
||||
|
||||
return path
|
||||
|
||||
def get_url(self, model_key: str) -> str:
|
||||
return self.__invoker.services.urls.get_model_image_url(model_key)
|
||||
|
||||
def delete(self, model_key: str) -> None:
|
||||
try:
|
||||
image_path = self.get_path(image_name)
|
||||
image_path = self.get_path(model_key + '.png')
|
||||
|
||||
if image_path.exists():
|
||||
send2trash(image_path)
|
||||
@ -69,14 +72,8 @@ class DiskImageFileStorage(ModelImagesBase):
|
||||
except Exception as e:
|
||||
raise ModelImageFileDeleteException from e
|
||||
|
||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||
def get_path(self, image_name: str) -> Path:
|
||||
path = self.__output_folder / image_name
|
||||
|
||||
return path
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
"""Checks if the required output folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__output_folder]
|
||||
"""Checks if the required folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__model_images_folder]
|
||||
for folder in folders:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
|
||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets the URL for an image or thumbnail."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
"""Gets the URL for a model image"""
|
||||
pass
|
||||
|
@ -15,3 +15,6 @@ class LocalUrlService(UrlServiceBase):
|
||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||
|
||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||
|
||||
def get_model_image_url(self, model_key: str) -> str:
|
||||
return f"{self._base_url}/model_images/{model_key}.png"
|
@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
|
||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||
description="Default settings for this model", default=None
|
||||
)
|
||||
image: Optional[str] = Field(description="Image to preview model", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
@ -374,6 +375,10 @@ AnyModelConfig = Annotated[
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class ModelImage(str, Enum):
|
||||
path: str
|
||||
|
||||
|
||||
class ModelConfigFactory(object):
|
||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||
|
||||
|
@ -0,0 +1,73 @@
|
||||
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
|
||||
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
const file = files[0];
|
||||
|
||||
if (!file) {
|
||||
return;
|
||||
}
|
||||
|
||||
field.onChange(file);
|
||||
},
|
||||
[field]
|
||||
);
|
||||
|
||||
const handleResetControlImage = useCallback(() => {
|
||||
field.onChange(undefined);
|
||||
}, [field]);
|
||||
|
||||
console.log('field', field);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
onDropAccepted,
|
||||
noDrag: true,
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
if (field.value) {
|
||||
return (
|
||||
<Box>
|
||||
<Image
|
||||
src={field.value ? URL.createObjectURL(field.value) : 'http://localhost:5173/api/v2/models/i/6b8a6c0d68127ad8db5550f16d9a304b/image'}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="200px"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<IconButton
|
||||
onClick={handleResetControlImage}
|
||||
aria-label="reset this image"
|
||||
tooltip="reset this image"
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
size="sm"
|
||||
variant="link"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
|
||||
Upload Image
|
||||
</Button>
|
||||
<input {...getInputProps()} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImageUpload);
|
@ -1,18 +1,22 @@
|
||||
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
||||
import { typedMemo } from 'common/util/typedMemo';
|
||||
import { useCallback } from 'react';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import type { UseControllerProps } from 'react-hook-form';
|
||||
import { useController } from 'react-hook-form';
|
||||
import { useController, useWatch } from 'react-hook-form';
|
||||
import type { AnyModelConfig } from 'services/api/types';
|
||||
|
||||
import { Button } from '@invoke-ai/ui-library';
|
||||
import { useDropzone } from 'react-dropzone';
|
||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
||||
import { useGetModelImageQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
const { field } = useController(props);
|
||||
|
||||
const key = useWatch({ control: props.control, name: 'key' });
|
||||
|
||||
const { data } = useGetModelImageQuery(key);
|
||||
|
||||
const onDropAccepted = useCallback(
|
||||
(files: File[]) => {
|
||||
const file = files[0];
|
||||
@ -30,8 +34,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
field.onChange(undefined);
|
||||
}, [field]);
|
||||
|
||||
console.log('field', field);
|
||||
|
||||
const { getInputProps, getRootProps } = useDropzone({
|
||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||
onDropAccepted,
|
||||
@ -39,11 +41,20 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
multiple: false,
|
||||
});
|
||||
|
||||
if (field.value) {
|
||||
const image = useMemo(() => {
|
||||
console.log(field.value, 'asdf' );
|
||||
if (field.value) {
|
||||
return URL.createObjectURL(field.value);
|
||||
}
|
||||
|
||||
return data;
|
||||
}, [field.value, data]);
|
||||
|
||||
if (image) {
|
||||
return (
|
||||
<Box>
|
||||
<Image
|
||||
src={URL.createObjectURL(field.value)}
|
||||
src={image}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="200px"
|
||||
@ -56,7 +67,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||
size="sm"
|
||||
variant="link"
|
||||
// sx={sx}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
@ -73,4 +83,3 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||
};
|
||||
|
||||
export default typedMemo(ModelImageUpload);
|
||||
|
||||
|
@ -20,7 +20,11 @@ import type { SubmitHandler } from 'react-hook-form';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
||||
import {
|
||||
useGetModelConfigQuery,
|
||||
useUpdateModelImageMutation,
|
||||
useUpdateModelsMutation,
|
||||
} from 'services/api/endpoints/models';
|
||||
|
||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||
@ -32,7 +36,8 @@ export const ModelEdit = () => {
|
||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||
const [updateModelImage] = useUpdateModelImageMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -55,11 +60,15 @@ export const ModelEdit = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
// remove image from body
|
||||
const image = values.image;
|
||||
if (values.image) {
|
||||
delete values.image;
|
||||
}
|
||||
const responseBody: UpdateModelArg = {
|
||||
key: data.key,
|
||||
body: values,
|
||||
};
|
||||
console.log(responseBody, 'responseBody')
|
||||
|
||||
updateModel(responseBody)
|
||||
.unwrap()
|
||||
@ -86,6 +95,33 @@ export const ModelEdit = () => {
|
||||
)
|
||||
);
|
||||
});
|
||||
if (image) {
|
||||
updateModelImage({ key: data.key, image: image })
|
||||
.unwrap()
|
||||
.then((payload) => {
|
||||
reset(payload, { keepDefaultValues: true });
|
||||
dispatch(setSelectedModelMode('view'));
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdated'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((_) => {
|
||||
reset();
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.modelUpdateFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
});
|
||||
}
|
||||
},
|
||||
[dispatch, data?.key, reset, t, updateModel]
|
||||
);
|
||||
|
@ -23,7 +23,16 @@ export type UpdateModelArg = {
|
||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||
};
|
||||
|
||||
export type UpdateModelImageArg = {
|
||||
key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key'];
|
||||
image: paths['/api/v2/models/i/{key}/image']['patch']['formData']['content']['multipart/form-data'];
|
||||
};
|
||||
|
||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||
type UpdateModelImageResponse =
|
||||
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||
type GetModelImageResponse =
|
||||
paths['/api/v2/models/i/{key}/image']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
@ -144,6 +153,21 @@ export const modelsApi = api.injectEndpoints({
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
getModelImage: build.query<GetModelImageResponse, string>({
|
||||
query: (key) => buildModelsUrl(`i/${key}/image`),
|
||||
}),
|
||||
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||
query: ({ key, image }) => {
|
||||
const formData = new FormData();
|
||||
formData.append('image', image);
|
||||
return {
|
||||
url: buildModelsUrl(`i/${key}/image`),
|
||||
method: 'PATCH',
|
||||
body: formData,
|
||||
};
|
||||
},
|
||||
invalidatesTags: ['Model'],
|
||||
}),
|
||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||
query: ({ source }) => {
|
||||
return {
|
||||
@ -330,7 +354,9 @@ export const {
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
useDeleteModelsMutation,
|
||||
useUpdateModelMutation,
|
||||
useUpdateModelsMutation,
|
||||
useGetModelImageQuery,
|
||||
useUpdateModelImageMutation,
|
||||
useInstallModelMutation,
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user