mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Custom location support for model conversion
This commit is contained in:
parent
8c8eddcc60
commit
9769b48661
@ -2,7 +2,7 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
from typing import Literal, List, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -10,11 +10,10 @@ from pydantic import BaseModel, parse_obj_as
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
from invokeai.backend.model_management.models import (OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
@ -25,9 +24,11 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -38,10 +39,13 @@ async def list_models(
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(
|
||||
base_model,
|
||||
model_type)
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
@ -79,6 +83,7 @@ async def update_model(
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
@ -93,7 +98,7 @@ async def update_model(
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
prediction_type: Optional[Literal['v_prediction', 'epsilon', 'sample']] =
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
@ -128,6 +133,7 @@ async def import_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
@ -168,6 +174,7 @@ async def add_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/rename/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="rename_model",
|
||||
@ -213,6 +220,7 @@ async def rename_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
@ -232,15 +240,15 @@ async def delete_model(
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except KeyError:
|
||||
logger.error(f"Model not found: {model_name}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
@ -257,28 +265,28 @@ async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||
convert_dest_directory: Optional[str] = Body(description="Save the converted model to the designated directory", default=None, embed=True)
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
convert_dest_directory = dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
dest = pathlib.Path(
|
||||
convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name, base_model=base_model, model_type=model_type,
|
||||
convert_dest_directory=dest,)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model '{model_name}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
@ -293,8 +301,12 @@ async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([
|
||||
search_path])
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
@ -326,6 +338,7 @@ async def sync_to_config(
|
||||
in-memory data structures with disk data structures."""
|
||||
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
@ -349,23 +362,21 @@ async def merge_models(
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
logger.info(
|
||||
f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(
|
||||
merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names, base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory = dest
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
alpha=alpha, interp=interp, force=force, merge_dest_directory=dest)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name, base_model=base_model, model_type=ModelType.Main, )
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"One or more of the models '{model_names}' not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
@ -415,6 +415,8 @@
|
||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||
"convertToDiffusersSaveLocation": "Save Location",
|
||||
"noCustomLocationProvided": "No Custom Location Provided",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
"v1": "v1",
|
||||
"v2_base": "v2 (512px)",
|
||||
"v2_768": "v2 (768px)",
|
||||
|
@ -1,9 +1,18 @@
|
||||
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
||||
// import { convertToDiffusers } from 'app/socketio/actions';
|
||||
import {
|
||||
Flex,
|
||||
ListItem,
|
||||
Radio,
|
||||
RadioGroup,
|
||||
Text,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@chakra-ui/react';
|
||||
import { makeToast } from 'app/components/Toaster';
|
||||
// import { convertToDiffusers } from 'app/socketio/actions';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import IAIInput from 'common/components/IAIInput';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -15,6 +24,8 @@ interface ModelConvertProps {
|
||||
model: CheckpointModelConfig;
|
||||
}
|
||||
|
||||
type SaveLocation = 'InvokeAIRoot' | 'Custom';
|
||||
|
||||
export default function ModelConvert(props: ModelConvertProps) {
|
||||
const { model } = props;
|
||||
|
||||
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
|
||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||
|
||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
||||
const [saveLocation, setSaveLocation] =
|
||||
useState<SaveLocation>('InvokeAIRoot');
|
||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||
|
||||
useEffect(() => {
|
||||
setSaveLocation('same');
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
}, [model]);
|
||||
|
||||
const modelConvertCancelHandler = () => {
|
||||
setSaveLocation('same');
|
||||
setSaveLocation('InvokeAIRoot');
|
||||
};
|
||||
|
||||
const modelConvertHandler = () => {
|
||||
const responseBody = {
|
||||
base_model: model.base_model,
|
||||
model_name: model.model_name,
|
||||
body: {
|
||||
convert_dest_directory:
|
||||
saveLocation === 'Custom' ? customSaveLocation : undefined,
|
||||
},
|
||||
};
|
||||
|
||||
if (saveLocation === 'Custom' && customSaveLocation === '') {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('modelManager.noCustomLocationProvided'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${t('modelManager.convertingModelBegin')}: ${
|
||||
model.model_name
|
||||
}`,
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
convertModel(responseBody)
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
|
||||
</Flex>
|
||||
|
||||
{/* <Flex flexDir="column" gap={4}>
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex marginTop={4} flexDir="column" gap={2}>
|
||||
<Text fontWeight="600">
|
||||
{t('modelManager.convertToDiffusersSaveLocation')}
|
||||
</Text>
|
||||
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}>
|
||||
<RadioGroup
|
||||
value={saveLocation}
|
||||
onChange={(v) => setSaveLocation(v as SaveLocation)}
|
||||
>
|
||||
<Flex gap={4}>
|
||||
<Radio value="same">
|
||||
<Tooltip label="Save converted model in the same folder">
|
||||
{t('modelManager.sameFolder')}
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
|
||||
<Radio value="root">
|
||||
<Radio value="InvokeAIRoot">
|
||||
<Tooltip label="Save converted model in the InvokeAI root folder">
|
||||
{t('modelManager.invokeRoot')}
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
|
||||
<Radio value="custom">
|
||||
<Radio value="Custom">
|
||||
<Tooltip label="Save converted model in a custom folder">
|
||||
{t('modelManager.custom')}
|
||||
</Tooltip>
|
||||
</Radio>
|
||||
</Flex>
|
||||
</RadioGroup>
|
||||
</Flex> */}
|
||||
|
||||
{/* {saveLocation === 'custom' && (
|
||||
</Flex>
|
||||
{saveLocation === 'Custom' && (
|
||||
<Flex flexDirection="column" rowGap={2}>
|
||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||
{t('modelManager.customSaveLocation')}
|
||||
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
|
||||
<IAIInput
|
||||
value={customSaveLocation}
|
||||
onChange={(e) => {
|
||||
if (e.target.value !== '')
|
||||
setCustomSaveLocation(e.target.value);
|
||||
}}
|
||||
width="full"
|
||||
/>
|
||||
</Flex>
|
||||
)} */}
|
||||
)}
|
||||
</Flex>
|
||||
</IAIAlertDialog>
|
||||
);
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ import {
|
||||
BaseModelType,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
ConvertModelConfig,
|
||||
DiffusersModelConfig,
|
||||
LoRAModelConfig,
|
||||
MainModelConfig,
|
||||
@ -62,6 +63,7 @@ type DeleteMainModelResponse = void;
|
||||
type ConvertMainModelArg = {
|
||||
base_model: BaseModelType;
|
||||
model_name: string;
|
||||
body: ConvertModelConfig;
|
||||
};
|
||||
|
||||
type ConvertMainModelResponse =
|
||||
@ -176,10 +178,11 @@ export const modelsApi = api.injectEndpoints({
|
||||
ConvertMainModelResponse,
|
||||
ConvertMainModelArg
|
||||
>({
|
||||
query: ({ base_model, model_name }) => {
|
||||
query: ({ base_model, model_name, body }) => {
|
||||
return {
|
||||
url: `models/convert/${base_model}/main/${model_name}`,
|
||||
method: 'PUT',
|
||||
body: body,
|
||||
};
|
||||
},
|
||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||
|
@ -378,6 +378,14 @@ export type components = {
|
||||
*/
|
||||
image_count: number;
|
||||
};
|
||||
/** Body_convert_model */
|
||||
Body_convert_model: {
|
||||
/**
|
||||
* Convert Dest Directory
|
||||
* @description Save the converted model to the designated directory
|
||||
*/
|
||||
convert_dest_directory?: string;
|
||||
};
|
||||
/** Body_create_board_image */
|
||||
Body_create_board_image: {
|
||||
/**
|
||||
@ -5200,10 +5208,6 @@ export type operations = {
|
||||
*/
|
||||
convert_model: {
|
||||
parameters: {
|
||||
query?: {
|
||||
/** @description Save the converted model to the designated directory */
|
||||
convert_dest_directory?: string;
|
||||
};
|
||||
path: {
|
||||
/** @description Base model */
|
||||
base_model: components["schemas"]["BaseModelType"];
|
||||
@ -5213,6 +5217,11 @@ export type operations = {
|
||||
model_name: string;
|
||||
};
|
||||
};
|
||||
requestBody?: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["Body_convert_model"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Model converted successfully */
|
||||
200: {
|
||||
|
@ -55,7 +55,9 @@ export type AnyModelConfig =
|
||||
| ControlNetModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig;
|
||||
|
||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||
|
||||
// Graphs
|
||||
export type Graph = components['schemas']['Graph'];
|
||||
|
Loading…
Reference in New Issue
Block a user