feat: Add Custom location support for model conversion

This commit is contained in:
blessedcoolant 2023-07-15 19:17:16 +12:00
parent 8c8eddcc60
commit 9769b48661
6 changed files with 204 additions and 142 deletions

View File

@ -2,7 +2,7 @@
import pathlib import pathlib
from typing import Literal, List, Optional, Union from typing import List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -10,11 +10,10 @@ from pydantic import BaseModel, parse_obj_as
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
)
from invokeai.backend.model_management import MergeInterpolationMethod from invokeai.backend.model_management import MergeInterpolationMethod
from invokeai.backend.model_management.models import (OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -25,32 +24,37 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/",
operation_id="list_models", operation_id="list_models",
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList}},
) )
async def list_models( async def list_models(
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(
models = parse_obj_as(ModelsList, { "models": models_raw }) base_model,
model_type)
models = parse_obj_as(ModelsList, {"models": models_raw})
return models return models
@models_router.patch( @models_router.patch(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={200: {"description": "The model was updated successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
400: {"description" : "Bad request"} 400: {"description": "Bad request"}
}, },
status_code = 200, status_code=200,
response_model = UpdateModelResponse, response_model=UpdateModelResponse,
) )
async def update_model( async def update_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
@ -79,40 +83,41 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/import",
operation_id="import_model", operation_id="import_model",
responses= { responses={
201: {"description" : "The model imported successfully"}, 201: {"description": "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
) )
async def import_model( async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal['v_prediction', 'epsilon', 'sample']] =
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import=items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper=lambda x: prediction_types.get(prediction_type)
) )
info = installed_models.get(location) info = installed_models.get(location)
if not info: if not info:
logger.error("Import failed") logger.error("Import failed")
raise HTTPException(status_code=424) raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}') logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, model_name=info.name,
@ -120,22 +125,23 @@ async def import_model(
model_type=info.model_type model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except KeyError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post( @models_router.post(
"/add", "/add",
operation_id="add_model", operation_id="add_model",
responses= { responses={
201: {"description" : "The model added successfully"}, 201: {"description": "The model added successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
@ -144,7 +150,7 @@ async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path""" """ Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
@ -152,7 +158,7 @@ async def add_model(
info.model_name, info.model_name,
info.base_model, info.base_model,
info.model_type, info.model_type,
model_attributes = info.dict() model_attributes=info.dict()
) )
logger.info(f'Successfully added {info.model_name}') logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -168,13 +174,14 @@ async def add_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post( @models_router.post(
"/rename/{base_model}/{model_type}/{model_name}", "/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model", operation_id="rename_model",
responses= { responses={
201: {"description" : "The model was renamed successfully"}, 201: {"description": "The model was renamed successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"}, 409: {"description": "There is already a model corresponding to the new name"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse
@ -187,16 +194,16 @@ async def rename_model(
new_base: Optional[BaseModelType] = Query(description="new model base", default=None), new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Rename a model""" """ Rename a model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
result = ApiDependencies.invoker.services.model_manager.rename_model( result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model, base_model=base_model,
model_type = model_type, model_type=model_type,
model_name = model_name, model_name=model_name,
new_name = new_name, new_name=new_name,
new_base = new_base, new_base=new_base,
) )
logger.debug(result) logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}') logger.info(f'Successfully renamed {model_name}=>{new_name}')
@ -212,16 +219,17 @@ async def rename_model(
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { "description": "Model deleted successfully" }, 204: {"description": "Model deleted successfully"},
404: { "description": "Model not found" } 404: {"description": "Model not found"}
}, },
status_code = 204, status_code=204,
response_model = None, response_model=None,
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
@ -230,142 +238,145 @@ async def delete_model(
) -> Response: ) -> Response:
"""Delete Model""" """Delete Model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
ApiDependencies.invoker.services.model_manager.del_model(model_name, ApiDependencies.invoker.services.model_manager.del_model(
base_model = base_model, model_name, base_model=base_model, model_type=model_type)
model_type = model_type
)
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
except KeyError: except KeyError:
logger.error(f"Model not found: {model_name}") logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(
status_code=404, detail=f"Model '{model_name}' not found")
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model", operation_id="convert_model",
responses={ responses={
200: { "description": "Model converted successfully" }, 200: {"description": "Model converted successfully"},
400: {"description" : "Bad request" }, 400: {"description": "Bad request"},
404: { "description": "Model not found" }, 404: {"description": "Model not found"},
}, },
status_code = 200, status_code=200,
response_model = ConvertModelResponse, response_model=ConvertModelResponse,
) )
async def convert_model( async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), convert_dest_directory: Optional[str] = Body(description="Save the converted model to the designated directory", default=None, embed=True)
) -> ConvertModelResponse: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None dest = pathlib.Path(
ApiDependencies.invoker.services.model_manager.convert_model(model_name, convert_dest_directory) if convert_dest_directory else None
base_model = base_model, ApiDependencies.invoker.services.model_manager.convert_model(
model_type = model_type, model_name, base_model=base_model, model_type=model_type,
convert_dest_directory = dest, convert_dest_directory=dest,)
) model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, model_name, base_model=base_model, model_type=model_type)
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(
status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get( @models_router.get(
"/search", "/search",
operation_id="search_for_models", operation_id="search_for_models",
responses={ responses={
200: { "description": "Directory searched successfully" }, 200: {"description": "Directory searched successfully"},
404: { "description": "Invalid directory path" }, 404: {"description": "Invalid directory path"},
}, },
status_code = 200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path]
) )
async def search_for_models( async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models") search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]: ) -> List[pathlib.Path]:
if not search_path.is_dir(): if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") raise HTTPException(
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([
search_path])
@models_router.get( @models_router.get(
"/ckpt_confs", "/ckpt_confs",
operation_id="list_ckpt_configs", operation_id="list_ckpt_configs",
responses={ responses={
200: { "description" : "paths retrieved successfully" }, 200: {"description": "paths retrieved successfully"},
}, },
status_code = 200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path]
) )
async def list_ckpt_configs( async def list_ckpt_configs(
)->List[pathlib.Path]: ) -> List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.get( @models_router.get(
"/sync", "/sync",
operation_id="sync_to_config", operation_id="sync_to_config",
responses={ responses={
201: { "description": "synchronization successful" }, 201: {"description": "synchronization successful"},
}, },
status_code = 201, status_code=201,
response_model = None response_model=None
) )
async def sync_to_config( async def sync_to_config(
)->None: ) -> None:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize """Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures.""" in-memory data structures with disk data structures."""
return ApiDependencies.invoker.services.model_manager.sync_to_config() return ApiDependencies.invoker.services.model_manager.sync_to_config()
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",
responses={ responses={
200: { "description": "Model converted successfully" }, 200: {"description": "Model converted successfully"},
400: { "description": "Incompatible models" }, 400: {"description": "Incompatible models"},
404: { "description": "One or more models not found" }, 404: {"description": "One or more models not found"},
}, },
status_code = 200, status_code=200,
response_model = MergeModelResponse, response_model=MergeModelResponse,
) )
async def merge_models( async def merge_models(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3), model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"), merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, dest = pathlib.Path(
base_model, merge_dest_directory) if merge_dest_directory else None
merged_model_name=merged_model_name or "+".join(model_names), result = ApiDependencies.invoker.services.model_manager.merge_models(
alpha=alpha, model_names, base_model,
interp=interp, merged_model_name=merged_model_name or "+".join(model_names),
force=force, alpha=alpha, interp=interp, force=force, merge_dest_directory=dest)
merge_dest_directory = dest model_raw = ApiDependencies.invoker.services.model_manager.list_model(
) result.name, base_model=base_model, model_type=ModelType.Main, )
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except KeyError:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") raise HTTPException(
status_code=404,
detail=f"One or more of the models '{model_names}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response

View File

@ -415,6 +415,8 @@
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.", "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersHelpText6": "Do you wish to convert this model?",
"convertToDiffusersSaveLocation": "Save Location", "convertToDiffusersSaveLocation": "Save Location",
"noCustomLocationProvided": "No Custom Location Provided",
"convertingModelBegin": "Converting Model. Please wait.",
"v1": "v1", "v1": "v1",
"v2_base": "v2 (512px)", "v2_base": "v2 (512px)",
"v2_768": "v2 (768px)", "v2_768": "v2 (768px)",

View File

@ -1,9 +1,18 @@
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react'; import {
// import { convertToDiffusers } from 'app/socketio/actions'; Flex,
ListItem,
Radio,
RadioGroup,
Text,
Tooltip,
UnorderedList,
} from '@chakra-ui/react';
import { makeToast } from 'app/components/Toaster'; import { makeToast } from 'app/components/Toaster';
// import { convertToDiffusers } from 'app/socketio/actions';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
import IAIInput from 'common/components/IAIInput';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -15,6 +24,8 @@ interface ModelConvertProps {
model: CheckpointModelConfig; model: CheckpointModelConfig;
} }
type SaveLocation = 'InvokeAIRoot' | 'Custom';
export default function ModelConvert(props: ModelConvertProps) { export default function ModelConvert(props: ModelConvertProps) {
const { model } = props; const { model } = props;
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
const [convertModel, { isLoading }] = useConvertMainModelsMutation(); const [convertModel, { isLoading }] = useConvertMainModelsMutation();
const [saveLocation, setSaveLocation] = useState<string>('same'); const [saveLocation, setSaveLocation] =
useState<SaveLocation>('InvokeAIRoot');
const [customSaveLocation, setCustomSaveLocation] = useState<string>(''); const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
useEffect(() => { useEffect(() => {
setSaveLocation('same'); setSaveLocation('InvokeAIRoot');
}, [model]); }, [model]);
const modelConvertCancelHandler = () => { const modelConvertCancelHandler = () => {
setSaveLocation('same'); setSaveLocation('InvokeAIRoot');
}; };
const modelConvertHandler = () => { const modelConvertHandler = () => {
const responseBody = { const responseBody = {
base_model: model.base_model, base_model: model.base_model,
model_name: model.model_name, model_name: model.model_name,
body: {
convert_dest_directory:
saveLocation === 'Custom' ? customSaveLocation : undefined,
},
}; };
if (saveLocation === 'Custom' && customSaveLocation === '') {
dispatch(
addToast(
makeToast({
title: t('modelManager.noCustomLocationProvided'),
status: 'error',
})
)
);
return;
}
dispatch(
addToast(
makeToast({
title: `${t('modelManager.convertingModelBegin')}: ${
model.model_name
}`,
status: 'success',
})
)
);
convertModel(responseBody) convertModel(responseBody)
.unwrap() .unwrap()
.then((_) => { .then((_) => {
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text> <Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
</Flex> </Flex>
{/* <Flex flexDir="column" gap={4}> <Flex flexDir="column" gap={2}>
<Flex marginTop={4} flexDir="column" gap={2}> <Flex marginTop={4} flexDir="column" gap={2}>
<Text fontWeight="600"> <Text fontWeight="600">
{t('modelManager.convertToDiffusersSaveLocation')} {t('modelManager.convertToDiffusersSaveLocation')}
</Text> </Text>
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}> <RadioGroup
value={saveLocation}
onChange={(v) => setSaveLocation(v as SaveLocation)}
>
<Flex gap={4}> <Flex gap={4}>
<Radio value="same"> <Radio value="InvokeAIRoot">
<Tooltip label="Save converted model in the same folder">
{t('modelManager.sameFolder')}
</Tooltip>
</Radio>
<Radio value="root">
<Tooltip label="Save converted model in the InvokeAI root folder"> <Tooltip label="Save converted model in the InvokeAI root folder">
{t('modelManager.invokeRoot')} {t('modelManager.invokeRoot')}
</Tooltip> </Tooltip>
</Radio> </Radio>
<Radio value="Custom">
<Radio value="custom">
<Tooltip label="Save converted model in a custom folder"> <Tooltip label="Save converted model in a custom folder">
{t('modelManager.custom')} {t('modelManager.custom')}
</Tooltip> </Tooltip>
</Radio> </Radio>
</Flex> </Flex>
</RadioGroup> </RadioGroup>
</Flex> */} </Flex>
{saveLocation === 'Custom' && (
{/* {saveLocation === 'custom' && (
<Flex flexDirection="column" rowGap={2}> <Flex flexDirection="column" rowGap={2}>
<Text fontWeight="500" fontSize="sm" variant="subtext"> <Text fontWeight="500" fontSize="sm" variant="subtext">
{t('modelManager.customSaveLocation')} {t('modelManager.customSaveLocation')}
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
<IAIInput <IAIInput
value={customSaveLocation} value={customSaveLocation}
onChange={(e) => { onChange={(e) => {
if (e.target.value !== '') setCustomSaveLocation(e.target.value);
setCustomSaveLocation(e.target.value);
}} }}
width="full" width="full"
/> />
</Flex> </Flex>
)} */} )}
</Flex>
</IAIAlertDialog> </IAIAlertDialog>
); );
} }

View File

@ -5,6 +5,7 @@ import {
BaseModelType, BaseModelType,
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
ConvertModelConfig,
DiffusersModelConfig, DiffusersModelConfig,
LoRAModelConfig, LoRAModelConfig,
MainModelConfig, MainModelConfig,
@ -62,6 +63,7 @@ type DeleteMainModelResponse = void;
type ConvertMainModelArg = { type ConvertMainModelArg = {
base_model: BaseModelType; base_model: BaseModelType;
model_name: string; model_name: string;
body: ConvertModelConfig;
}; };
type ConvertMainModelResponse = type ConvertMainModelResponse =
@ -176,10 +178,11 @@ export const modelsApi = api.injectEndpoints({
ConvertMainModelResponse, ConvertMainModelResponse,
ConvertMainModelArg ConvertMainModelArg
>({ >({
query: ({ base_model, model_name }) => { query: ({ base_model, model_name, body }) => {
return { return {
url: `models/convert/${base_model}/main/${model_name}`, url: `models/convert/${base_model}/main/${model_name}`,
method: 'PUT', method: 'PUT',
body: body,
}; };
}, },
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],

View File

@ -378,6 +378,14 @@ export type components = {
*/ */
image_count: number; image_count: number;
}; };
/** Body_convert_model */
Body_convert_model: {
/**
* Convert Dest Directory
* @description Save the converted model to the designated directory
*/
convert_dest_directory?: string;
};
/** Body_create_board_image */ /** Body_create_board_image */
Body_create_board_image: { Body_create_board_image: {
/** /**
@ -5200,10 +5208,6 @@ export type operations = {
*/ */
convert_model: { convert_model: {
parameters: { parameters: {
query?: {
/** @description Save the converted model to the designated directory */
convert_dest_directory?: string;
};
path: { path: {
/** @description Base model */ /** @description Base model */
base_model: components["schemas"]["BaseModelType"]; base_model: components["schemas"]["BaseModelType"];
@ -5213,6 +5217,11 @@ export type operations = {
model_name: string; model_name: string;
}; };
}; };
requestBody?: {
content: {
"application/json": components["schemas"]["Body_convert_model"];
};
};
responses: { responses: {
/** @description Model converted successfully */ /** @description Model converted successfully */
200: { 200: {

View File

@ -55,7 +55,9 @@ export type AnyModelConfig =
| ControlNetModelConfig | ControlNetModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig; | MainModelConfig;
export type MergeModelConfig = components['schemas']['Body_merge_models']; export type MergeModelConfig = components['schemas']['Body_merge_models'];
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
// Graphs // Graphs
export type Graph = components['schemas']['Graph']; export type Graph = components['schemas']['Graph'];