From 9769b4866115ea3a498cf2b0cc7d3a762b825732 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Sat, 15 Jul 2023 19:17:16 +1200 Subject: [PATCH] feat: Add Custom location support for model conversion --- invokeai/app/api/routers/models.py | 241 +++++++++--------- invokeai/frontend/web/public/locales/en.json | 2 + .../ModelManagerPanel/ModelConvert.tsx | 79 ++++-- .../web/src/services/api/endpoints/models.ts | 5 +- .../frontend/web/src/services/api/schema.d.ts | 17 +- .../frontend/web/src/services/api/types.d.ts | 2 + 6 files changed, 204 insertions(+), 142 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index c298114cbc..12b7f991f4 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,7 +2,7 @@ import pathlib -from typing import Literal, List, Optional, Union +from typing import List, Literal, Optional, Union from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter @@ -10,11 +10,10 @@ from pydantic import BaseModel, parse_obj_as from starlette.exceptions import HTTPException from invokeai.backend import BaseModelType, ModelType -from invokeai.backend.model_management.models import ( - OPENAPI_MODEL_CONFIGS, - SchedulerPredictionType, -) from invokeai.backend.model_management import MergeInterpolationMethod +from invokeai.backend.model_management.models import (OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType) + from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -25,32 +24,37 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] + class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + @models_router.get( "/", operation_id="list_models", - responses={200: {"model": ModelsList }}, + responses={200: {"model": ModelsList}}, ) async def list_models( base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" - models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) - models = parse_obj_as(ModelsList, { "models": models_raw }) + models_raw = ApiDependencies.invoker.services.model_manager.list_models( + base_model, + model_type) + models = parse_obj_as(ModelsList, {"models": models_raw}) return models + @models_router.patch( "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"description" : "The model was updated successfully"}, - 404: {"description" : "The model could not be found"}, - 400: {"description" : "Bad request"} + responses={200: {"description": "The model was updated successfully"}, + 404: {"description": "The model could not be found"}, + 400: {"description": "Bad request"} }, - status_code = 200, - response_model = UpdateModelResponse, + status_code=200, + response_model=UpdateModelResponse, ) async def update_model( base_model: BaseModelType = Path(description="Base model"), @@ -79,40 +83,41 @@ async def update_model( return model_response + @models_router.post( "/import", operation_id="import_model", - responses= { - 201: {"description" : "The model imported successfully"}, - 404: {"description" : "The model could not be found"}, - 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model imported successfully"}, + 404: {"description": "The model could not be found"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, response_model=ImportModelResponse ) async def import_model( location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ - Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), + prediction_type: Optional[Literal['v_prediction', 'epsilon', 'sample']] = + Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), ) -> ImportModelResponse: """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ - + items_to_import = {location} - prediction_types = { x.value: x for x in SchedulerPredictionType } + prediction_types = {x.value: x for x in SchedulerPredictionType} logger = ApiDependencies.invoker.services.logger try: installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) + items_to_import=items_to_import, + prediction_type_helper=lambda x: prediction_types.get(prediction_type) ) info = installed_models.get(location) if not info: logger.error("Import failed") raise HTTPException(status_code=424) - + logger.info(f'Successfully imported {location}, got {info}') model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name=info.name, @@ -120,22 +125,23 @@ async def import_model( model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) - + except KeyError as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + + @models_router.post( "/add", operation_id="add_model", - responses= { - 201: {"description" : "The model added successfully"}, - 404: {"description" : "The model could not be found"}, - 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model added successfully"}, + 404: {"description": "The model could not be found"}, + 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, response_model=ImportModelResponse @@ -144,7 +150,7 @@ async def add_model( info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> ImportModelResponse: """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - + logger = ApiDependencies.invoker.services.logger try: @@ -152,7 +158,7 @@ async def add_model( info.model_name, info.base_model, info.model_type, - model_attributes = info.dict() + model_attributes=info.dict() ) logger.info(f'Successfully added {info.model_name}') model_raw = ApiDependencies.invoker.services.model_manager.list_model( @@ -168,13 +174,14 @@ async def add_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) + @models_router.post( "/rename/{base_model}/{model_type}/{model_name}", operation_id="rename_model", - responses= { - 201: {"description" : "The model was renamed successfully"}, - 404: {"description" : "The model could not be found"}, - 409: {"description" : "There is already a model corresponding to the new name"}, + responses={ + 201: {"description": "The model was renamed successfully"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to the new name"}, }, status_code=201, response_model=ImportModelResponse @@ -187,16 +194,16 @@ async def rename_model( new_base: Optional[BaseModelType] = Query(description="new model base", default=None), ) -> ImportModelResponse: """ Rename a model""" - + logger = ApiDependencies.invoker.services.logger try: result = ApiDependencies.invoker.services.model_manager.rename_model( - base_model = base_model, - model_type = model_type, - model_name = model_name, - new_name = new_name, - new_base = new_base, + base_model=base_model, + model_type=model_type, + model_name=model_name, + new_name=new_name, + new_base=new_base, ) logger.debug(result) logger.info(f'Successfully renamed {model_name}=>{new_name}') @@ -212,16 +219,17 @@ async def rename_model( except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", responses={ - 204: { "description": "Model deleted successfully" }, - 404: { "description": "Model not found" } + 204: {"description": "Model deleted successfully"}, + 404: {"description": "Model not found"} }, - status_code = 204, - response_model = None, + status_code=204, + response_model=None, ) async def delete_model( base_model: BaseModelType = Path(description="Base model"), @@ -230,142 +238,145 @@ async def delete_model( ) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger - + try: - ApiDependencies.invoker.services.model_manager.del_model(model_name, - base_model = base_model, - model_type = model_type - ) + ApiDependencies.invoker.services.model_manager.del_model( + model_name, base_model=base_model, model_type=model_type) logger.info(f"Deleted model: {model_name}") return Response(status_code=204) except KeyError: logger.error(f"Model not found: {model_name}") - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + raise HTTPException( + status_code=404, detail=f"Model '{model_name}' not found") + @models_router.put( "/convert/{base_model}/{model_type}/{model_name}", operation_id="convert_model", responses={ - 200: { "description": "Model converted successfully" }, - 400: {"description" : "Bad request" }, - 404: { "description": "Model not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, }, - status_code = 200, - response_model = ConvertModelResponse, + status_code=200, + response_model=ConvertModelResponse, ) async def convert_model( base_model: BaseModelType = Path(description="Base model"), model_type: ModelType = Path(description="The type of model"), model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), + convert_dest_directory: Optional[str] = Body(description="Save the converted model to the designated directory", default=None, embed=True) ) -> ConvertModelResponse: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") - dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model(model_name, - base_model = base_model, - model_type = model_type, - convert_dest_directory = dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, - base_model = base_model, - model_type = model_type) + dest = pathlib.Path( + convert_dest_directory) if convert_dest_directory else None + ApiDependencies.invoker.services.model_manager.convert_model( + model_name, base_model=base_model, model_type=model_type, + convert_dest_directory=dest,) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name, base_model=base_model, model_type=model_type) response = parse_obj_as(ConvertModelResponse, model_raw) except KeyError: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + raise HTTPException( + status_code=404, detail=f"Model '{model_name}' not found") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + @models_router.get( "/search", operation_id="search_for_models", responses={ - 200: { "description": "Directory searched successfully" }, - 404: { "description": "Invalid directory path" }, + 200: {"description": "Directory searched successfully"}, + 404: {"description": "Invalid directory path"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path] ) async def search_for_models( search_path: pathlib.Path = Query(description="Directory path to search for models") -)->List[pathlib.Path]: +) -> List[pathlib.Path]: if not search_path.is_dir(): - raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") - return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) + raise HTTPException( + status_code=404, + detail=f"The search path '{search_path}' does not exist or is not directory") + return ApiDependencies.invoker.services.model_manager.search_for_models([ + search_path]) + @models_router.get( "/ckpt_confs", operation_id="list_ckpt_configs", responses={ - 200: { "description" : "paths retrieved successfully" }, + 200: {"description": "paths retrieved successfully"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path] ) async def list_ckpt_configs( -)->List[pathlib.Path]: +) -> List[pathlib.Path]: """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - + + @models_router.get( "/sync", operation_id="sync_to_config", responses={ - 201: { "description": "synchronization successful" }, + 201: {"description": "synchronization successful"}, }, - status_code = 201, - response_model = None + status_code=201, + response_model=None ) async def sync_to_config( -)->None: +) -> None: """Call after making changes to models.yaml, autoimport directories or models directory to synchronize in-memory data structures with disk data structures.""" return ApiDependencies.invoker.services.model_manager.sync_to_config() - + + @models_router.put( "/merge/{base_model}", operation_id="merge_models", responses={ - 200: { "description": "Model converted successfully" }, - 400: { "description": "Incompatible models" }, - 404: { "description": "One or more models not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Incompatible models"}, + 404: {"description": "One or more models not found"}, }, - status_code = 200, - response_model = MergeModelResponse, + status_code=200, + response_model=MergeModelResponse, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description="Name of destination model"), - alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model"), + alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), - force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), - merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) + force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), + merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: - logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") - dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, - base_model, - merged_model_name=merged_model_name or "+".join(model_names), - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory = dest - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, - base_model = base_model, - model_type = ModelType.Main, - ) + logger.info( + f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") + dest = pathlib.Path( + merge_dest_directory) if merge_dest_directory else None + result = ApiDependencies.invoker.services.model_manager.merge_models( + model_names, base_model, + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, interp=interp, force=force, merge_dest_directory=dest) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + result.name, base_model=base_model, model_type=ModelType.Main, ) response = parse_obj_as(ConvertModelResponse, model_raw) except KeyError: - raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") + raise HTTPException( + status_code=404, + detail=f"One or more of the models '{model_names}' not found") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ae1778101d..36cf1d7af6 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -415,6 +415,8 @@ "convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.", "convertToDiffusersHelpText6": "Do you wish to convert this model?", "convertToDiffusersSaveLocation": "Save Location", + "noCustomLocationProvided": "No Custom Location Provided", + "convertingModelBegin": "Converting Model. Please wait.", "v1": "v1", "v2_base": "v2 (512px)", "v2_768": "v2 (768px)", diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx index 922fdacee7..9c7130f2ad 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelConvert.tsx @@ -1,9 +1,18 @@ -import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react'; -// import { convertToDiffusers } from 'app/socketio/actions'; +import { + Flex, + ListItem, + Radio, + RadioGroup, + Text, + Tooltip, + UnorderedList, +} from '@chakra-ui/react'; import { makeToast } from 'app/components/Toaster'; +// import { convertToDiffusers } from 'app/socketio/actions'; import { useAppDispatch } from 'app/store/storeHooks'; import IAIAlertDialog from 'common/components/IAIAlertDialog'; import IAIButton from 'common/components/IAIButton'; +import IAIInput from 'common/components/IAIInput'; import { addToast } from 'features/system/store/systemSlice'; import { useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -15,6 +24,8 @@ interface ModelConvertProps { model: CheckpointModelConfig; } +type SaveLocation = 'InvokeAIRoot' | 'Custom'; + export default function ModelConvert(props: ModelConvertProps) { const { model } = props; @@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) { const [convertModel, { isLoading }] = useConvertMainModelsMutation(); - const [saveLocation, setSaveLocation] = useState('same'); + const [saveLocation, setSaveLocation] = + useState('InvokeAIRoot'); const [customSaveLocation, setCustomSaveLocation] = useState(''); useEffect(() => { - setSaveLocation('same'); + setSaveLocation('InvokeAIRoot'); }, [model]); const modelConvertCancelHandler = () => { - setSaveLocation('same'); + setSaveLocation('InvokeAIRoot'); }; const modelConvertHandler = () => { const responseBody = { base_model: model.base_model, model_name: model.model_name, + body: { + convert_dest_directory: + saveLocation === 'Custom' ? customSaveLocation : undefined, + }, }; + + if (saveLocation === 'Custom' && customSaveLocation === '') { + dispatch( + addToast( + makeToast({ + title: t('modelManager.noCustomLocationProvided'), + status: 'error', + }) + ) + ); + return; + } + + dispatch( + addToast( + makeToast({ + title: `${t('modelManager.convertingModelBegin')}: ${ + model.model_name + }`, + status: 'success', + }) + ) + ); + convertModel(responseBody) .unwrap() .then((_) => { @@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) { {t('modelManager.convertToDiffusersHelpText6')} - {/* + {t('modelManager.convertToDiffusersSaveLocation')} - setSaveLocation(v)}> + setSaveLocation(v as SaveLocation)} + > - - - {t('modelManager.sameFolder')} - - - - + {t('modelManager.invokeRoot')} - - + {t('modelManager.custom')} - */} - - {/* {saveLocation === 'custom' && ( + + {saveLocation === 'Custom' && ( {t('modelManager.customSaveLocation')} @@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) { { - if (e.target.value !== '') - setCustomSaveLocation(e.target.value); + setCustomSaveLocation(e.target.value); }} width="full" /> - )} */} + )} + ); } diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index c86ad91100..79e685313e 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -5,6 +5,7 @@ import { BaseModelType, CheckpointModelConfig, ControlNetModelConfig, + ConvertModelConfig, DiffusersModelConfig, LoRAModelConfig, MainModelConfig, @@ -62,6 +63,7 @@ type DeleteMainModelResponse = void; type ConvertMainModelArg = { base_model: BaseModelType; model_name: string; + body: ConvertModelConfig; }; type ConvertMainModelResponse = @@ -176,10 +178,11 @@ export const modelsApi = api.injectEndpoints({ ConvertMainModelResponse, ConvertMainModelArg >({ - query: ({ base_model, model_name }) => { + query: ({ base_model, model_name, body }) => { return { url: `models/convert/${base_model}/main/${model_name}`, method: 'PUT', + body: body, }; }, invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }], diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 2ae5109f4f..610e9fa05e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -378,6 +378,14 @@ export type components = { */ image_count: number; }; + /** Body_convert_model */ + Body_convert_model: { + /** + * Convert Dest Directory + * @description Save the converted model to the designated directory + */ + convert_dest_directory?: string; + }; /** Body_create_board_image */ Body_create_board_image: { /** @@ -5200,10 +5208,6 @@ export type operations = { */ convert_model: { parameters: { - query?: { - /** @description Save the converted model to the designated directory */ - convert_dest_directory?: string; - }; path: { /** @description Base model */ base_model: components["schemas"]["BaseModelType"]; @@ -5213,6 +5217,11 @@ export type operations = { model_name: string; }; }; + requestBody?: { + content: { + "application/json": components["schemas"]["Body_convert_model"]; + }; + }; responses: { /** @description Model converted successfully */ 200: { diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index fcbbd1a6a0..57258fb19b 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -55,7 +55,9 @@ export type AnyModelConfig = | ControlNetModelConfig | TextualInversionModelConfig | MainModelConfig; + export type MergeModelConfig = components['schemas']['Body_merge_models']; +export type ConvertModelConfig = components['schemas']['Body_convert_model']; // Graphs export type Graph = components['schemas']['Graph'];