diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index a7e0616f5b..2f59f1dd0f 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] + class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + @models_router.get( "/", operation_id="list_models", - responses={200: {"model": ModelsList }}, + responses={200: {"model": ModelsList}}, ) async def list_models( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" - if base_models and len(base_models)>0: + if base_models and len(base_models) > 0: models_raw = list() for base_model in base_models: models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) else: models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = parse_obj_as(ModelsList, { "models": models_raw }) + models = parse_obj_as(ModelsList, {"models": models_raw}) return models + @models_router.patch( "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"description" : "The model was updated successfully"}, - 400: {"description" : "Bad request"}, - 404: {"description" : "The model could not be found"}, - 409: {"description" : "There is already a model corresponding to the new name"}, - }, - status_code = 200, - response_model = UpdateModelResponse, + responses={ + 200: {"description": "The model was updated successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to the new name"}, + }, + status_code=200, + response_model=UpdateModelResponse, ) async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> UpdateModelResponse: - """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """ + """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - try: previous_info = ApiDependencies.invoker.services.model_manager.list_model( model_name=model_name, @@ -81,13 +84,13 @@ async def update_model( # rename operation requested if info.model_name != model_name or info.base_model != base_model: ApiDependencies.invoker.services.model_manager.rename_model( - base_model = base_model, - model_type = model_type, - model_name = model_name, - new_name = info.model_name, - new_base = info.base_model, + base_model=base_model, + model_type=model_type, + model_name=model_name, + new_name=info.model_name, + new_base=info.base_model, ) - logger.info(f'Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}') + logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}") # update information to support an update of attributes model_name = info.model_name base_model = info.base_model @@ -96,16 +99,15 @@ async def update_model( base_model=base_model, model_type=model_type, ) - if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get('path') + if new_info.get("path") != previous_info.get( + "path" + ): # model manager moved model path during rename - don't overwrite it + info.path = new_info.get("path") ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info.dict() + model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() ) - + model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name=model_name, base_model=base_model, @@ -123,49 +125,48 @@ async def update_model( return model_response + @models_router.post( "/import", operation_id="import_model", - responses= { - 201: {"description" : "The model imported successfully"}, - 404: {"description" : "The model could not be found"}, - 415: {"description" : "Unrecognized file/folder format"}, - 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model imported successfully"}, + 404: {"description": "The model could not be found"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse + response_model=ImportModelResponse, ) async def import_model( - location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ - Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), + location: str = Body(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( + description="Prediction type for SDv2 checkpoint files", default="v_prediction" + ), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ - + """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" + items_to_import = {location} - prediction_types = { x.value: x for x in SchedulerPredictionType } + prediction_types = {x.value: x for x in SchedulerPredictionType} logger = ApiDependencies.invoker.services.logger try: installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) + items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type) ) info = installed_models.get(location) if not info: logger.error("Import failed") raise HTTPException(status_code=415) - - logger.info(f'Successfully imported {location}, got {info}') + + logger.info(f"Successfully imported {location}, got {info}") model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, - base_model=info.base_model, - model_type=info.model_type + model_name=info.name, base_model=info.base_model, model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) - + except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) @@ -175,38 +176,34 @@ async def import_model( except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + + @models_router.post( "/add", operation_id="add_model", - responses= { - 201: {"description" : "The model added successfully"}, - 404: {"description" : "The model could not be found"}, - 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model added successfully"}, + 404: {"description": "The model could not be found"}, + 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse + response_model=ImportModelResponse, ) async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> ImportModelResponse: - """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - + """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + logger = ApiDependencies.invoker.services.logger try: ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, - info.base_model, - info.model_type, - model_attributes = info.dict() + info.model_name, info.base_model, info.model_type, model_attributes=info.dict() ) - logger.info(f'Successfully added {info.model_name}') + logger.info(f"Successfully added {info.model_name}") model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, - base_model=info.base_model, - model_type=info.model_type + model_name=info.model_name, base_model=info.base_model, model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) except ModelNotFoundException as e: @@ -216,66 +213,66 @@ async def add_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", - responses={ - 204: { "description": "Model deleted successfully" }, - 404: { "description": "Model not found" } - }, - status_code = 204, - response_model = None, + responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, + status_code=204, + response_model=None, ) async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), ) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger - + try: - ApiDependencies.invoker.services.model_manager.del_model(model_name, - base_model = base_model, - model_type = model_type - ) + ApiDependencies.invoker.services.model_manager.del_model( + model_name, base_model=base_model, model_type=model_type + ) logger.info(f"Deleted model: {model_name}") return Response(status_code=204) except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) + @models_router.put( "/convert/{base_model}/{model_type}/{model_name}", operation_id="convert_model", responses={ - 200: { "description": "Model converted successfully" }, - 400: {"description" : "Bad request" }, - 404: { "description": "Model not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, }, - status_code = 200, - response_model = ConvertModelResponse, + status_code=200, + response_model=ConvertModelResponse, ) async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + convert_dest_directory: Optional[str] = Query( + default=None, description="Save the converted model to the designated directory" + ), ) -> ConvertModelResponse: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model(model_name, - base_model = base_model, - model_type = model_type, - convert_dest_directory = dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, - base_model = base_model, - model_type = model_type) + ApiDependencies.invoker.services.model_manager.convert_model( + model_name, + base_model=base_model, + model_type=model_type, + convert_dest_directory=dest, + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name, base_model=base_model, model_type=model_type + ) response = parse_obj_as(ConvertModelResponse, model_raw) except ModelNotFoundException as e: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") @@ -283,91 +280,101 @@ async def convert_model( raise HTTPException(status_code=400, detail=str(e)) return response + @models_router.get( "/search", operation_id="search_for_models", responses={ - 200: { "description": "Directory searched successfully" }, - 404: { "description": "Invalid directory path" }, + 200: {"description": "Directory searched successfully"}, + 404: {"description": "Invalid directory path"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path], ) async def search_for_models( - search_path: pathlib.Path = Query(description="Directory path to search for models") -)->List[pathlib.Path]: + search_path: pathlib.Path = Query(description="Directory path to search for models"), +) -> List[pathlib.Path]: if not search_path.is_dir(): - raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + raise HTTPException( + status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory" + ) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) + @models_router.get( "/ckpt_confs", operation_id="list_ckpt_configs", responses={ - 200: { "description" : "paths retrieved successfully" }, + 200: {"description": "paths retrieved successfully"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path], ) -async def list_ckpt_configs( -)->List[pathlib.Path]: +async def list_ckpt_configs() -> List[pathlib.Path]: """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - + + @models_router.post( "/sync", operation_id="sync_to_config", responses={ - 201: { "description": "synchronization successful" }, + 201: {"description": "synchronization successful"}, }, - status_code = 201, - response_model = bool + status_code=201, + response_model=bool, ) -async def sync_to_config( -)->bool: +async def sync_to_config() -> bool: """Call after making changes to models.yaml, autoimport directories or models directory to synchronize in-memory data structures with disk data structures.""" ApiDependencies.invoker.services.model_manager.sync_to_config() return True - + + @models_router.put( "/merge/{base_model}", operation_id="merge_models", responses={ - 200: { "description": "Model converted successfully" }, - 400: { "description": "Incompatible models" }, - 404: { "description": "One or more models not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Incompatible models"}, + 404: {"description": "One or more models not found"}, }, - status_code = 200, - response_model = MergeModelResponse, + status_code=200, + response_model=MergeModelResponse, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description="Name of destination model"), - alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), - interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), - force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), - merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model"), + alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), + force: Optional[bool] = Body( + description="Force merging of models created with different versions of diffusers", default=False + ), + merge_dest_directory: Optional[str] = Body( + description="Save the merged model to the designated directory (with 'merged_model_name' appended)", + default=None, + ), ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, - base_model, - merged_model_name=merged_model_name or "+".join(model_names), - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory = dest - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, - base_model = base_model, - model_type = ModelType.Main, - ) + result = ApiDependencies.invoker.services.model_manager.merge_models( + model_names, + base_model, + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=dest, + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + result.name, + base_model=base_model, + model_type=ModelType.Main, + ) response = parse_obj_as(ConvertModelResponse, model_raw) except ModelNotFoundException: raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 96c536dc94..f994e4a3d9 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -17,15 +17,16 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( - ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, - image_resized_to_grid_as_tensor) -from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ - PostprocessingSettings + ConditioningData, + ControlNetData, + StableDiffusionGeneratorPipeline, + image_resized_to_grid_as_tensor, +) +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .compel import ConditioningField from .controlnet_image_processors import ControlField from .image import ImageOutput @@ -46,8 +47,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field( - default=None, description="The name of the latents") + latents_name: Optional[str] = Field(default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} @@ -55,14 +55,15 @@ class LatentsField(BaseModel): class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" - #fmt: off + + # fmt: off type: Literal["latents_output"] = "latents_output" # Inputs latents: LatentsField = Field(default=None, description="The output latents") width: int = Field(description="The width of the latents in pixels") height: int = Field(description="The height of the latents in pixels") - #fmt: on + # fmt: on def build_latents_output(latents_name: str, latents: torch.Tensor): @@ -73,9 +74,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor): ) -SAMPLER_NAME_VALUES = Literal[ - tuple(list(SCHEDULER_MAP.keys())) -] +SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] def get_scheduler( @@ -83,11 +82,10 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim'] - ) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict(), context=context, + **scheduler_info.dict(), + context=context, ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -102,7 +100,7 @@ def get_scheduler( scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py - if not hasattr(scheduler, 'uses_inpainting_model'): + if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False return scheduler @@ -123,8 +121,8 @@ class TextToLatentsInvocation(BaseInvocation): scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) unet: UNetField = Field(default=None, description="UNet submodel") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @validator("cfg_scale") @@ -133,10 +131,10 @@ class TextToLatentsInvocation(BaseInvocation): if isinstance(v, list): for i in v: if i < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") else: if v < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") return v # Schema customisation @@ -149,8 +147,8 @@ class TextToLatentsInvocation(BaseInvocation): "model": "model", "control": "control", # "cfg_scale": "float", - "cfg_scale": "number" - } + "cfg_scale": "number", + }, }, } @@ -190,16 +188,14 @@ class TextToLatentsInvocation(BaseInvocation): threshold=0.0, # threshold, warmup=0.2, # warmup, h_symmetry_time_pct=None, # h_symmetry_time_pct, - v_symmetry_time_pct=None # v_symmetry_time_pct, + v_symmetry_time_pct=None, # v_symmetry_time_pct, ), ) conditioning_data = conditioning_data.add_scheduler_args_if_applicable( scheduler, - # for ddim scheduler eta=0.0, # ddim_eta - # for ancestral and sde schedulers generator=torch.Generator(device=unet.device).manual_seed(0), ) @@ -247,7 +243,6 @@ class TextToLatentsInvocation(BaseInvocation): exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents control_height_resize = latents_shape[2] * 8 control_width_resize = latents_shape[3] * 8 @@ -261,7 +256,7 @@ class TextToLatentsInvocation(BaseInvocation): control_list = control_input else: control_list = None - if (control_list is None): + if control_list is None: control_data = None # from above handling, any control that is not None should now be of type list[ControlField] else: @@ -281,9 +276,7 @@ class TextToLatentsInvocation(BaseInvocation): control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image( - control_image_field.image_name - ) + input_image = context.services.images.get_pil_image(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -322,9 +315,7 @@ class TextToLatentsInvocation(BaseInvocation): noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -333,19 +324,20 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context, + **lora.dict(exclude={"weight"}), + context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context, + **self.unet.unet.dict(), + context=context, ) - with ExitStack() as exit_stack,\ - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( + unet_info.context.model, _lora_loader() + ), unet_info as unet: noise = noise.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( @@ -358,7 +350,9 @@ class TextToLatentsInvocation(BaseInvocation): conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( - model=pipeline, context=context, control_input=self.control, + model=pipeline, + context=context, + control_input=self.control, latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, @@ -379,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation): result_latents = result_latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -390,11 +384,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to use as a base image") - strength: float = Field( - default=0.7, ge=0, le=1, - description="The strength of the latents to use") + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -406,7 +397,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): "model": "model", "control": "control", "cfg_scale": "number", - } + }, }, } @@ -417,9 +408,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -428,19 +417,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context, + **lora.dict(exclude={"weight"}), + context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context, + **self.unet.unet.dict(), + context=context, ) - with ExitStack() as exit_stack,\ - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( + unet_info.context.model, _lora_loader() + ), unet_info as unet: noise = noise.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype) @@ -454,7 +444,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( - model=pipeline, context=context, control_input=self.control, + model=pipeline, + context=context, + control_input=self.control, latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, @@ -462,8 +454,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) # TODO: Verify the noise is the right size - initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype + initial_latents = ( + latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) ) timesteps, _ = pipeline.get_img2img_timesteps( @@ -479,14 +471,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): num_inference_steps=self.steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] - callback=step_callback + callback=step_callback, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -498,14 +490,13 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to generate an image from") + latents: Optional[LatentsField] = Field(description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field( - default=False, - description="Decode latents by overlaping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") + tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") + metadata: Optional[CoreMetadata] = Field( + default=None, description="Optional core metadata to be written to the image" + ) # Schema customisation class Config(InvocationConfig): @@ -521,7 +512,8 @@ class LatentsToImageInvocation(BaseInvocation): latents = context.services.latents.get(self.latents.latents_name) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), context=context, + **self.vae.vae.dict(), + context=context, ) with vae_info as vae: @@ -588,8 +580,7 @@ class LatentsToImageInvocation(BaseInvocation): ) -LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", - "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -598,36 +589,30 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to resize") - width: Union[int, None] = Field(default=512, - ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(default=512, - ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode") + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") + height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") antialias: bool = Field( - default=False, - description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + ) class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Resize Latents", - "tags": ["latents", "resize"] - }, + "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]}, } def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # TODO: - device=choose_torch_device() + device = choose_torch_device() resized_latents = torch.nn.functional.interpolate( - latents.to(device), size=(self.height // 8, self.width // 8), - mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False, + latents.to(device), + size=(self.height // 8, self.width // 8), + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -646,35 +631,30 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to scale") - scale_factor: float = Field( - gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode") + latents: Optional[LatentsField] = Field(description="The latents to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") antialias: bool = Field( - default=False, - description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + ) class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Scale Latents", - "tags": ["latents", "scale"] - }, + "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]}, } def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # TODO: - device=choose_torch_device() + device = choose_torch_device() # resizing resized_latents = torch.nn.functional.interpolate( - latents.to(device), scale_factor=self.scale_factor, mode=self.mode, - antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False, + latents.to(device), + scale_factor=self.scale_factor, + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -695,19 +675,13 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Optional[ImageField] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field( - default=False, - description="Encode latents by overlaping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") - + tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Image To Latents", - "tags": ["latents", "image"] - }, + "ui": {"title": "Image To Latents", "tags": ["latents", "image"]}, } @torch.no_grad() @@ -717,9 +691,10 @@ class ImageToLatentsInvocation(BaseInvocation): # ) image = context.services.images.get_pil_image(self.image.image_name) - #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), context=context, + **self.vae.vae.dict(), + context=context, ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -746,12 +721,12 @@ class ImageToLatentsInvocation(BaseInvocation): vae.post_quant_conv.to(orig_dtype) vae.decoder.conv_in.to(orig_dtype) vae.decoder.mid_block.to(orig_dtype) - #else: + # else: # latents = latents.float() else: vae.to(dtype=torch.float16) - #latents = latents.half() + # latents = latents.half() if self.tiled: vae.enable_tiling() @@ -762,9 +737,7 @@ class ImageToLatentsInvocation(BaseInvocation): image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(): image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to( - dtype=vae.dtype - ) # FIXME: uses torch.randn. make reproducible! + latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) diff --git a/scripts/probe-model.py b/scripts/probe-model.py index eca2a0985a..7281dafc3f 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -7,13 +7,10 @@ from invokeai.backend.model_management.model_probe import ModelProbe parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( - 'model_path', + "model_path", type=Path, ) -args=parser.parse_args() +args = parser.parse_args() info = ModelProbe().probe(args.model_path) print(info) - - -