feat: Add Custom location support for model conversion

This commit is contained in:
blessedcoolant 2023-07-15 19:17:16 +12:00
parent 8c8eddcc60
commit 9769b48661
6 changed files with 204 additions and 142 deletions

View File

@ -2,7 +2,7 @@
import pathlib
from typing import Literal, List, Optional, Union
from typing import List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -10,11 +10,10 @@ from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import (OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -25,32 +24,37 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
responses={200: {"model": ModelsList}},
)
async def list_models(
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
models_raw = ApiDependencies.invoker.services.model_manager.list_models(
base_model,
model_type)
models = parse_obj_as(ModelsList, {"models": models_raw})
return models
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
responses={200: {"description": "The model was updated successfully"},
404: {"description": "The model could not be found"},
400: {"description": "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
status_code=200,
response_model=UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
@ -79,33 +83,34 @@ async def update_model(
return model_response
@models_router.post(
"/import",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
prediction_type: Optional[Literal['v_prediction', 'epsilon', 'sample']] =
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
items_to_import=items_to_import,
prediction_type_helper=lambda x: prediction_types.get(prediction_type)
)
info = installed_models.get(location)
@ -128,14 +133,15 @@ async def import_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
@ -152,7 +158,7 @@ async def add_model(
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
model_attributes=info.dict()
)
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -168,13 +174,14 @@ async def add_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
responses={
201: {"description": "The model was renamed successfully"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
@ -192,11 +199,11 @@ async def rename_model(
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
@ -213,15 +220,16 @@ async def rename_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={
204: { "description": "Model deleted successfully" },
404: { "description": "Model not found" }
204: {"description": "Model deleted successfully"},
404: {"description": "Model not found"}
},
status_code = 204,
response_model = None,
status_code=204,
response_model=None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
@ -232,81 +240,85 @@ async def delete_model(
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(model_name,
base_model = base_model,
model_type = model_type
)
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
raise HTTPException(
status_code=404, detail=f"Model '{model_name}' not found")
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: { "description": "Model converted successfully" },
400: {"description" : "Bad request" },
404: { "description": "Model not found" },
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
},
status_code = 200,
response_model = ConvertModelResponse,
status_code=200,
response_model=ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
convert_dest_directory: Optional[str] = Body(description="Save the converted model to the designated directory", default=None, embed=True)
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type,
convert_dest_directory = dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
model_type = model_type)
dest = pathlib.Path(
convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name, base_model=base_model, model_type=model_type,
convert_dest_directory=dest,)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
raise HTTPException(
status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
200: {"description": "Directory searched successfully"},
404: {"description": "Invalid directory path"},
},
status_code = 200,
response_model = List[pathlib.Path]
status_code=200,
response_model=List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
) -> List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
raise HTTPException(
status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([
search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: { "description" : "paths retrieved successfully" },
200: {"description": "paths retrieved successfully"},
},
status_code = 200,
response_model = List[pathlib.Path]
status_code=200,
response_model=List[pathlib.Path]
)
async def list_ckpt_configs(
)->List[pathlib.Path]:
) -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@ -315,27 +327,28 @@ async def list_ckpt_configs(
"/sync",
operation_id="sync_to_config",
responses={
201: { "description": "synchronization successful" },
201: {"description": "synchronization successful"},
},
status_code = 201,
response_model = None
status_code=201,
response_model=None
)
async def sync_to_config(
)->None:
) -> None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: { "description": "Model converted successfully" },
400: { "description": "Incompatible models" },
404: { "description": "One or more models not found" },
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
},
status_code = 200,
response_model = MergeModelResponse,
status_code=200,
response_model=MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
@ -349,23 +362,21 @@ async def merge_models(
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
logger.info(
f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(
merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names, base_model,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,
)
alpha=alpha, interp=interp, force=force, merge_dest_directory=dest)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name, base_model=base_model, model_type=ModelType.Main, )
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
raise HTTPException(
status_code=404,
detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response

View File

@ -415,6 +415,8 @@
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location",
"noCustomLocationProvided": "No Custom Location Provided",
"convertingModelBegin": "Converting Model. Please wait.",
"v1": "v1",
"v2_base": "v2 (512px)",
"v2_768": "v2 (768px)",

View File

@ -1,9 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
// import { convertToDiffusers } from 'app/socketio/actions';
import {
Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster';
// import { convertToDiffusers } from 'app/socketio/actions';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
@ -15,6 +24,8 @@ interface ModelConvertProps {
model: CheckpointModelConfig;
}
type SaveLocation = 'InvokeAIRoot' | 'Custom';
export default function ModelConvert(props: ModelConvertProps) {
const { model } = props;
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same');
const [saveLocation, setSaveLocation] =
useState<SaveLocation>('InvokeAIRoot');
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => {
setSaveLocation('same');
setSaveLocation('InvokeAIRoot');
}, [model]);
const modelConvertCancelHandler = () => {
setSaveLocation('same');
setSaveLocation('InvokeAIRoot');
};
const modelConvertHandler = () => {
const responseBody = {
base_model: model.base_model,
model_name: model.model_name,
body: {
convert_dest_directory:
saveLocation === 'Custom' ? customSaveLocation : undefined,
},
};
if (saveLocation === 'Custom' && customSaveLocation === '') {
dispatch(
addToast(
makeToast({
title: t('modelManager.noCustomLocationProvided'),
status: 'error',
})
)
);
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${
model.model_name
}`,
status: 'success',
})
)
);
convertModel(responseBody)
.unwrap()
.then((_) => {
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex>
{/* <Flex flexDir="column" gap={4}>
<Flex flexDir="column" gap={2}>
<Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')}
</Text>
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}>
<RadioGroup
value={saveLocation}
onChange={(v) => setSaveLocation(v as SaveLocation)}
>
<Flex gap={4}>
<Radio value="same">
<Tooltip label="Save converted model in the same folder">
{t('modelManager.sameFolder')}
</Tooltip>
</Radio>
<Radio value="root">
<Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')}
</Tooltip>
</Radio>
<Radio value="custom">
<Radio value="Custom">
<Tooltip label="Save converted model in a custom folder">
{t('modelManager.custom')}
</Tooltip>
</Radio>
</Flex>
</RadioGroup>
</Flex> */}
{/* {saveLocation === 'custom' && (
</Flex>
{saveLocation === 'Custom' && (
<Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')}
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIInput
value={customSaveLocation}
onChange={(e) => {
if (e.target.value !== '')
setCustomSaveLocation(e.target.value);
}}
width="full"
/>
</Flex>
)} */}
)}
</Flex>
</IAIAlertDialog>
);
}

View File

@ -5,6 +5,7 @@ import {
BaseModelType,
CheckpointModelConfig,
ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig,
LoRAModelConfig,
MainModelConfig,
@ -62,6 +63,7 @@ type DeleteMainModelResponse = void;
type ConvertMainModelArg = {
base_model: BaseModelType;
model_name: string;
body: ConvertModelConfig;
};
type ConvertMainModelResponse =
@ -176,10 +178,11 @@ export const modelsApi = api.injectEndpoints({
ConvertMainModelResponse,
ConvertMainModelArg
>({
query: ({ base_model, model_name }) => {
query: ({ base_model, model_name, body }) => {
return {
url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT',
body: body,
};
},
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],

View File

@ -378,6 +378,14 @@ export type components = {
*/
image_count: number;
};
/** Body_convert_model */
Body_convert_model: {
/**
* Convert Dest Directory
* @description Save the converted model to the designated directory
*/
convert_dest_directory?: string;
};
/** Body_create_board_image */
Body_create_board_image: {
/**
@ -5200,10 +5208,6 @@ export type operations = {
*/
convert_model: {
parameters: {
query?: {
/** @description Save the converted model to the designated directory */
convert_dest_directory?: string;
};
path: {
/** @description Base model */
base_model: components["schemas"]["BaseModelType"];
@ -5213,6 +5217,11 @@ export type operations = {
model_name: string;
};
};
requestBody?: {
content: {
"application/json": components["schemas"]["Body_convert_model"];
};
};
responses: {
/** @description Model converted successfully */
200: {

View File

@ -55,7 +55,9 @@ export type AnyModelConfig =
| ControlNetModelConfig
| TextualInversionModelConfig
| MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
// Graphs
export type Graph = components['schemas']['Graph'];