mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Custom location support for model conversion
This commit is contained in:
parent
8c8eddcc60
commit
9769b48661
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import Literal, List, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -10,11 +10,10 @@ from pydantic import BaseModel, parse_obj_as
|
|||||||
from starlette.exceptions import HTTPException
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management.models import (
|
|
||||||
OPENAPI_MODEL_CONFIGS,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
)
|
|
||||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||||
|
from invokeai.backend.model_management.models import (OPENAPI_MODEL_CONFIGS,
|
||||||
|
SchedulerPredictionType)
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
@ -25,32 +24,37 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList}},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
base_model: Optional[BaseModelType] = Query(default=None, description="Base model"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(
|
||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
base_model,
|
||||||
|
model_type)
|
||||||
|
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"description" : "The model was updated successfully"},
|
responses={200: {"description": "The model was updated successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
400: {"description" : "Bad request"}
|
400: {"description": "Bad request"}
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = UpdateModelResponse,
|
response_model=UpdateModelResponse,
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
@ -79,40 +83,41 @@ async def update_model(
|
|||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal['v_prediction', 'epsilon', 'sample']] =
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import=items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
if not info:
|
if not info:
|
||||||
logger.error("Import failed")
|
logger.error("Import failed")
|
||||||
raise HTTPException(status_code=424)
|
raise HTTPException(status_code=424)
|
||||||
|
|
||||||
logger.info(f'Successfully imported {location}, got {info}')
|
logger.info(f'Successfully imported {location}, got {info}')
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name,
|
model_name=info.name,
|
||||||
@ -120,22 +125,23 @@ async def import_model(
|
|||||||
model_type=info.model_type
|
model_type=info.model_type
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/add",
|
"/add",
|
||||||
operation_id="add_model",
|
operation_id="add_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
@ -144,7 +150,7 @@ async def add_model(
|
|||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -152,7 +158,7 @@ async def add_model(
|
|||||||
info.model_name,
|
info.model_name,
|
||||||
info.base_model,
|
info.base_model,
|
||||||
info.model_type,
|
info.model_type,
|
||||||
model_attributes = info.dict()
|
model_attributes=info.dict()
|
||||||
)
|
)
|
||||||
logger.info(f'Successfully added {info.model_name}')
|
logger.info(f'Successfully added {info.model_name}')
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@ -168,13 +174,14 @@ async def add_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/rename/{base_model}/{model_type}/{model_name}",
|
"/rename/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="rename_model",
|
operation_id="rename_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model was renamed successfully"},
|
201: {"description": "The model was renamed successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
409: {"description" : "There is already a model corresponding to the new name"},
|
409: {"description": "There is already a model corresponding to the new name"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse
|
||||||
@ -187,16 +194,16 @@ async def rename_model(
|
|||||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Rename a model"""
|
""" Rename a model"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
model_name = model_name,
|
model_name=model_name,
|
||||||
new_name = new_name,
|
new_name=new_name,
|
||||||
new_base = new_base,
|
new_base=new_base,
|
||||||
)
|
)
|
||||||
logger.debug(result)
|
logger.debug(result)
|
||||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||||
@ -212,16 +219,17 @@ async def rename_model(
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={
|
||||||
204: { "description": "Model deleted successfully" },
|
204: {"description": "Model deleted successfully"},
|
||||||
404: { "description": "Model not found" }
|
404: {"description": "Model not found"}
|
||||||
},
|
},
|
||||||
status_code = 204,
|
status_code=204,
|
||||||
response_model = None,
|
response_model=None,
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
@ -230,142 +238,145 @@ async def delete_model(
|
|||||||
) -> Response:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
ApiDependencies.invoker.services.model_manager.del_model(
|
||||||
base_model = base_model,
|
model_name, base_model=base_model, model_type=model_type)
|
||||||
model_type = model_type
|
|
||||||
)
|
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(f"Model not found: {model_name}")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: {"description" : "Bad request" },
|
400: {"description": "Bad request"},
|
||||||
404: { "description": "Model not found" },
|
404: {"description": "Model not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = ConvertModelResponse,
|
response_model=ConvertModelResponse,
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
convert_dest_directory: Optional[str] = Body(description="Save the converted model to the designated directory", default=None, embed=True)
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
dest = pathlib.Path(
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
convert_dest_directory) if convert_dest_directory else None
|
||||||
base_model = base_model,
|
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||||
model_type = model_type,
|
model_name, base_model=base_model, model_type=model_type,
|
||||||
convert_dest_directory = dest,
|
convert_dest_directory=dest,)
|
||||||
)
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
model_name, base_model=base_model, model_type=model_type)
|
||||||
base_model = base_model,
|
|
||||||
model_type = model_type)
|
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/search",
|
"/search",
|
||||||
operation_id="search_for_models",
|
operation_id="search_for_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Directory searched successfully" },
|
200: {"description": "Directory searched successfully"},
|
||||||
404: { "description": "Invalid directory path" },
|
404: {"description": "Invalid directory path"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path]
|
||||||
)
|
)
|
||||||
async def search_for_models(
|
async def search_for_models(
|
||||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||||
)->List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
raise HTTPException(
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
status_code=404,
|
||||||
|
detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||||
|
return ApiDependencies.invoker.services.model_manager.search_for_models([
|
||||||
|
search_path])
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/ckpt_confs",
|
"/ckpt_confs",
|
||||||
operation_id="list_ckpt_configs",
|
operation_id="list_ckpt_configs",
|
||||||
responses={
|
responses={
|
||||||
200: { "description" : "paths retrieved successfully" },
|
200: {"description": "paths retrieved successfully"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path]
|
||||||
)
|
)
|
||||||
async def list_ckpt_configs(
|
async def list_ckpt_configs(
|
||||||
)->List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_to_config",
|
operation_id="sync_to_config",
|
||||||
responses={
|
responses={
|
||||||
201: { "description": "synchronization successful" },
|
201: {"description": "synchronization successful"},
|
||||||
},
|
},
|
||||||
status_code = 201,
|
status_code=201,
|
||||||
response_model = None
|
response_model=None
|
||||||
)
|
)
|
||||||
async def sync_to_config(
|
async def sync_to_config(
|
||||||
)->None:
|
) -> None:
|
||||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
in-memory data structures with disk data structures."""
|
in-memory data structures with disk data structures."""
|
||||||
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
return ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: { "description": "Incompatible models" },
|
400: {"description": "Incompatible models"},
|
||||||
404: { "description": "One or more models not found" },
|
404: {"description": "One or more models not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = MergeModelResponse,
|
response_model=MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
dest = pathlib.Path(
|
||||||
base_model,
|
merge_dest_directory) if merge_dest_directory else None
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||||
alpha=alpha,
|
model_names, base_model,
|
||||||
interp=interp,
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
force=force,
|
alpha=alpha, interp=interp, force=force, merge_dest_directory=dest)
|
||||||
merge_dest_directory = dest
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
)
|
result.name, base_model=base_model, model_type=ModelType.Main, )
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
|
||||||
base_model = base_model,
|
|
||||||
model_type = ModelType.Main,
|
|
||||||
)
|
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"One or more of the models '{model_names}' not found")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
@ -415,6 +415,8 @@
|
|||||||
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
|
||||||
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
|
||||||
"convertToDiffusersSaveLocation": "Save Location",
|
"convertToDiffusersSaveLocation": "Save Location",
|
||||||
|
"noCustomLocationProvided": "No Custom Location Provided",
|
||||||
|
"convertingModelBegin": "Converting Model. Please wait.",
|
||||||
"v1": "v1",
|
"v1": "v1",
|
||||||
"v2_base": "v2 (512px)",
|
"v2_base": "v2 (512px)",
|
||||||
"v2_768": "v2 (768px)",
|
"v2_768": "v2 (768px)",
|
||||||
|
@ -1,9 +1,18 @@
|
|||||||
import { Flex, ListItem, Text, UnorderedList } from '@chakra-ui/react';
|
import {
|
||||||
// import { convertToDiffusers } from 'app/socketio/actions';
|
Flex,
|
||||||
|
ListItem,
|
||||||
|
Radio,
|
||||||
|
RadioGroup,
|
||||||
|
Text,
|
||||||
|
Tooltip,
|
||||||
|
UnorderedList,
|
||||||
|
} from '@chakra-ui/react';
|
||||||
import { makeToast } from 'app/components/Toaster';
|
import { makeToast } from 'app/components/Toaster';
|
||||||
|
// import { convertToDiffusers } from 'app/socketio/actions';
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||||
import IAIButton from 'common/components/IAIButton';
|
import IAIButton from 'common/components/IAIButton';
|
||||||
|
import IAIInput from 'common/components/IAIInput';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { useEffect, useState } from 'react';
|
import { useEffect, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@ -15,6 +24,8 @@ interface ModelConvertProps {
|
|||||||
model: CheckpointModelConfig;
|
model: CheckpointModelConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SaveLocation = 'InvokeAIRoot' | 'Custom';
|
||||||
|
|
||||||
export default function ModelConvert(props: ModelConvertProps) {
|
export default function ModelConvert(props: ModelConvertProps) {
|
||||||
const { model } = props;
|
const { model } = props;
|
||||||
|
|
||||||
@ -23,22 +34,51 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
|
|
||||||
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
const [convertModel, { isLoading }] = useConvertMainModelsMutation();
|
||||||
|
|
||||||
const [saveLocation, setSaveLocation] = useState<string>('same');
|
const [saveLocation, setSaveLocation] =
|
||||||
|
useState<SaveLocation>('InvokeAIRoot');
|
||||||
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
const [customSaveLocation, setCustomSaveLocation] = useState<string>('');
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setSaveLocation('same');
|
setSaveLocation('InvokeAIRoot');
|
||||||
}, [model]);
|
}, [model]);
|
||||||
|
|
||||||
const modelConvertCancelHandler = () => {
|
const modelConvertCancelHandler = () => {
|
||||||
setSaveLocation('same');
|
setSaveLocation('InvokeAIRoot');
|
||||||
};
|
};
|
||||||
|
|
||||||
const modelConvertHandler = () => {
|
const modelConvertHandler = () => {
|
||||||
const responseBody = {
|
const responseBody = {
|
||||||
base_model: model.base_model,
|
base_model: model.base_model,
|
||||||
model_name: model.model_name,
|
model_name: model.model_name,
|
||||||
|
body: {
|
||||||
|
convert_dest_directory:
|
||||||
|
saveLocation === 'Custom' ? customSaveLocation : undefined,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if (saveLocation === 'Custom' && customSaveLocation === '') {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('modelManager.noCustomLocationProvided'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: `${t('modelManager.convertingModelBegin')}: ${
|
||||||
|
model.model_name
|
||||||
|
}`,
|
||||||
|
status: 'success',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
convertModel(responseBody)
|
convertModel(responseBody)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((_) => {
|
.then((_) => {
|
||||||
@ -94,35 +134,30 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
|
<Text>{t('modelManager.convertToDiffusersHelpText6')}</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
{/* <Flex flexDir="column" gap={4}>
|
<Flex flexDir="column" gap={2}>
|
||||||
<Flex marginTop={4} flexDir="column" gap={2}>
|
<Flex marginTop={4} flexDir="column" gap={2}>
|
||||||
<Text fontWeight="600">
|
<Text fontWeight="600">
|
||||||
{t('modelManager.convertToDiffusersSaveLocation')}
|
{t('modelManager.convertToDiffusersSaveLocation')}
|
||||||
</Text>
|
</Text>
|
||||||
<RadioGroup value={saveLocation} onChange={(v) => setSaveLocation(v)}>
|
<RadioGroup
|
||||||
|
value={saveLocation}
|
||||||
|
onChange={(v) => setSaveLocation(v as SaveLocation)}
|
||||||
|
>
|
||||||
<Flex gap={4}>
|
<Flex gap={4}>
|
||||||
<Radio value="same">
|
<Radio value="InvokeAIRoot">
|
||||||
<Tooltip label="Save converted model in the same folder">
|
|
||||||
{t('modelManager.sameFolder')}
|
|
||||||
</Tooltip>
|
|
||||||
</Radio>
|
|
||||||
|
|
||||||
<Radio value="root">
|
|
||||||
<Tooltip label="Save converted model in the InvokeAI root folder">
|
<Tooltip label="Save converted model in the InvokeAI root folder">
|
||||||
{t('modelManager.invokeRoot')}
|
{t('modelManager.invokeRoot')}
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Radio>
|
</Radio>
|
||||||
|
<Radio value="Custom">
|
||||||
<Radio value="custom">
|
|
||||||
<Tooltip label="Save converted model in a custom folder">
|
<Tooltip label="Save converted model in a custom folder">
|
||||||
{t('modelManager.custom')}
|
{t('modelManager.custom')}
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
</Radio>
|
</Radio>
|
||||||
</Flex>
|
</Flex>
|
||||||
</RadioGroup>
|
</RadioGroup>
|
||||||
</Flex> */}
|
</Flex>
|
||||||
|
{saveLocation === 'Custom' && (
|
||||||
{/* {saveLocation === 'custom' && (
|
|
||||||
<Flex flexDirection="column" rowGap={2}>
|
<Flex flexDirection="column" rowGap={2}>
|
||||||
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
<Text fontWeight="500" fontSize="sm" variant="subtext">
|
||||||
{t('modelManager.customSaveLocation')}
|
{t('modelManager.customSaveLocation')}
|
||||||
@ -130,13 +165,13 @@ export default function ModelConvert(props: ModelConvertProps) {
|
|||||||
<IAIInput
|
<IAIInput
|
||||||
value={customSaveLocation}
|
value={customSaveLocation}
|
||||||
onChange={(e) => {
|
onChange={(e) => {
|
||||||
if (e.target.value !== '')
|
setCustomSaveLocation(e.target.value);
|
||||||
setCustomSaveLocation(e.target.value);
|
|
||||||
}}
|
}}
|
||||||
width="full"
|
width="full"
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
)} */}
|
)}
|
||||||
|
</Flex>
|
||||||
</IAIAlertDialog>
|
</IAIAlertDialog>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import {
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
CheckpointModelConfig,
|
CheckpointModelConfig,
|
||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
|
ConvertModelConfig,
|
||||||
DiffusersModelConfig,
|
DiffusersModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
MainModelConfig,
|
MainModelConfig,
|
||||||
@ -62,6 +63,7 @@ type DeleteMainModelResponse = void;
|
|||||||
type ConvertMainModelArg = {
|
type ConvertMainModelArg = {
|
||||||
base_model: BaseModelType;
|
base_model: BaseModelType;
|
||||||
model_name: string;
|
model_name: string;
|
||||||
|
body: ConvertModelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
type ConvertMainModelResponse =
|
type ConvertMainModelResponse =
|
||||||
@ -176,10 +178,11 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
ConvertMainModelResponse,
|
ConvertMainModelResponse,
|
||||||
ConvertMainModelArg
|
ConvertMainModelArg
|
||||||
>({
|
>({
|
||||||
query: ({ base_model, model_name }) => {
|
query: ({ base_model, model_name, body }) => {
|
||||||
return {
|
return {
|
||||||
url: `models/convert/${base_model}/main/${model_name}`,
|
url: `models/convert/${base_model}/main/${model_name}`,
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
|
body: body,
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
invalidatesTags: [{ type: 'MainModel', id: LIST_TAG }],
|
||||||
|
@ -378,6 +378,14 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
image_count: number;
|
image_count: number;
|
||||||
};
|
};
|
||||||
|
/** Body_convert_model */
|
||||||
|
Body_convert_model: {
|
||||||
|
/**
|
||||||
|
* Convert Dest Directory
|
||||||
|
* @description Save the converted model to the designated directory
|
||||||
|
*/
|
||||||
|
convert_dest_directory?: string;
|
||||||
|
};
|
||||||
/** Body_create_board_image */
|
/** Body_create_board_image */
|
||||||
Body_create_board_image: {
|
Body_create_board_image: {
|
||||||
/**
|
/**
|
||||||
@ -5200,10 +5208,6 @@ export type operations = {
|
|||||||
*/
|
*/
|
||||||
convert_model: {
|
convert_model: {
|
||||||
parameters: {
|
parameters: {
|
||||||
query?: {
|
|
||||||
/** @description Save the converted model to the designated directory */
|
|
||||||
convert_dest_directory?: string;
|
|
||||||
};
|
|
||||||
path: {
|
path: {
|
||||||
/** @description Base model */
|
/** @description Base model */
|
||||||
base_model: components["schemas"]["BaseModelType"];
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
@ -5213,6 +5217,11 @@ export type operations = {
|
|||||||
model_name: string;
|
model_name: string;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
requestBody?: {
|
||||||
|
content: {
|
||||||
|
"application/json": components["schemas"]["Body_convert_model"];
|
||||||
|
};
|
||||||
|
};
|
||||||
responses: {
|
responses: {
|
||||||
/** @description Model converted successfully */
|
/** @description Model converted successfully */
|
||||||
200: {
|
200: {
|
||||||
|
@ -55,7 +55,9 @@ export type AnyModelConfig =
|
|||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig;
|
| MainModelConfig;
|
||||||
|
|
||||||
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
export type MergeModelConfig = components['schemas']['Body_merge_models'];
|
||||||
|
export type ConvertModelConfig = components['schemas']['Body_convert_model'];
|
||||||
|
|
||||||
// Graphs
|
// Graphs
|
||||||
export type Graph = components['schemas']['Graph'];
|
export type Graph = components['schemas']['Graph'];
|
||||||
|
Loading…
Reference in New Issue
Block a user