mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fetching model image, still not working
This commit is contained in:
parent
c1cdfd132b
commit
2f6964bfa5
@ -25,6 +25,7 @@ from ..services.invocation_cache.invocation_cache_memory import MemoryInvocation
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||||
from ..services.invoker import Invoker
|
from ..services.invoker import Invoker
|
||||||
|
from ..services.model_images.model_images_default import ModelImagesService
|
||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
@ -71,6 +72,8 @@ class ApiDependencies:
|
|||||||
|
|
||||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||||
|
|
||||||
|
model_images_folder = config.models_path
|
||||||
|
|
||||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||||
|
|
||||||
configuration = config
|
configuration = config
|
||||||
@ -92,6 +95,7 @@ class ApiDependencies:
|
|||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(event_bus=events)
|
||||||
|
model_images_service = ModelImagesService(model_images_folder / "model_images")
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
model_record_service=ModelRecordServiceSQL(db=db),
|
model_record_service=ModelRecordServiceSQL(db=db),
|
||||||
@ -118,6 +122,7 @@ class ApiDependencies:
|
|||||||
images=images,
|
images=images,
|
||||||
invocation_cache=invocation_cache,
|
invocation_cache=invocation_cache,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
model_images=model_images_service,
|
||||||
model_manager=model_manager,
|
model_manager=model_manager,
|
||||||
download_queue=download_queue_service,
|
download_queue=download_queue_service,
|
||||||
names=names,
|
names=names,
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein
|
# Copyright (c) 2023 Lincoln D. Stein
|
||||||
"""FastAPI route for model configuration records."""
|
"""FastAPI route for model configuration records."""
|
||||||
|
|
||||||
|
import io
|
||||||
import pathlib
|
import pathlib
|
||||||
import shutil
|
import shutil
|
||||||
|
import traceback
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response, UploadFile
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -31,6 +35,9 @@ from ..dependencies import ApiDependencies
|
|||||||
|
|
||||||
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
model_manager_router = APIRouter(prefix="/v2/models", tags=["model_manager"])
|
||||||
|
|
||||||
|
# images are immutable; set a high max-age
|
||||||
|
IMAGE_MAX_AGE = 31536000
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
"""Return list of configs."""
|
"""Return list of configs."""
|
||||||
@ -72,7 +79,7 @@ example_model_input = {
|
|||||||
"description": "Model description",
|
"description": "Model description",
|
||||||
"vae": None,
|
"vae": None,
|
||||||
"variant": "normal",
|
"variant": "normal",
|
||||||
"image": "blob"
|
"image": "blob",
|
||||||
}
|
}
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
@ -267,6 +274,93 @@ async def update_model_record(
|
|||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.get(
|
||||||
|
"/i/{key}/image",
|
||||||
|
operation_id="get_model_image",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "The model image was fetched successfully",
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
404: {"description": "The model could not be found"},
|
||||||
|
},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def get_model_image(
|
||||||
|
key: str = Path(description="The name of model image file to get"),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Gets a full-resolution image file"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = ApiDependencies.invoker.services.model_images.get_path(key + ".png")
|
||||||
|
|
||||||
|
if not path:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
response = FileResponse(
|
||||||
|
path,
|
||||||
|
media_type="image/png",
|
||||||
|
filename=key + ".png",
|
||||||
|
content_disposition_type="inline",
|
||||||
|
)
|
||||||
|
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||||
|
return response
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
# async def get_model_image(
|
||||||
|
# key: Annotated[str, Path(description="Unique key of model")],
|
||||||
|
# ) -> Optional[str]:
|
||||||
|
# model_images = ApiDependencies.invoker.services.model_images
|
||||||
|
# try:
|
||||||
|
# url = model_images.get_url(key)
|
||||||
|
|
||||||
|
# if not url:
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# return url
|
||||||
|
# except Exception:
|
||||||
|
# raise HTTPException(status_code=404)
|
||||||
|
|
||||||
|
|
||||||
|
@model_manager_router.patch(
|
||||||
|
"/i/{key}/image",
|
||||||
|
operation_id="update_model_image",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"description": "The model image was updated successfully",
|
||||||
|
},
|
||||||
|
400: {"description": "Bad request"},
|
||||||
|
},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
async def update_model_image(
|
||||||
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
|
image: UploadFile,
|
||||||
|
) -> None:
|
||||||
|
if not image.content_type or not image.content_type.startswith("image"):
|
||||||
|
raise HTTPException(status_code=415, detail="Not an image")
|
||||||
|
|
||||||
|
contents = await image.read()
|
||||||
|
try:
|
||||||
|
pil_image = Image.open(io.BytesIO(contents))
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
|
||||||
|
raise HTTPException(status_code=415, detail="Failed to read image")
|
||||||
|
|
||||||
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
model_images = ApiDependencies.invoker.services.model_images
|
||||||
|
try:
|
||||||
|
model_images.save(pil_image, key)
|
||||||
|
logger.info(f"Updated image for model: {key}")
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(str(e))
|
||||||
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.delete(
|
@model_manager_router.delete(
|
||||||
"/i/{key}",
|
"/i/{key}",
|
||||||
operation_id="delete_model",
|
operation_id="delete_model",
|
||||||
|
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
|||||||
from .images.images_base import ImageServiceABC
|
from .images.images_base import ImageServiceABC
|
||||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||||
|
from .model_images.model_images_base import ModelImagesBase
|
||||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||||
from .names.names_base import NameServiceBase
|
from .names.names_base import NameServiceBase
|
||||||
from .session_processor.session_processor_base import SessionProcessorBase
|
from .session_processor.session_processor_base import SessionProcessorBase
|
||||||
@ -49,6 +50,7 @@ class InvocationServices:
|
|||||||
image_files: "ImageFileStorageBase",
|
image_files: "ImageFileStorageBase",
|
||||||
image_records: "ImageRecordStorageBase",
|
image_records: "ImageRecordStorageBase",
|
||||||
logger: "Logger",
|
logger: "Logger",
|
||||||
|
model_images: "ModelImagesBase",
|
||||||
model_manager: "ModelManagerServiceBase",
|
model_manager: "ModelManagerServiceBase",
|
||||||
download_queue: "DownloadQueueServiceBase",
|
download_queue: "DownloadQueueServiceBase",
|
||||||
performance_statistics: "InvocationStatsServiceBase",
|
performance_statistics: "InvocationStatsServiceBase",
|
||||||
@ -72,6 +74,7 @@ class InvocationServices:
|
|||||||
self.image_files = image_files
|
self.image_files = image_files
|
||||||
self.image_records = image_records
|
self.image_records = image_records
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.model_images = model_images
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
self.download_queue = download_queue
|
self.download_queue = download_queue
|
||||||
self.performance_statistics = performance_statistics
|
self.performance_statistics = performance_statistics
|
||||||
|
@ -3,30 +3,34 @@ from pathlib import Path
|
|||||||
|
|
||||||
from PIL.Image import Image as PILImageType
|
from PIL.Image import Image as PILImageType
|
||||||
|
|
||||||
|
|
||||||
class ModelImagesBase(ABC):
|
class ModelImagesBase(ABC):
|
||||||
"""Low-level service responsible for storing and retrieving image files."""
|
"""Low-level service responsible for storing and retrieving image files."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get(self, image_name: str) -> PILImageType:
|
def get(self, model_key: str) -> PILImageType:
|
||||||
"""Retrieves an image as PIL Image."""
|
"""Retrieves a model image as PIL Image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_name: str) -> Path:
|
def get_path(self, model_key: str) -> Path:
|
||||||
"""Gets the internal path to an image."""
|
"""Gets the internal path to a model image."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url(self, model_key: str) -> str:
|
||||||
|
"""Gets the url for a model image."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
model_key: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Saves an image. Returns a tuple of the image name and created timestamp."""
|
"""Saves a model image. Returns a tuple of the image name and created timestamp."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, image_name: str) -> None:
|
def delete(self, model_key: str) -> None:
|
||||||
"""Deletes an image."""
|
"""Deletes a model image."""
|
||||||
pass
|
pass
|
||||||
|
@ -2,19 +2,19 @@
|
|||||||
class ModelImageFileNotFoundException(Exception):
|
class ModelImageFileNotFoundException(Exception):
|
||||||
"""Raised when an image file is not found in storage."""
|
"""Raised when an image file is not found in storage."""
|
||||||
|
|
||||||
def __init__(self, message="Image file not found"):
|
def __init__(self, message="Model image file not found"):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileSaveException(Exception):
|
class ModelImageFileSaveException(Exception):
|
||||||
"""Raised when an image cannot be saved."""
|
"""Raised when an image cannot be saved."""
|
||||||
|
|
||||||
def __init__(self, message="Image file not saved"):
|
def __init__(self, message="Model image file not saved"):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ModelImageFileDeleteException(Exception):
|
class ModelImageFileDeleteException(Exception):
|
||||||
"""Raised when an image cannot be deleted."""
|
"""Raised when an image cannot be deleted."""
|
||||||
|
|
||||||
def __init__(self, message="Image file not deleted"):
|
def __init__(self, message="Model image file not deleted"):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
@ -10,25 +10,24 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
from .model_images_base import ModelImagesBase
|
from .model_images_base import ModelImagesBase
|
||||||
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
|
from .model_images_common import ModelImageFileDeleteException, ModelImageFileNotFoundException, ModelImageFileSaveException
|
||||||
|
|
||||||
|
class ModelImagesService(ModelImagesBase):
|
||||||
class DiskImageFileStorage(ModelImagesBase):
|
|
||||||
"""Stores images on disk"""
|
"""Stores images on disk"""
|
||||||
|
|
||||||
__output_folder: Path
|
__model_images_folder: Path
|
||||||
__invoker: Invoker
|
__invoker: Invoker
|
||||||
|
|
||||||
def __init__(self, output_folder: Union[str, Path]):
|
def __init__(self, model_images_folder: Union[str, Path]):
|
||||||
|
|
||||||
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
self.__model_images_folder: Path = model_images_folder if isinstance(model_images_folder, Path) else Path(model_images_folder)
|
||||||
# Validate required output folders at launch
|
# Validate required folders at launch
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self.__invoker = invoker
|
self.__invoker = invoker
|
||||||
|
|
||||||
def get(self, image_name: str) -> PILImageType:
|
def get(self, model_key: str) -> PILImageType:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(model_key + '.png')
|
||||||
|
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
return image
|
return image
|
||||||
@ -38,17 +37,13 @@ class DiskImageFileStorage(ModelImagesBase):
|
|||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
image: PILImageType,
|
image: PILImageType,
|
||||||
image_name: str,
|
model_key: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self.__validate_storage_folders()
|
self.__validate_storage_folders()
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(model_key + '.png')
|
||||||
|
|
||||||
pnginfo = PngImagePlugin.PngInfo()
|
pnginfo = PngImagePlugin.PngInfo()
|
||||||
info_dict = {}
|
|
||||||
|
|
||||||
# When saving the image, the image object's info field is not populated. We need to set it
|
|
||||||
image.info = info_dict
|
|
||||||
image.save(
|
image.save(
|
||||||
image_path,
|
image_path,
|
||||||
"PNG",
|
"PNG",
|
||||||
@ -59,9 +54,17 @@ class DiskImageFileStorage(ModelImagesBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ModelImageFileSaveException from e
|
raise ModelImageFileSaveException from e
|
||||||
|
|
||||||
def delete(self, image_name: str) -> None:
|
def get_path(self, model_key: str) -> Path:
|
||||||
|
path = self.__model_images_folder / model_key
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
def get_url(self, model_key: str) -> str:
|
||||||
|
return self.__invoker.services.urls.get_model_image_url(model_key)
|
||||||
|
|
||||||
|
def delete(self, model_key: str) -> None:
|
||||||
try:
|
try:
|
||||||
image_path = self.get_path(image_name)
|
image_path = self.get_path(model_key + '.png')
|
||||||
|
|
||||||
if image_path.exists():
|
if image_path.exists():
|
||||||
send2trash(image_path)
|
send2trash(image_path)
|
||||||
@ -69,14 +72,8 @@ class DiskImageFileStorage(ModelImagesBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ModelImageFileDeleteException from e
|
raise ModelImageFileDeleteException from e
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
|
||||||
def get_path(self, image_name: str) -> Path:
|
|
||||||
path = self.__output_folder / image_name
|
|
||||||
|
|
||||||
return path
|
|
||||||
|
|
||||||
def __validate_storage_folders(self) -> None:
|
def __validate_storage_folders(self) -> None:
|
||||||
"""Checks if the required output folders exist and create them if they don't"""
|
"""Checks if the required folders exist and create them if they don't"""
|
||||||
folders: list[Path] = [self.__output_folder]
|
folders: list[Path] = [self.__model_images_folder]
|
||||||
for folder in folders:
|
for folder in folders:
|
||||||
folder.mkdir(parents=True, exist_ok=True)
|
folder.mkdir(parents=True, exist_ok=True)
|
@ -8,3 +8,8 @@ class UrlServiceBase(ABC):
|
|||||||
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
def get_image_url(self, image_name: str, thumbnail: bool = False) -> str:
|
||||||
"""Gets the URL for an image or thumbnail."""
|
"""Gets the URL for an image or thumbnail."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model_image_url(self, model_key: str) -> str:
|
||||||
|
"""Gets the URL for a model image"""
|
||||||
|
pass
|
||||||
|
@ -15,3 +15,6 @@ class LocalUrlService(UrlServiceBase):
|
|||||||
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
return f"{self._base_url}/images/i/{image_basename}/thumbnail"
|
||||||
|
|
||||||
return f"{self._base_url}/images/i/{image_basename}/full"
|
return f"{self._base_url}/images/i/{image_basename}/full"
|
||||||
|
|
||||||
|
def get_model_image_url(self, model_key: str) -> str:
|
||||||
|
return f"{self._base_url}/model_images/{model_key}.png"
|
@ -161,6 +161,7 @@ class ModelConfigBase(BaseModel):
|
|||||||
default_settings: Optional[ModelDefaultSettings] = Field(
|
default_settings: Optional[ModelDefaultSettings] = Field(
|
||||||
description="Default settings for this model", default=None
|
description="Default settings for this model", default=None
|
||||||
)
|
)
|
||||||
|
image: Optional[str] = Field(description="Image to preview model", default=None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
@ -374,6 +375,10 @@ AnyModelConfig = Annotated[
|
|||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelImage(str, Enum):
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigFactory(object):
|
class ModelConfigFactory(object):
|
||||||
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
||||||
|
|
||||||
|
@ -0,0 +1,73 @@
|
|||||||
|
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
||||||
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
|
import { useController } from 'react-hook-form';
|
||||||
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
|
import { useDropzone } from 'react-dropzone';
|
||||||
|
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||||
|
|
||||||
|
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||||
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const onDropAccepted = useCallback(
|
||||||
|
(files: File[]) => {
|
||||||
|
const file = files[0];
|
||||||
|
|
||||||
|
if (!file) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
field.onChange(file);
|
||||||
|
},
|
||||||
|
[field]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleResetControlImage = useCallback(() => {
|
||||||
|
field.onChange(undefined);
|
||||||
|
}, [field]);
|
||||||
|
|
||||||
|
console.log('field', field);
|
||||||
|
|
||||||
|
const { getInputProps, getRootProps } = useDropzone({
|
||||||
|
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||||
|
onDropAccepted,
|
||||||
|
noDrag: true,
|
||||||
|
multiple: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (field.value) {
|
||||||
|
return (
|
||||||
|
<Box>
|
||||||
|
<Image
|
||||||
|
src={field.value ? URL.createObjectURL(field.value) : 'http://localhost:5173/api/v2/models/i/6b8a6c0d68127ad8db5550f16d9a304b/image'}
|
||||||
|
objectFit="contain"
|
||||||
|
maxW="full"
|
||||||
|
maxH="200px"
|
||||||
|
borderRadius="base"
|
||||||
|
/>
|
||||||
|
<IconButton
|
||||||
|
onClick={handleResetControlImage}
|
||||||
|
aria-label="reset this image"
|
||||||
|
tooltip="reset this image"
|
||||||
|
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||||
|
size="sm"
|
||||||
|
variant="link"
|
||||||
|
/>
|
||||||
|
</Box>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Button leftIcon={<PiUploadSimpleBold />} {...getRootProps()} pointerEvents="auto">
|
||||||
|
Upload Image
|
||||||
|
</Button>
|
||||||
|
<input {...getInputProps()} />
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default typedMemo(ModelImageUpload);
|
@ -1,18 +1,22 @@
|
|||||||
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
import { Box, IconButton, Image } from '@invoke-ai/ui-library';
|
||||||
import { typedMemo } from 'common/util/typedMemo';
|
import { typedMemo } from 'common/util/typedMemo';
|
||||||
import { useCallback } from 'react';
|
import { useCallback, useEffect, useMemo } from 'react';
|
||||||
import type { UseControllerProps } from 'react-hook-form';
|
import type { UseControllerProps } from 'react-hook-form';
|
||||||
import { useController } from 'react-hook-form';
|
import { useController, useWatch } from 'react-hook-form';
|
||||||
import type { AnyModelConfig } from 'services/api/types';
|
import type { AnyModelConfig } from 'services/api/types';
|
||||||
|
|
||||||
import { Button } from '@invoke-ai/ui-library';
|
import { Button } from '@invoke-ai/ui-library';
|
||||||
import { useDropzone } from 'react-dropzone';
|
import { useDropzone } from 'react-dropzone';
|
||||||
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
import { PiArrowCounterClockwiseBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||||
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
|
import { useGetModelImageQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
||||||
const { field } = useController(props);
|
const { field } = useController(props);
|
||||||
|
|
||||||
|
const key = useWatch({ control: props.control, name: 'key' });
|
||||||
|
|
||||||
|
const { data } = useGetModelImageQuery(key);
|
||||||
|
|
||||||
const onDropAccepted = useCallback(
|
const onDropAccepted = useCallback(
|
||||||
(files: File[]) => {
|
(files: File[]) => {
|
||||||
const file = files[0];
|
const file = files[0];
|
||||||
@ -30,8 +34,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
|||||||
field.onChange(undefined);
|
field.onChange(undefined);
|
||||||
}, [field]);
|
}, [field]);
|
||||||
|
|
||||||
console.log('field', field);
|
|
||||||
|
|
||||||
const { getInputProps, getRootProps } = useDropzone({
|
const { getInputProps, getRootProps } = useDropzone({
|
||||||
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
accept: { 'image/png': ['.png'], 'image/jpeg': ['.jpg', '.jpeg', '.png'] },
|
||||||
onDropAccepted,
|
onDropAccepted,
|
||||||
@ -39,11 +41,20 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
|||||||
multiple: false,
|
multiple: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (field.value) {
|
const image = useMemo(() => {
|
||||||
|
console.log(field.value, 'asdf' );
|
||||||
|
if (field.value) {
|
||||||
|
return URL.createObjectURL(field.value);
|
||||||
|
}
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [field.value, data]);
|
||||||
|
|
||||||
|
if (image) {
|
||||||
return (
|
return (
|
||||||
<Box>
|
<Box>
|
||||||
<Image
|
<Image
|
||||||
src={URL.createObjectURL(field.value)}
|
src={image}
|
||||||
objectFit="contain"
|
objectFit="contain"
|
||||||
maxW="full"
|
maxW="full"
|
||||||
maxH="200px"
|
maxH="200px"
|
||||||
@ -56,7 +67,6 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
|||||||
icon={<PiArrowCounterClockwiseBold size={16} />}
|
icon={<PiArrowCounterClockwiseBold size={16} />}
|
||||||
size="sm"
|
size="sm"
|
||||||
variant="link"
|
variant="link"
|
||||||
// sx={sx}
|
|
||||||
/>
|
/>
|
||||||
</Box>
|
</Box>
|
||||||
);
|
);
|
||||||
@ -73,4 +83,3 @@ const ModelImageUpload = (props: UseControllerProps<AnyModelConfig>) => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export default typedMemo(ModelImageUpload);
|
export default typedMemo(ModelImageUpload);
|
||||||
|
|
||||||
|
@ -20,7 +20,11 @@ import type { SubmitHandler } from 'react-hook-form';
|
|||||||
import { useForm } from 'react-hook-form';
|
import { useForm } from 'react-hook-form';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||||
import { useGetModelConfigQuery, useUpdateModelMutation } from 'services/api/endpoints/models';
|
import {
|
||||||
|
useGetModelConfigQuery,
|
||||||
|
useUpdateModelImageMutation,
|
||||||
|
useUpdateModelsMutation,
|
||||||
|
} from 'services/api/endpoints/models';
|
||||||
|
|
||||||
import BaseModelSelect from './Fields/BaseModelSelect';
|
import BaseModelSelect from './Fields/BaseModelSelect';
|
||||||
import ModelImageUpload from './Fields/ModelImageUpload';
|
import ModelImageUpload from './Fields/ModelImageUpload';
|
||||||
@ -32,7 +36,8 @@ export const ModelEdit = () => {
|
|||||||
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
const selectedModelKey = useAppSelector((s) => s.modelmanagerV2.selectedModelKey);
|
||||||
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
const { data, isLoading } = useGetModelConfigQuery(selectedModelKey ?? skipToken);
|
||||||
|
|
||||||
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelMutation();
|
const [updateModel, { isLoading: isSubmitting }] = useUpdateModelsMutation();
|
||||||
|
const [updateModelImage] = useUpdateModelImageMutation();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
@ -55,11 +60,15 @@ export const ModelEdit = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// remove image from body
|
||||||
|
const image = values.image;
|
||||||
|
if (values.image) {
|
||||||
|
delete values.image;
|
||||||
|
}
|
||||||
const responseBody: UpdateModelArg = {
|
const responseBody: UpdateModelArg = {
|
||||||
key: data.key,
|
key: data.key,
|
||||||
body: values,
|
body: values,
|
||||||
};
|
};
|
||||||
console.log(responseBody, 'responseBody')
|
|
||||||
|
|
||||||
updateModel(responseBody)
|
updateModel(responseBody)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
@ -86,6 +95,33 @@ export const ModelEdit = () => {
|
|||||||
)
|
)
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
if (image) {
|
||||||
|
updateModelImage({ key: data.key, image: image })
|
||||||
|
.unwrap()
|
||||||
|
.then((payload) => {
|
||||||
|
reset(payload, { keepDefaultValues: true });
|
||||||
|
dispatch(setSelectedModelMode('view'));
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdated'),
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
})
|
||||||
|
.catch((_) => {
|
||||||
|
reset();
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.modelUpdateFailed'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
},
|
},
|
||||||
[dispatch, data?.key, reset, t, updateModel]
|
[dispatch, data?.key, reset, t, updateModel]
|
||||||
);
|
);
|
||||||
|
@ -23,7 +23,16 @@ export type UpdateModelArg = {
|
|||||||
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
body: paths['/api/v2/models/i/{key}']['patch']['requestBody']['content']['application/json'];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type UpdateModelImageArg = {
|
||||||
|
key: paths['/api/v2/models/i/{key}/image']['patch']['parameters']['path']['key'];
|
||||||
|
image: paths['/api/v2/models/i/{key}/image']['patch']['formData']['content']['multipart/form-data'];
|
||||||
|
};
|
||||||
|
|
||||||
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
type UpdateModelResponse = paths['/api/v2/models/i/{key}']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
type UpdateModelImageResponse =
|
||||||
|
paths['/api/v2/models/i/{key}/image']['patch']['responses']['200']['content']['application/json'];
|
||||||
|
type GetModelImageResponse =
|
||||||
|
paths['/api/v2/models/i/{key}/image']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
type GetModelConfigResponse = paths['/api/v2/models/i/{key}']['get']['responses']['200']['content']['application/json'];
|
||||||
|
|
||||||
@ -144,6 +153,21 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
},
|
},
|
||||||
invalidatesTags: ['Model'],
|
invalidatesTags: ['Model'],
|
||||||
}),
|
}),
|
||||||
|
getModelImage: build.query<GetModelImageResponse, string>({
|
||||||
|
query: (key) => buildModelsUrl(`i/${key}/image`),
|
||||||
|
}),
|
||||||
|
updateModelImage: build.mutation<UpdateModelImageResponse, UpdateModelImageArg>({
|
||||||
|
query: ({ key, image }) => {
|
||||||
|
const formData = new FormData();
|
||||||
|
formData.append('image', image);
|
||||||
|
return {
|
||||||
|
url: buildModelsUrl(`i/${key}/image`),
|
||||||
|
method: 'PATCH',
|
||||||
|
body: formData,
|
||||||
|
};
|
||||||
|
},
|
||||||
|
invalidatesTags: ['Model'],
|
||||||
|
}),
|
||||||
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
installModel: build.mutation<InstallModelResponse, InstallModelArg>({
|
||||||
query: ({ source }) => {
|
query: ({ source }) => {
|
||||||
return {
|
return {
|
||||||
@ -330,7 +354,9 @@ export const {
|
|||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
useDeleteModelsMutation,
|
useDeleteModelsMutation,
|
||||||
useUpdateModelMutation,
|
useUpdateModelsMutation,
|
||||||
|
useGetModelImageQuery,
|
||||||
|
useUpdateModelImageMutation,
|
||||||
useInstallModelMutation,
|
useInstallModelMutation,
|
||||||
useConvertModelMutation,
|
useConvertModelMutation,
|
||||||
useSyncModelsMutation,
|
useSyncModelsMutation,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user