mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Apply black
This commit is contained in:
@ -78,9 +78,7 @@ class ApiDependencies:
|
||||
image_record_storage = SqliteImageRecordStorage(db_location)
|
||||
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
|
||||
names = SimpleNameService()
|
||||
latents = ForwardCacheLatentsStorage(
|
||||
DiskLatentsStorage(f"{output_folder}/latents")
|
||||
)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
|
||||
board_record_storage = SqliteBoardRecordStorage(db_location)
|
||||
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
|
||||
@ -125,9 +123,7 @@ class ApiDependencies:
|
||||
boards=boards,
|
||||
board_images=board_images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](
|
||||
filename=db_location, table_name="graphs"
|
||||
),
|
||||
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
processor=DefaultInvocationProcessor(),
|
||||
configuration=config,
|
||||
|
@ -15,6 +15,7 @@ from invokeai.version import __version__
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.backend.util.logging import logging
|
||||
|
||||
|
||||
class LogLevel(int, Enum):
|
||||
NotSet = logging.NOTSET
|
||||
Debug = logging.DEBUG
|
||||
@ -23,10 +24,12 @@ class LogLevel(int, Enum):
|
||||
Error = logging.ERROR
|
||||
Critical = logging.CRITICAL
|
||||
|
||||
|
||||
class Upscaler(BaseModel):
|
||||
upscaling_method: str = Field(description="Name of upscaling method")
|
||||
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
|
||||
|
||||
|
||||
|
||||
app_router = APIRouter(prefix="/v1/app", tags=["app"])
|
||||
|
||||
|
||||
@ -45,38 +48,30 @@ class AppConfig(BaseModel):
|
||||
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
|
||||
)
|
||||
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
|
||||
async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
|
||||
)
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ['tile']
|
||||
infill_methods = ["tile"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append('patchmatch')
|
||||
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
upscaling_models = []
|
||||
for model in typing.get_args(ESRGAN_MODELS):
|
||||
upscaling_models.append(str(Path(model).stem))
|
||||
upscaler = Upscaler(
|
||||
upscaling_method = 'esrgan',
|
||||
upscaling_models = upscaling_models
|
||||
)
|
||||
|
||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||
|
||||
nsfw_methods = []
|
||||
if SafetyChecker.safety_checker_available():
|
||||
nsfw_methods.append('nsfw_checker')
|
||||
nsfw_methods.append("nsfw_checker")
|
||||
|
||||
watermarking_methods = []
|
||||
if InvisibleWatermark.invisible_watermark_available():
|
||||
watermarking_methods.append('invisible_watermark')
|
||||
|
||||
watermarking_methods.append("invisible_watermark")
|
||||
|
||||
return AppConfig(
|
||||
infill_methods=infill_methods,
|
||||
upscaling_methods=[upscaler],
|
||||
@ -84,25 +79,26 @@ async def get_config() -> AppConfig:
|
||||
watermarking_methods=watermarking_methods,
|
||||
)
|
||||
|
||||
|
||||
@app_router.get(
|
||||
"/logging",
|
||||
operation_id="get_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def get_log_level(
|
||||
) -> LogLevel:
|
||||
async def get_log_level() -> LogLevel:
|
||||
"""Returns the log level"""
|
||||
return LogLevel(ApiDependencies.invoker.services.logger.level)
|
||||
|
||||
|
||||
@app_router.post(
|
||||
"/logging",
|
||||
operation_id="set_log_level",
|
||||
responses={200: {"description" : "The operation was successful"}},
|
||||
response_model = LogLevel,
|
||||
responses={200: {"description": "The operation was successful"}},
|
||||
response_model=LogLevel,
|
||||
)
|
||||
async def set_log_level(
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
level: LogLevel = Body(description="New log verbosity level"),
|
||||
) -> LogLevel:
|
||||
"""Sets the log verbosity level"""
|
||||
ApiDependencies.invoker.services.logger.setLevel(level)
|
||||
|
@ -52,4 +52,3 @@ async def remove_board_image(
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel):
|
||||
deleted_board_images: list[str] = Field(
|
||||
description="The image names of the board-images relationships that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(
|
||||
description="The names of the images that were deleted."
|
||||
)
|
||||
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
|
||||
|
||||
|
||||
@boards_router.post(
|
||||
@ -73,22 +71,16 @@ async def update_board(
|
||||
) -> BoardDTO:
|
||||
"""Updates a board"""
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.boards.update(
|
||||
board_id=board_id, changes=changes
|
||||
)
|
||||
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail="Failed to update board")
|
||||
|
||||
|
||||
@boards_router.delete(
|
||||
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
|
||||
)
|
||||
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
|
||||
async def delete_board(
|
||||
board_id: str = Path(description="The id of board to delete"),
|
||||
include_images: Optional[bool] = Query(
|
||||
description="Permanently delete all images on the board", default=False
|
||||
),
|
||||
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
|
||||
) -> DeleteBoardResult:
|
||||
"""Deletes a board"""
|
||||
try:
|
||||
@ -96,9 +88,7 @@ async def delete_board(
|
||||
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(
|
||||
board_id=board_id
|
||||
)
|
||||
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
|
||||
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
|
||||
return DeleteBoardResult(
|
||||
board_id=board_id,
|
||||
@ -127,9 +117,7 @@ async def delete_board(
|
||||
async def list_boards(
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(
|
||||
default=None, description="The number of boards per page"
|
||||
),
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
|
@ -40,15 +40,9 @@ async def upload_image(
|
||||
response: Response,
|
||||
image_category: ImageCategory = Query(description="The category of the image"),
|
||||
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None, description="The board to add this image to, if any"
|
||||
),
|
||||
session_id: Optional[str] = Query(
|
||||
default=None, description="The session ID associated with this upload, if any"
|
||||
),
|
||||
crop_visible: Optional[bool] = Query(
|
||||
default=False, description="Whether to crop the image"
|
||||
),
|
||||
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
|
||||
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
|
||||
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
|
||||
) -> ImageDTO:
|
||||
"""Uploads an image"""
|
||||
if not file.content_type.startswith("image"):
|
||||
@ -115,9 +109,7 @@ async def clear_intermediates() -> int:
|
||||
)
|
||||
async def update_image(
|
||||
image_name: str = Path(description="The name of the image to update"),
|
||||
image_changes: ImageRecordChanges = Body(
|
||||
description="The changes to apply to the image"
|
||||
),
|
||||
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
|
||||
) -> ImageDTO:
|
||||
"""Updates an image"""
|
||||
|
||||
@ -212,15 +204,11 @@ async def get_image_thumbnail(
|
||||
"""Gets a thumbnail image file"""
|
||||
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.images.get_path(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
|
||||
if not ApiDependencies.invoker.services.images.validate_path(path):
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
response = FileResponse(
|
||||
path, media_type="image/webp", content_disposition_type="inline"
|
||||
)
|
||||
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
return response
|
||||
except Exception as e:
|
||||
@ -239,9 +227,7 @@ async def get_image_urls(
|
||||
|
||||
try:
|
||||
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
|
||||
image_name, thumbnail=True
|
||||
)
|
||||
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
|
||||
return ImageUrlsDTO(
|
||||
image_name=image_name,
|
||||
image_url=image_url,
|
||||
@ -257,15 +243,9 @@ async def get_image_urls(
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_image_dtos(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list."
|
||||
),
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include."
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images."
|
||||
),
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
|
@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
responses={200: {"model": ModelsList }},
|
||||
responses={200: {"model": ModelsList}},
|
||||
)
|
||||
async def list_models(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
) -> ModelsList:
|
||||
"""Gets a list of models"""
|
||||
if base_models and len(base_models)>0:
|
||||
if base_models and len(base_models) > 0:
|
||||
models_raw = list()
|
||||
for base_model in base_models:
|
||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||
else:
|
||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||
return models
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
400: {"description" : "Bad request"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
409: {"description" : "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
responses={
|
||||
200: {"description": "The model was updated successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "The model could not be found"},
|
||||
409: {"description": "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> UpdateModelResponse:
|
||||
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
|
||||
try:
|
||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
@ -81,13 +84,13 @@ async def update_model(
|
||||
# rename operation requested
|
||||
if info.model_name != model_name or info.base_model != base_model:
|
||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = info.model_name,
|
||||
new_base = info.base_model,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_name=model_name,
|
||||
new_name=info.model_name,
|
||||
new_base=info.base_model,
|
||||
)
|
||||
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
|
||||
logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||
# update information to support an update of attributes
|
||||
model_name = info.model_name
|
||||
base_model = info.base_model
|
||||
@ -96,16 +99,15 @@ async def update_model(
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get('path')
|
||||
|
||||
if new_info.get("path") != previous_info.get(
|
||||
"path"
|
||||
): # model manager moved model path during rename - don't overwrite it
|
||||
info.path = new_info.get("path")
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||
)
|
||||
|
||||
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@ -123,49 +125,48 @@ async def update_model(
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
415: {"description" : "Unrecognized file/folder format"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@ -175,38 +176,34 @@ async def import_model(
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses= {
|
||||
201: {"description" : "The model added successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
responses={
|
||||
201: {"description": "The model added successfully"},
|
||||
404: {"description": "The model could not be found"},
|
||||
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
response_model=ImportModelResponse,
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes = info.dict()
|
||||
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||
)
|
||||
logger.info(f'Successfully added {info.model_name}')
|
||||
logger.info(f"Successfully added {info.model_name}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
@ -216,66 +213,66 @@ async def add_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
responses={
|
||||
204: { "description": "Model deleted successfully" },
|
||||
404: { "description": "Model not found" }
|
||||
},
|
||||
status_code = 204,
|
||||
response_model = None,
|
||||
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||
status_code=204,
|
||||
response_model=None,
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type
|
||||
)
|
||||
ApiDependencies.invoker.services.model_manager.del_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
return Response(status_code=204)
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/convert/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: {"description" : "Bad request" },
|
||||
404: { "description": "Model not found" },
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "Model not found"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = ConvertModelResponse,
|
||||
status_code=200,
|
||||
response_model=ConvertModelResponse,
|
||||
)
|
||||
async def convert_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
convert_dest_directory: Optional[str] = Query(
|
||||
default=None, description="Save the converted model to the designated directory"
|
||||
),
|
||||
) -> ConvertModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Converting model: {model_name}")
|
||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
convert_dest_directory = dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
||||
base_model = base_model,
|
||||
model_type = model_type)
|
||||
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||
model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
convert_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException as e:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||
@ -283,91 +280,101 @@ async def convert_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return response
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/search",
|
||||
operation_id="search_for_models",
|
||||
responses={
|
||||
200: { "description": "Directory searched successfully" },
|
||||
404: { "description": "Invalid directory path" },
|
||||
200: {"description": "Directory searched successfully"},
|
||||
404: {"description": "Invalid directory path"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def search_for_models(
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
||||
)->List[pathlib.Path]:
|
||||
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||
) -> List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||
)
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
operation_id="list_ckpt_configs",
|
||||
responses={
|
||||
200: { "description" : "paths retrieved successfully" },
|
||||
200: {"description": "paths retrieved successfully"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = List[pathlib.Path]
|
||||
status_code=200,
|
||||
response_model=List[pathlib.Path],
|
||||
)
|
||||
async def list_ckpt_configs(
|
||||
)->List[pathlib.Path]:
|
||||
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||
|
||||
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/sync",
|
||||
operation_id="sync_to_config",
|
||||
responses={
|
||||
201: { "description": "synchronization successful" },
|
||||
201: {"description": "synchronization successful"},
|
||||
},
|
||||
status_code = 201,
|
||||
response_model = bool
|
||||
status_code=201,
|
||||
response_model=bool,
|
||||
)
|
||||
async def sync_to_config(
|
||||
)->bool:
|
||||
async def sync_to_config() -> bool:
|
||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||
in-memory data structures with disk data structures."""
|
||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@models_router.put(
|
||||
"/merge/{base_model}",
|
||||
operation_id="merge_models",
|
||||
responses={
|
||||
200: { "description": "Model converted successfully" },
|
||||
400: { "description": "Incompatible models" },
|
||||
404: { "description": "One or more models not found" },
|
||||
200: {"description": "Model converted successfully"},
|
||||
400: {"description": "Incompatible models"},
|
||||
404: {"description": "One or more models not found"},
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = MergeModelResponse,
|
||||
status_code=200,
|
||||
response_model=MergeModelResponse,
|
||||
)
|
||||
async def merge_models(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||
force: Optional[bool] = Body(
|
||||
description="Force merging of models created with different versions of diffusers", default=False
|
||||
),
|
||||
merge_dest_directory: Optional[str] = Body(
|
||||
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||
default=None,
|
||||
),
|
||||
) -> MergeModelResponse:
|
||||
"""Convert a checkpoint model into a diffusers model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory = dest
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
||||
base_model = base_model,
|
||||
model_type = ModelType.Main,
|
||||
)
|
||||
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||
model_names,
|
||||
base_model,
|
||||
merged_model_name=merged_model_name or "+".join(model_names),
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
merge_dest_directory=dest,
|
||||
)
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
result.name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Main,
|
||||
)
|
||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||
except ModelNotFoundException:
|
||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||
|
@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
},
|
||||
)
|
||||
async def create_session(
|
||||
graph: Optional[Graph] = Body(
|
||||
default=None, description="The graph to initialize the session with"
|
||||
)
|
||||
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
|
||||
) -> GraphExecutionState:
|
||||
"""Creates a new session, optionally initializing it with an invocation graph"""
|
||||
session = ApiDependencies.invoker.create_execution_state(graph)
|
||||
@ -51,13 +49,9 @@ async def list_sessions(
|
||||
) -> PaginatedResults[GraphExecutionState]:
|
||||
"""Gets a list of sessions, optionally searching"""
|
||||
if query == "":
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(
|
||||
page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
else:
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(
|
||||
query, page, per_page
|
||||
)
|
||||
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
return result
|
||||
|
||||
|
||||
@ -91,9 +85,9 @@ async def get_session(
|
||||
)
|
||||
async def add_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The node to add"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The node to add"
|
||||
),
|
||||
) -> str:
|
||||
"""Adds a node to the graph"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -124,9 +118,9 @@ async def add_node(
|
||||
async def update_node(
|
||||
session_id: str = Path(description="The id of the session"),
|
||||
node_path: str = Path(description="The path to the node in the graph"),
|
||||
node: Annotated[
|
||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||
] = Body(description="The new node"),
|
||||
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
description="The new node"
|
||||
),
|
||||
) -> GraphExecutionState:
|
||||
"""Updates a node in the graph and removes all linked edges"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -230,7 +224,7 @@ async def delete_edge(
|
||||
try:
|
||||
edge = Edge(
|
||||
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||
destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
)
|
||||
session.delete_edge(edge)
|
||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
@ -255,9 +249,7 @@ async def delete_edge(
|
||||
)
|
||||
async def invoke_session(
|
||||
session_id: str = Path(description="The id of the session to invoke"),
|
||||
all: bool = Query(
|
||||
default=False, description="Whether or not to invoke all remaining invocations"
|
||||
),
|
||||
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
) -> Response:
|
||||
"""Invokes a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
@ -274,9 +266,7 @@ async def invoke_session(
|
||||
@session_router.delete(
|
||||
"/{session_id}/invoke",
|
||||
operation_id="cancel_session_invoke",
|
||||
responses={
|
||||
202: {"description": "The invocation is canceled"}
|
||||
},
|
||||
responses={202: {"description": "The invocation is canceled"}},
|
||||
)
|
||||
async def cancel_session_invoke(
|
||||
session_id: str = Path(description="The id of the session to cancel"),
|
||||
|
@ -16,9 +16,7 @@ class SocketIO:
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
local_handler.register(
|
||||
event_name=EventServiceBase.session_event, _func=self._handle_session_event
|
||||
)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
|
Reference in New Issue
Block a user