mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
reformat with black
This commit is contained in:
parent
00988e4972
commit
fd75a1dd10
@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
|||||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||||
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_models",
|
operation_id="list_models",
|
||||||
responses={200: {"model": ModelsList }},
|
responses={200: {"model": ModelsList}},
|
||||||
)
|
)
|
||||||
async def list_models(
|
async def list_models(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Gets a list of models"""
|
"""Gets a list of models"""
|
||||||
if base_models and len(base_models)>0:
|
if base_models and len(base_models) > 0:
|
||||||
models_raw = list()
|
models_raw = list()
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
|
||||||
else:
|
else:
|
||||||
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
|
||||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
models = parse_obj_as(ModelsList, {"models": models_raw})
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@models_router.patch(
|
@models_router.patch(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="update_model",
|
operation_id="update_model",
|
||||||
responses={200: {"description" : "The model was updated successfully"},
|
responses={
|
||||||
400: {"description" : "Bad request"},
|
200: {"description": "The model was updated successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
400: {"description": "Bad request"},
|
||||||
409: {"description" : "There is already a model corresponding to the new name"},
|
404: {"description": "The model could not be found"},
|
||||||
},
|
409: {"description": "There is already a model corresponding to the new name"},
|
||||||
status_code = 200,
|
},
|
||||||
response_model = UpdateModelResponse,
|
status_code=200,
|
||||||
|
response_model=UpdateModelResponse,
|
||||||
)
|
)
|
||||||
async def update_model(
|
async def update_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> UpdateModelResponse:
|
) -> UpdateModelResponse:
|
||||||
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -81,13 +84,13 @@ async def update_model(
|
|||||||
# rename operation requested
|
# rename operation requested
|
||||||
if info.model_name != model_name or info.base_model != base_model:
|
if info.model_name != model_name or info.base_model != base_model:
|
||||||
ApiDependencies.invoker.services.model_manager.rename_model(
|
ApiDependencies.invoker.services.model_manager.rename_model(
|
||||||
base_model = base_model,
|
base_model=base_model,
|
||||||
model_type = model_type,
|
model_type=model_type,
|
||||||
model_name = model_name,
|
model_name=model_name,
|
||||||
new_name = info.model_name,
|
new_name=info.model_name,
|
||||||
new_base = info.base_model,
|
new_base=info.base_model,
|
||||||
)
|
)
|
||||||
logger.info(f'Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}')
|
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
|
||||||
# update information to support an update of attributes
|
# update information to support an update of attributes
|
||||||
model_name = info.model_name
|
model_name = info.model_name
|
||||||
base_model = info.base_model
|
base_model = info.base_model
|
||||||
@ -96,14 +99,13 @@ async def update_model(
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
|
if new_info.get("path") != previous_info.get(
|
||||||
info.path = new_info.get('path')
|
"path"
|
||||||
|
): # model manager moved model path during rename - don't overwrite it
|
||||||
|
info.path = new_info.get("path")
|
||||||
|
|
||||||
ApiDependencies.invoker.services.model_manager.update_model(
|
ApiDependencies.invoker.services.model_manager.update_model(
|
||||||
model_name=model_name,
|
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
model_attributes=info.dict()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
@ -123,34 +125,35 @@ async def update_model(
|
|||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description": "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
415: {"description" : "Unrecognized file/folder format"},
|
415: {"description": "Unrecognized file/folder format"},
|
||||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse,
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||||
|
),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||||
items_to_import = items_to_import,
|
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
info = installed_models.get(location)
|
||||||
|
|
||||||
@ -158,11 +161,9 @@ async def import_model(
|
|||||||
logger.error("Import failed")
|
logger.error("Import failed")
|
||||||
raise HTTPException(status_code=415)
|
raise HTTPException(status_code=415)
|
||||||
|
|
||||||
logger.info(f'Successfully imported {location}, got {info}')
|
logger.info(f"Successfully imported {location}, got {info}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.name,
|
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type
|
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
|
|
||||||
@ -176,37 +177,33 @@ async def import_model(
|
|||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
raise HTTPException(status_code=409, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/add",
|
"/add",
|
||||||
operation_id="add_model",
|
operation_id="add_model",
|
||||||
responses= {
|
responses={
|
||||||
201: {"description" : "The model added successfully"},
|
201: {"description": "The model added successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description": "The model could not be found"},
|
||||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
response_model=ImportModelResponse
|
response_model=ImportModelResponse,
|
||||||
)
|
)
|
||||||
async def add_model(
|
async def add_model(
|
||||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||||
) -> ImportModelResponse:
|
) -> ImportModelResponse:
|
||||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||||
|
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.add_model(
|
ApiDependencies.invoker.services.model_manager.add_model(
|
||||||
info.model_name,
|
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
|
||||||
info.base_model,
|
|
||||||
info.model_type,
|
|
||||||
model_attributes = info.dict()
|
|
||||||
)
|
)
|
||||||
logger.info(f'Successfully added {info.model_name}')
|
logger.info(f"Successfully added {info.model_name}")
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_name=info.model_name,
|
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type
|
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
return parse_obj_as(ImportModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
@ -220,62 +217,62 @@ async def add_model(
|
|||||||
@models_router.delete(
|
@models_router.delete(
|
||||||
"/{base_model}/{model_type}/{model_name}",
|
"/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="del_model",
|
operation_id="del_model",
|
||||||
responses={
|
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
|
||||||
204: { "description": "Model deleted successfully" },
|
status_code=204,
|
||||||
404: { "description": "Model not found" }
|
response_model=None,
|
||||||
},
|
|
||||||
status_code = 204,
|
|
||||||
response_model = None,
|
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ApiDependencies.invoker.services.model_manager.del_model(model_name,
|
ApiDependencies.invoker.services.model_manager.del_model(
|
||||||
base_model = base_model,
|
model_name, base_model=base_model, model_type=model_type
|
||||||
model_type = model_type
|
)
|
||||||
)
|
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
return Response(status_code=204)
|
return Response(status_code=204)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/convert/{base_model}/{model_type}/{model_name}",
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
operation_id="convert_model",
|
operation_id="convert_model",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: {"description" : "Bad request" },
|
400: {"description": "Bad request"},
|
||||||
404: { "description": "Model not found" },
|
404: {"description": "Model not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = ConvertModelResponse,
|
response_model=ConvertModelResponse,
|
||||||
)
|
)
|
||||||
async def convert_model(
|
async def convert_model(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
|
convert_dest_directory: Optional[str] = Query(
|
||||||
|
default=None, description="Save the converted model to the designated directory"
|
||||||
|
),
|
||||||
) -> ConvertModelResponse:
|
) -> ConvertModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Converting model: {model_name}")
|
logger.info(f"Converting model: {model_name}")
|
||||||
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
|
||||||
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
ApiDependencies.invoker.services.model_manager.convert_model(
|
||||||
base_model = base_model,
|
model_name,
|
||||||
model_type = model_type,
|
base_model=base_model,
|
||||||
convert_dest_directory = dest,
|
model_type=model_type,
|
||||||
)
|
convert_dest_directory=dest,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
|
)
|
||||||
base_model = base_model,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_type = model_type)
|
model_name, base_model=base_model, model_type=model_type
|
||||||
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException as e:
|
except ModelNotFoundException as e:
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
|
||||||
@ -283,34 +280,37 @@ async def convert_model(
|
|||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/search",
|
"/search",
|
||||||
operation_id="search_for_models",
|
operation_id="search_for_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Directory searched successfully" },
|
200: {"description": "Directory searched successfully"},
|
||||||
404: { "description": "Invalid directory path" },
|
404: {"description": "Invalid directory path"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path],
|
||||||
)
|
)
|
||||||
async def search_for_models(
|
async def search_for_models(
|
||||||
search_path: pathlib.Path = Query(description="Directory path to search for models")
|
search_path: pathlib.Path = Query(description="Directory path to search for models"),
|
||||||
)->List[pathlib.Path]:
|
) -> List[pathlib.Path]:
|
||||||
if not search_path.is_dir():
|
if not search_path.is_dir():
|
||||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
|
||||||
|
)
|
||||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||||
|
|
||||||
|
|
||||||
@models_router.get(
|
@models_router.get(
|
||||||
"/ckpt_confs",
|
"/ckpt_confs",
|
||||||
operation_id="list_ckpt_configs",
|
operation_id="list_ckpt_configs",
|
||||||
responses={
|
responses={
|
||||||
200: { "description" : "paths retrieved successfully" },
|
200: {"description": "paths retrieved successfully"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = List[pathlib.Path]
|
response_model=List[pathlib.Path],
|
||||||
)
|
)
|
||||||
async def list_ckpt_configs(
|
async def list_ckpt_configs() -> List[pathlib.Path]:
|
||||||
)->List[pathlib.Path]:
|
|
||||||
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
|
||||||
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
|
||||||
|
|
||||||
@ -319,55 +319,62 @@ async def list_ckpt_configs(
|
|||||||
"/sync",
|
"/sync",
|
||||||
operation_id="sync_to_config",
|
operation_id="sync_to_config",
|
||||||
responses={
|
responses={
|
||||||
201: { "description": "synchronization successful" },
|
201: {"description": "synchronization successful"},
|
||||||
},
|
},
|
||||||
status_code = 201,
|
status_code=201,
|
||||||
response_model = bool
|
response_model=bool,
|
||||||
)
|
)
|
||||||
async def sync_to_config(
|
async def sync_to_config() -> bool:
|
||||||
)->bool:
|
|
||||||
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
|
||||||
in-memory data structures with disk data structures."""
|
in-memory data structures with disk data structures."""
|
||||||
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
ApiDependencies.invoker.services.model_manager.sync_to_config()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@models_router.put(
|
@models_router.put(
|
||||||
"/merge/{base_model}",
|
"/merge/{base_model}",
|
||||||
operation_id="merge_models",
|
operation_id="merge_models",
|
||||||
responses={
|
responses={
|
||||||
200: { "description": "Model converted successfully" },
|
200: {"description": "Model converted successfully"},
|
||||||
400: { "description": "Incompatible models" },
|
400: {"description": "Incompatible models"},
|
||||||
404: { "description": "One or more models not found" },
|
404: {"description": "One or more models not found"},
|
||||||
},
|
},
|
||||||
status_code = 200,
|
status_code=200,
|
||||||
response_model = MergeModelResponse,
|
response_model=MergeModelResponse,
|
||||||
)
|
)
|
||||||
async def merge_models(
|
async def merge_models(
|
||||||
base_model: BaseModelType = Path(description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
|
||||||
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
merged_model_name: Optional[str] = Body(description="Name of destination model"),
|
||||||
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
|
||||||
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
|
||||||
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
|
force: Optional[bool] = Body(
|
||||||
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
|
description="Force merging of models created with different versions of diffusers", default=False
|
||||||
|
),
|
||||||
|
merge_dest_directory: Optional[str] = Body(
|
||||||
|
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
|
||||||
|
default=None,
|
||||||
|
),
|
||||||
) -> MergeModelResponse:
|
) -> MergeModelResponse:
|
||||||
"""Convert a checkpoint model into a diffusers model"""
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
try:
|
try:
|
||||||
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
|
||||||
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
|
||||||
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
|
result = ApiDependencies.invoker.services.model_manager.merge_models(
|
||||||
base_model,
|
model_names,
|
||||||
merged_model_name=merged_model_name or "+".join(model_names),
|
base_model,
|
||||||
alpha=alpha,
|
merged_model_name=merged_model_name or "+".join(model_names),
|
||||||
interp=interp,
|
alpha=alpha,
|
||||||
force=force,
|
interp=interp,
|
||||||
merge_dest_directory = dest
|
force=force,
|
||||||
)
|
merge_dest_directory=dest,
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
|
)
|
||||||
base_model = base_model,
|
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||||
model_type = ModelType.Main,
|
result.name,
|
||||||
)
|
base_model=base_model,
|
||||||
|
model_type=ModelType.Main,
|
||||||
|
)
|
||||||
response = parse_obj_as(ConvertModelResponse, model_raw)
|
response = parse_obj_as(ConvertModelResponse, model_raw)
|
||||||
except ModelNotFoundException:
|
except ModelNotFoundException:
|
||||||
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
|
||||||
|
@ -17,15 +17,16 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
|||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
|
ConditioningData,
|
||||||
image_resized_to_grid_as_tensor)
|
ControlNetData,
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
|
StableDiffusionGeneratorPipeline,
|
||||||
PostprocessingSettings
|
image_resized_to_grid_as_tensor,
|
||||||
|
)
|
||||||
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
|
||||||
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
InvocationConfig, InvocationContext)
|
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .image import ImageOutput
|
from .image import ImageOutput
|
||||||
@ -46,8 +47,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
"""A latents field used for passing latents between invocations"""
|
"""A latents field used for passing latents between invocations"""
|
||||||
|
|
||||||
latents_name: Optional[str] = Field(
|
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
|
||||||
default=None, description="The name of the latents")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
schema_extra = {"required": ["latents_name"]}
|
schema_extra = {"required": ["latents_name"]}
|
||||||
@ -55,14 +55,15 @@ class LatentsField(BaseModel):
|
|||||||
|
|
||||||
class LatentsOutput(BaseInvocationOutput):
|
class LatentsOutput(BaseInvocationOutput):
|
||||||
"""Base class for invocations that output latents"""
|
"""Base class for invocations that output latents"""
|
||||||
#fmt: off
|
|
||||||
|
# fmt: off
|
||||||
type: Literal["latents_output"] = "latents_output"
|
type: Literal["latents_output"] = "latents_output"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: LatentsField = Field(default=None, description="The output latents")
|
latents: LatentsField = Field(default=None, description="The output latents")
|
||||||
width: int = Field(description="The width of the latents in pixels")
|
width: int = Field(description="The width of the latents in pixels")
|
||||||
height: int = Field(description="The height of the latents in pixels")
|
height: int = Field(description="The height of the latents in pixels")
|
||||||
#fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
@ -73,9 +74,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
tuple(list(SCHEDULER_MAP.keys()))
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
@ -83,11 +82,10 @@ def get_scheduler(
|
|||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
scheduler_name: str,
|
scheduler_name: str,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
|
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||||
scheduler_name, SCHEDULER_MAP['ddim']
|
|
||||||
)
|
|
||||||
orig_scheduler_info = context.services.model_manager.get_model(
|
orig_scheduler_info = context.services.model_manager.get_model(
|
||||||
**scheduler_info.dict(), context=context,
|
**scheduler_info.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with orig_scheduler_info as orig_scheduler:
|
with orig_scheduler_info as orig_scheduler:
|
||||||
scheduler_config = orig_scheduler.config
|
scheduler_config = orig_scheduler.config
|
||||||
@ -102,7 +100,7 @@ def get_scheduler(
|
|||||||
scheduler = scheduler_class.from_config(scheduler_config)
|
scheduler = scheduler_class.from_config(scheduler_config)
|
||||||
|
|
||||||
# hack copied over from generate.py
|
# hack copied over from generate.py
|
||||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
if not hasattr(scheduler, "uses_inpainting_model"):
|
||||||
scheduler.uses_inpainting_model = lambda: False
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
@ -123,8 +121,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
@ -133,10 +131,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
for i in v:
|
for i in v:
|
||||||
if i < 1:
|
if i < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
else:
|
else:
|
||||||
if v < 1:
|
if v < 1:
|
||||||
raise ValueError('cfg_scale must be greater than 1')
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
@ -149,8 +147,8 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
# "cfg_scale": "float",
|
# "cfg_scale": "float",
|
||||||
"cfg_scale": "number"
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,16 +188,14 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
threshold=0.0, # threshold,
|
threshold=0.0, # threshold,
|
||||||
warmup=0.2, # warmup,
|
warmup=0.2, # warmup,
|
||||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||||
v_symmetry_time_pct=None # v_symmetry_time_pct,
|
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
|
||||||
scheduler,
|
scheduler,
|
||||||
|
|
||||||
# for ddim scheduler
|
# for ddim scheduler
|
||||||
eta=0.0, # ddim_eta
|
eta=0.0, # ddim_eta
|
||||||
|
|
||||||
# for ancestral and sde schedulers
|
# for ancestral and sde schedulers
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(0),
|
generator=torch.Generator(device=unet.device).manual_seed(0),
|
||||||
)
|
)
|
||||||
@ -247,7 +243,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
do_classifier_free_guidance: bool = True,
|
do_classifier_free_guidance: bool = True,
|
||||||
) -> List[ControlNetData]:
|
) -> List[ControlNetData]:
|
||||||
|
|
||||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||||
control_height_resize = latents_shape[2] * 8
|
control_height_resize = latents_shape[2] * 8
|
||||||
control_width_resize = latents_shape[3] * 8
|
control_width_resize = latents_shape[3] * 8
|
||||||
@ -261,7 +256,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
control_list = control_input
|
control_list = control_input
|
||||||
else:
|
else:
|
||||||
control_list = None
|
control_list = None
|
||||||
if (control_list is None):
|
if control_list is None:
|
||||||
control_data = None
|
control_data = None
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||||
else:
|
else:
|
||||||
@ -281,9 +276,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
control_models.append(control_model)
|
control_models.append(control_model)
|
||||||
control_image_field = control_info.image
|
control_image_field = control_info.image
|
||||||
input_image = context.services.images.get_pil_image(
|
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||||
control_image_field.image_name
|
|
||||||
)
|
|
||||||
# self.image.image_type, self.image.image_name
|
# self.image.image_type, self.image.image_name
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
# and add in batch_size, num_images_per_prompt?
|
# and add in batch_size, num_images_per_prompt?
|
||||||
@ -322,9 +315,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
noise = context.services.latents.get(self.noise.latents_name)
|
noise = context.services.latents.get(self.noise.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -333,19 +324,20 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}), context=context,
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(), context=context,
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
unet_info.context.model, _lora_loader()
|
||||||
unet_info as unet:
|
), unet_info as unet:
|
||||||
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
@ -358,7 +350,9 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline,
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
@ -379,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
@ -390,11 +384,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
type: Literal["l2l"] = "l2l"
|
type: Literal["l2l"] = "l2l"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
|
||||||
description="The latents to use as a base image")
|
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
|
||||||
strength: float = Field(
|
|
||||||
default=0.7, ge=0, le=1,
|
|
||||||
description="The strength of the latents to use")
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -406,7 +397,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
"model": "model",
|
"model": "model",
|
||||||
"control": "control",
|
"control": "control",
|
||||||
"cfg_scale": "number",
|
"cfg_scale": "number",
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,9 +408,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
latent = context.services.latents.get(self.latents.latents_name)
|
latent = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
context.graph_execution_state_id
|
|
||||||
)
|
|
||||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||||
|
|
||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
@ -428,19 +417,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def _lora_loader():
|
def _lora_loader():
|
||||||
for lora in self.unet.loras:
|
for lora in self.unet.loras:
|
||||||
lora_info = context.services.model_manager.get_model(
|
lora_info = context.services.model_manager.get_model(
|
||||||
**lora.dict(exclude={"weight"}), context=context,
|
**lora.dict(exclude={"weight"}),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
yield (lora_info.context.model, lora.weight)
|
yield (lora_info.context.model, lora.weight)
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
unet_info = context.services.model_manager.get_model(
|
unet_info = context.services.model_manager.get_model(
|
||||||
**self.unet.unet.dict(), context=context,
|
**self.unet.unet.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack,\
|
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
unet_info.context.model, _lora_loader()
|
||||||
unet_info as unet:
|
), unet_info as unet:
|
||||||
|
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
latent = latent.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
@ -454,7 +444,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
|
||||||
|
|
||||||
control_data = self.prep_control_data(
|
control_data = self.prep_control_data(
|
||||||
model=pipeline, context=context, control_input=self.control,
|
model=pipeline,
|
||||||
|
context=context,
|
||||||
|
control_input=self.control,
|
||||||
latents_shape=noise.shape,
|
latents_shape=noise.shape,
|
||||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
@ -462,8 +454,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
|
initial_latents = (
|
||||||
latent, device=unet.device, dtype=latent.dtype
|
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
timesteps, _ = pipeline.get_img2img_timesteps(
|
timesteps, _ = pipeline.get_img2img_timesteps(
|
||||||
@ -479,14 +471,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
num_inference_steps=self.steps,
|
num_inference_steps=self.steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=control_data, # list[ControlNetData]
|
control_data=control_data, # list[ControlNetData]
|
||||||
callback=step_callback
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
name = f'{context.graph_execution_state_id}__{self.id}'
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
return build_latents_output(latents_name=name, latents=result_latents)
|
return build_latents_output(latents_name=name, latents=result_latents)
|
||||||
|
|
||||||
@ -498,14 +490,13 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
type: Literal["l2i"] = "l2i"
|
type: Literal["l2i"] = "l2i"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
description="The latents to generate an image from")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(
|
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
default=False,
|
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||||
description="Decode latents by overlaping tiles(less memory consumption)")
|
metadata: Optional[CoreMetadata] = Field(
|
||||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
default=None, description="Optional core metadata to be written to the image"
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
)
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
@ -521,7 +512,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(), context=context,
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
@ -588,8 +580,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
|
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||||
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
|
||||||
|
|
||||||
|
|
||||||
class ResizeLatentsInvocation(BaseInvocation):
|
class ResizeLatentsInvocation(BaseInvocation):
|
||||||
@ -598,36 +589,30 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lresize"] = "lresize"
|
type: Literal["lresize"] = "lresize"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to resize")
|
||||||
description="The latents to resize")
|
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
width: Union[int, None] = Field(default=512,
|
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
ge=64, multiple_of=8, description="The width to resize to (px)")
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
height: Union[int, None] = Field(default=512,
|
|
||||||
ge=64, multiple_of=8, description="The height to resize to (px)")
|
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
|
||||||
default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False,
|
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
|
||||||
"title": "Resize Latents",
|
|
||||||
"tags": ["latents", "resize"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device=choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device), size=(self.height // 8, self.width // 8),
|
latents.to(device),
|
||||||
mode=self.mode, antialias=self.antialias
|
size=(self.height // 8, self.width // 8),
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@ -646,35 +631,30 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
type: Literal["lscale"] = "lscale"
|
type: Literal["lscale"] = "lscale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
latents: Optional[LatentsField] = Field(
|
latents: Optional[LatentsField] = Field(description="The latents to scale")
|
||||||
description="The latents to scale")
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
|
||||||
scale_factor: float = Field(
|
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
|
||||||
gt=0, description="The factor by which to scale the latents")
|
|
||||||
mode: LATENTS_INTERPOLATION_MODE = Field(
|
|
||||||
default="bilinear", description="The interpolation mode")
|
|
||||||
antialias: bool = Field(
|
antialias: bool = Field(
|
||||||
default=False,
|
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
|
||||||
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
|
)
|
||||||
|
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
|
||||||
"title": "Scale Latents",
|
|
||||||
"tags": ["latents", "scale"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.services.latents.get(self.latents.latents_name)
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
# TODO:
|
||||||
device=choose_torch_device()
|
device = choose_torch_device()
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
|
latents.to(device),
|
||||||
antialias=self.antialias
|
scale_factor=self.scale_factor,
|
||||||
if self.mode in ["bilinear", "bicubic"] else False,
|
mode=self.mode,
|
||||||
|
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
@ -695,19 +675,13 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(description="The image to encode")
|
image: Optional[ImageField] = Field(description="The image to encode")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
tiled: bool = Field(
|
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
|
||||||
default=False,
|
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
|
||||||
description="Encode latents by overlaping tiles(less memory consumption)")
|
|
||||||
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
|
|
||||||
|
|
||||||
|
|
||||||
# Schema customisation
|
# Schema customisation
|
||||||
class Config(InvocationConfig):
|
class Config(InvocationConfig):
|
||||||
schema_extra = {
|
schema_extra = {
|
||||||
"ui": {
|
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
|
||||||
"title": "Image To Latents",
|
|
||||||
"tags": ["latents", "image"]
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -717,9 +691,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
# )
|
# )
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
vae_info = context.services.model_manager.get_model(
|
vae_info = context.services.model_manager.get_model(
|
||||||
**self.vae.vae.dict(), context=context,
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
@ -746,12 +721,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.post_quant_conv.to(orig_dtype)
|
vae.post_quant_conv.to(orig_dtype)
|
||||||
vae.decoder.conv_in.to(orig_dtype)
|
vae.decoder.conv_in.to(orig_dtype)
|
||||||
vae.decoder.mid_block.to(orig_dtype)
|
vae.decoder.mid_block.to(orig_dtype)
|
||||||
#else:
|
# else:
|
||||||
# latents = latents.float()
|
# latents = latents.float()
|
||||||
|
|
||||||
else:
|
else:
|
||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
#latents = latents.half()
|
# latents = latents.half()
|
||||||
|
|
||||||
if self.tiled:
|
if self.tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
@ -762,9 +737,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||||
latents = image_tensor_dist.sample().to(
|
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
|
||||||
dtype=vae.dtype
|
|
||||||
) # FIXME: uses torch.randn. make reproducible!
|
|
||||||
|
|
||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
@ -7,13 +7,10 @@ from invokeai.backend.model_management.model_probe import ModelProbe
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Probe model type")
|
parser = argparse.ArgumentParser(description="Probe model type")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'model_path',
|
"model_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
)
|
)
|
||||||
args=parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
info = ModelProbe().probe(args.model_path)
|
info = ModelProbe().probe(args.model_path)
|
||||||
print(info)
|
print(info)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user