reformat with black

This commit is contained in:
Lincoln Stein 2023-07-27 15:01:00 -04:00
parent 00988e4972
commit fd75a1dd10
3 changed files with 249 additions and 272 deletions

View File

@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get( @models_router.get(
"/", "/",
operation_id="list_models", operation_id="list_models",
responses={200: {"model": ModelsList }}, responses={200: {"model": ModelsList}},
) )
async def list_models( async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList: ) -> ModelsList:
"""Gets a list of models""" """Gets a list of models"""
if base_models and len(base_models)>0: if base_models and len(base_models) > 0:
models_raw = list() models_raw = list()
for base_model in base_models: for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else: else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, {"models": models_raw})
return models return models
@models_router.patch( @models_router.patch(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={
400: {"description" : "Bad request"}, 200: {"description": "The model was updated successfully"},
404: {"description" : "The model could not be found"}, 400: {"description": "Bad request"},
409: {"description" : "There is already a model corresponding to the new name"}, 404: {"description": "The model could not be found"},
}, 409: {"description": "There is already a model corresponding to the new name"},
status_code = 200, },
response_model = UpdateModelResponse, status_code=200,
response_model=UpdateModelResponse,
) )
async def update_model( async def update_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """ """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model( previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name, model_name=model_name,
@ -81,13 +84,13 @@ async def update_model(
# rename operation requested # rename operation requested
if info.model_name != model_name or info.base_model != base_model: if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model( ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model, base_model=base_model,
model_type = model_type, model_type=model_type,
model_name = model_name, model_name=model_name,
new_name = info.model_name, new_name=info.model_name,
new_base = info.base_model, new_base=info.base_model,
) )
logger.info(f'Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}') logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
# update information to support an update of attributes # update information to support an update of attributes
model_name = info.model_name model_name = info.model_name
base_model = info.base_model base_model = info.base_model
@ -96,14 +99,13 @@ async def update_model(
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
) )
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it if new_info.get("path") != previous_info.get(
info.path = new_info.get('path') "path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
@ -123,34 +125,35 @@ async def update_model(
return model_response return model_response
@models_router.post( @models_router.post(
"/import", "/import",
operation_id="import_model", operation_id="import_model",
responses= { responses={
201: {"description" : "The model imported successfully"}, 201: {"description": "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse,
) )
async def import_model( async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
items_to_import = {location} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import, items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
info = installed_models.get(location) info = installed_models.get(location)
@ -158,11 +161,9 @@ async def import_model(
logger.error("Import failed") logger.error("Import failed")
raise HTTPException(status_code=415) raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}') logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, model_name=info.name, base_model=info.base_model, model_type=info.model_type
base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
@ -176,37 +177,33 @@ async def import_model(
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post( @models_router.post(
"/add", "/add",
operation_id="add_model", operation_id="add_model",
responses= { responses={
201: {"description" : "The model added successfully"}, 201: {"description": "The model added successfully"},
404: {"description" : "The model could not be found"}, 404: {"description": "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
response_model=ImportModelResponse response_model=ImportModelResponse,
) )
async def add_model( async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse: ) -> ImportModelResponse:
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path""" """Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
ApiDependencies.invoker.services.model_manager.add_model( ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
info.base_model,
info.model_type,
model_attributes = info.dict()
) )
logger.info(f'Successfully added {info.model_name}') logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e: except ModelNotFoundException as e:
@ -220,62 +217,62 @@ async def add_model(
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="del_model", operation_id="del_model",
responses={ responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
204: { "description": "Model deleted successfully" }, status_code=204,
404: { "description": "Model not found" } response_model=None,
},
status_code = 204,
response_model = None,
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
) -> Response: ) -> Response:
"""Delete Model""" """Delete Model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
ApiDependencies.invoker.services.model_manager.del_model(model_name, ApiDependencies.invoker.services.model_manager.del_model(
base_model = base_model, model_name, base_model=base_model, model_type=model_type
model_type = model_type )
)
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
except ModelNotFoundException as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model", operation_id="convert_model",
responses={ responses={
200: { "description": "Model converted successfully" }, 200: {"description": "Model converted successfully"},
400: {"description" : "Bad request" }, 400: {"description": "Bad request"},
404: { "description": "Model not found" }, 404: {"description": "Model not found"},
}, },
status_code = 200, status_code=200,
response_model = ConvertModelResponse, response_model=ConvertModelResponse,
) )
async def convert_model( async def convert_model(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
) -> ConvertModelResponse: ) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Converting model: {model_name}") logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(model_name, ApiDependencies.invoker.services.model_manager.convert_model(
base_model = base_model, model_name,
model_type = model_type, base_model=base_model,
convert_dest_directory = dest, model_type=model_type,
) convert_dest_directory=dest,
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, )
base_model = base_model, model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_type = model_type) model_name, base_model=base_model, model_type=model_type
)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException as e: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
@ -283,34 +280,37 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@models_router.get( @models_router.get(
"/search", "/search",
operation_id="search_for_models", operation_id="search_for_models",
responses={ responses={
200: { "description": "Directory searched successfully" }, 200: {"description": "Directory searched successfully"},
404: { "description": "Invalid directory path" }, 404: {"description": "Invalid directory path"},
}, },
status_code = 200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path],
) )
async def search_for_models( async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models") search_path: pathlib.Path = Query(description="Directory path to search for models"),
)->List[pathlib.Path]: ) -> List[pathlib.Path]:
if not search_path.is_dir(): if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get( @models_router.get(
"/ckpt_confs", "/ckpt_confs",
operation_id="list_ckpt_configs", operation_id="list_ckpt_configs",
responses={ responses={
200: { "description" : "paths retrieved successfully" }, 200: {"description": "paths retrieved successfully"},
}, },
status_code = 200, status_code=200,
response_model = List[pathlib.Path] response_model=List[pathlib.Path],
) )
async def list_ckpt_configs( async def list_ckpt_configs() -> List[pathlib.Path]:
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@ -319,55 +319,62 @@ async def list_ckpt_configs(
"/sync", "/sync",
operation_id="sync_to_config", operation_id="sync_to_config",
responses={ responses={
201: { "description": "synchronization successful" }, 201: {"description": "synchronization successful"},
}, },
status_code = 201, status_code=201,
response_model = bool response_model=bool,
) )
async def sync_to_config( async def sync_to_config() -> bool:
)->bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize """Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures.""" in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config() ApiDependencies.invoker.services.model_manager.sync_to_config()
return True return True
@models_router.put( @models_router.put(
"/merge/{base_model}", "/merge/{base_model}",
operation_id="merge_models", operation_id="merge_models",
responses={ responses={
200: { "description": "Model converted successfully" }, 200: {"description": "Model converted successfully"},
400: { "description": "Incompatible models" }, 400: {"description": "Incompatible models"},
404: { "description": "One or more models not found" }, 404: {"description": "One or more models not found"},
}, },
status_code = 200, status_code=200,
response_model = MergeModelResponse, response_model=MergeModelResponse,
) )
async def merge_models( async def merge_models(
base_model: BaseModelType = Path(description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3), model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"), merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), force: Optional[bool] = Body(
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) description="Force merging of models created with different versions of diffusers", default=False
),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
) -> MergeModelResponse: ) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model""" """Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}") logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, result = ApiDependencies.invoker.services.model_manager.merge_models(
base_model, model_names,
merged_model_name=merged_model_name or "+".join(model_names), base_model,
alpha=alpha, merged_model_name=merged_model_name or "+".join(model_names),
interp=interp, alpha=alpha,
force=force, interp=interp,
merge_dest_directory = dest force=force,
) merge_dest_directory=dest,
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, )
base_model = base_model, model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_type = ModelType.Main, result.name,
) base_model=base_model,
model_type=ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException: except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")

View File

@ -17,15 +17,16 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, ConditioningData,
image_resized_to_grid_as_tensor) ControlNetData,
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ StableDiffusionGeneratorPipeline,
PostprocessingSettings image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
InvocationConfig, InvocationContext)
from .compel import ConditioningField from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .image import ImageOutput from .image import ImageOutput
@ -46,8 +47,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel): class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations""" """A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field( latents_name: Optional[str] = Field(default=None, description="The name of the latents")
default=None, description="The name of the latents")
class Config: class Config:
schema_extra = {"required": ["latents_name"]} schema_extra = {"required": ["latents_name"]}
@ -55,14 +55,15 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput): class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents""" """Base class for invocations that output latents"""
#fmt: off
# fmt: off
type: Literal["latents_output"] = "latents_output" type: Literal["latents_output"] = "latents_output"
# Inputs # Inputs
latents: LatentsField = Field(default=None, description="The output latents") latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels") width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels") height: int = Field(description="The height of the latents in pixels")
#fmt: on # fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor): def build_latents_output(latents_name: str, latents: torch.Tensor):
@ -73,9 +74,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
) )
SAMPLER_NAME_VALUES = Literal[ SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler( def get_scheduler(
@ -83,11 +82,10 @@ def get_scheduler(
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
scheduler_name: str, scheduler_name: str,
) -> Scheduler: ) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model( orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict(), context=context, **scheduler_info.dict(),
context=context,
) )
with orig_scheduler_info as orig_scheduler: with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config scheduler_config = orig_scheduler.config
@ -102,7 +100,7 @@ def get_scheduler(
scheduler = scheduler_class.from_config(scheduler_config) scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py # hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'): if not hasattr(scheduler, "uses_inpainting_model"):
scheduler.uses_inpainting_model = lambda: False scheduler.uses_inpainting_model = lambda: False
return scheduler return scheduler
@ -123,8 +121,8 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on # fmt: on
@validator("cfg_scale") @validator("cfg_scale")
@ -133,10 +131,10 @@ class TextToLatentsInvocation(BaseInvocation):
if isinstance(v, list): if isinstance(v, list):
for i in v: for i in v:
if i < 1: if i < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
else: else:
if v < 1: if v < 1:
raise ValueError('cfg_scale must be greater than 1') raise ValueError("cfg_scale must be greater than 1")
return v return v
# Schema customisation # Schema customisation
@ -149,8 +147,8 @@ class TextToLatentsInvocation(BaseInvocation):
"model": "model", "model": "model",
"control": "control", "control": "control",
# "cfg_scale": "float", # "cfg_scale": "float",
"cfg_scale": "number" "cfg_scale": "number",
} },
}, },
} }
@ -190,16 +188,14 @@ class TextToLatentsInvocation(BaseInvocation):
threshold=0.0, # threshold, threshold=0.0, # threshold,
warmup=0.2, # warmup, warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct, h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None # v_symmetry_time_pct, v_symmetry_time_pct=None, # v_symmetry_time_pct,
), ),
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler, scheduler,
# for ddim scheduler # for ddim scheduler
eta=0.0, # ddim_eta eta=0.0, # ddim_eta
# for ancestral and sde schedulers # for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0), generator=torch.Generator(device=unet.device).manual_seed(0),
) )
@ -247,7 +243,6 @@ class TextToLatentsInvocation(BaseInvocation):
exit_stack: ExitStack, exit_stack: ExitStack,
do_classifier_free_guidance: bool = True, do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]: ) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents # assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8 control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8 control_width_resize = latents_shape[3] * 8
@ -261,7 +256,7 @@ class TextToLatentsInvocation(BaseInvocation):
control_list = control_input control_list = control_input
else: else:
control_list = None control_list = None
if (control_list is None): if control_list is None:
control_data = None control_data = None
# from above handling, any control that is not None should now be of type list[ControlField] # from above handling, any control that is not None should now be of type list[ControlField]
else: else:
@ -281,9 +276,7 @@ class TextToLatentsInvocation(BaseInvocation):
control_models.append(control_model) control_models.append(control_model)
control_image_field = control_info.image control_image_field = control_info.image
input_image = context.services.images.get_pil_image( input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_image_field.image_name
)
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes # FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt? # and add in batch_size, num_images_per_prompt?
@ -322,9 +315,7 @@ class TextToLatentsInvocation(BaseInvocation):
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -333,19 +324,20 @@ class TextToLatentsInvocation(BaseInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context, **lora.dict(exclude={"weight"}),
context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context, **self.unet.unet.dict(),
context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info.context.model, _lora_loader()
unet_info as unet: ), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
@ -358,7 +350,9 @@ class TextToLatentsInvocation(BaseInvocation):
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
@ -379,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation):
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -390,11 +384,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l" type: Literal["l2l"] = "l2l"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
description="The latents to use as a base image") strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -406,7 +397,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"model": "model", "model": "model",
"control": "control", "control": "control",
"cfg_scale": "number", "cfg_scale": "number",
} },
}, },
} }
@ -417,9 +408,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
@ -428,19 +417,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def _lora_loader(): def _lora_loader():
for lora in self.unet.loras: for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model( lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context, **lora.dict(exclude={"weight"}),
context=context,
) )
yield (lora_info.context.model, lora.weight) yield (lora_info.context.model, lora.weight)
del lora_info del lora_info
return return
unet_info = context.services.model_manager.get_model( unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context, **self.unet.unet.dict(),
context=context,
) )
with ExitStack() as exit_stack,\ with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info.context.model, _lora_loader()
unet_info as unet: ), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype)
@ -454,7 +444,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
conditioning_data = self.get_conditioning_data(context, scheduler, unet) conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data( control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control, model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape, latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0)) # do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
@ -462,8 +454,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
) )
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = (
latent, device=unet.device, dtype=latent.dtype latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
) )
timesteps, _ = pipeline.get_img2img_timesteps( timesteps, _ = pipeline.get_img2img_timesteps(
@ -479,14 +471,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}' name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents) return build_latents_output(latents_name=name, latents=result_latents)
@ -498,14 +490,13 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i" type: Literal["l2i"] = "l2i"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
default=False, fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
description="Decode latents by overlaping tiles(less memory consumption)") metadata: Optional[CoreMetadata] = Field(
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") default=None, description="Optional core metadata to be written to the image"
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") )
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -521,7 +512,8 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), context=context, **self.vae.vae.dict(),
context=context,
) )
with vae_info as vae: with vae_info as vae:
@ -588,8 +580,7 @@ class LatentsToImageInvocation(BaseInvocation):
) )
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation): class ResizeLatentsInvocation(BaseInvocation):
@ -598,36 +589,30 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize" type: Literal["lresize"] = "lresize"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to resize")
description="The latents to resize") width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
width: Union[int, None] = Field(default=512, height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
ge=64, multiple_of=8, description="The width to resize to (px)") mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
height: Union[int, None] = Field(default=512,
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field( antialias: bool = Field(
default=False, default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") )
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
"title": "Resize Latents",
"tags": ["latents", "resize"]
},
} }
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO: # TODO:
device=choose_torch_device() device = choose_torch_device()
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), size=(self.height // 8, self.width // 8), latents.to(device),
mode=self.mode, antialias=self.antialias size=(self.height // 8, self.width // 8),
if self.mode in ["bilinear", "bicubic"] else False, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -646,35 +631,30 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale" type: Literal["lscale"] = "lscale"
# Inputs # Inputs
latents: Optional[LatentsField] = Field( latents: Optional[LatentsField] = Field(description="The latents to scale")
description="The latents to scale") scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
scale_factor: float = Field( mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field( antialias: bool = Field(
default=False, default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
description="Whether or not to antialias (applied in bilinear and bicubic modes only)") )
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
"title": "Scale Latents",
"tags": ["latents", "scale"]
},
} }
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name) latents = context.services.latents.get(self.latents.latents_name)
# TODO: # TODO:
device=choose_torch_device() device = choose_torch_device()
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), scale_factor=self.scale_factor, mode=self.mode, latents.to(device),
antialias=self.antialias scale_factor=self.scale_factor,
if self.mode in ["bilinear", "bicubic"] else False, mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@ -695,19 +675,13 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs # Inputs
image: Optional[ImageField] = Field(description="The image to encode") image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
default=False, fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
schema_extra = { schema_extra = {
"ui": { "ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
"title": "Image To Latents",
"tags": ["latents", "image"]
},
} }
@torch.no_grad() @torch.no_grad()
@ -717,9 +691,10 @@ class ImageToLatentsInvocation(BaseInvocation):
# ) # )
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model( vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(), context=context, **self.vae.vae.dict(),
context=context,
) )
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -746,12 +721,12 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.post_quant_conv.to(orig_dtype) vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype) vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype) vae.decoder.mid_block.to(orig_dtype)
#else: # else:
# latents = latents.float() # latents = latents.float()
else: else:
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
#latents = latents.half() # latents = latents.half()
if self.tiled: if self.tiled:
vae.enable_tiling() vae.enable_tiling()
@ -762,9 +737,7 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(): with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to( latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)

View File

@ -7,13 +7,10 @@ from invokeai.backend.model_management.model_probe import ModelProbe
parser = argparse.ArgumentParser(description="Probe model type") parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument( parser.add_argument(
'model_path', "model_path",
type=Path, type=Path,
) )
args=parser.parse_args() args = parser.parse_args()
info = ModelProbe().probe(args.model_path) info = ModelProbe().probe(args.model_path)
print(info) print(info)