diff --git a/.dev_scripts/diff_images.py b/.dev_scripts/diff_images.py index e21cae214e..5208ed41ec 100644 --- a/.dev_scripts/diff_images.py +++ b/.dev_scripts/diff_images.py @@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path): def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('image1_path') - parser.add_argument('image2_path') + parser.add_argument("image1_path") + parser.add_argument("image2_path") args = parser.parse_args() return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path) print(mean_L1) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 050d1b091f..a186daedf5 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -78,9 +78,7 @@ class ApiDependencies: image_record_storage = SqliteImageRecordStorage(db_location) image_file_storage = DiskImageFileStorage(f"{output_folder}/images") names = SimpleNameService() - latents = ForwardCacheLatentsStorage( - DiskLatentsStorage(f"{output_folder}/latents") - ) + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) board_record_storage = SqliteBoardRecordStorage(db_location) board_image_record_storage = SqliteBoardImageRecordStorage(db_location) @@ -125,9 +123,7 @@ class ApiDependencies: boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=db_location, table_name="graphs" - ), + graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), configuration=config, diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index e37184a77b..9d9e47d2ef 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -15,6 +15,7 @@ from invokeai.version import __version__ from ..dependencies import ApiDependencies from invokeai.backend.util.logging import logging + class LogLevel(int, Enum): NotSet = logging.NOTSET Debug = logging.DEBUG @@ -23,10 +24,12 @@ class LogLevel(int, Enum): Error = logging.ERROR Critical = logging.CRITICAL + class Upscaler(BaseModel): upscaling_method: str = Field(description="Name of upscaling method") upscaling_models: list[str] = Field(description="List of upscaling models for this method") - + + app_router = APIRouter(prefix="/v1/app", tags=["app"]) @@ -45,38 +48,30 @@ class AppConfig(BaseModel): watermarking_methods: list[str] = Field(description="List of invisible watermark methods") -@app_router.get( - "/version", operation_id="app_version", status_code=200, response_model=AppVersion -) +@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion) async def get_version() -> AppVersion: return AppVersion(version=__version__) -@app_router.get( - "/config", operation_id="get_config", status_code=200, response_model=AppConfig -) +@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig) async def get_config() -> AppConfig: - infill_methods = ['tile'] + infill_methods = ["tile"] if PatchMatch.patchmatch_available(): - infill_methods.append('patchmatch') - + infill_methods.append("patchmatch") upscaling_models = [] for model in typing.get_args(ESRGAN_MODELS): upscaling_models.append(str(Path(model).stem)) - upscaler = Upscaler( - upscaling_method = 'esrgan', - upscaling_models = upscaling_models - ) - + upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models) + nsfw_methods = [] if SafetyChecker.safety_checker_available(): - nsfw_methods.append('nsfw_checker') + nsfw_methods.append("nsfw_checker") watermarking_methods = [] if InvisibleWatermark.invisible_watermark_available(): - watermarking_methods.append('invisible_watermark') - + watermarking_methods.append("invisible_watermark") + return AppConfig( infill_methods=infill_methods, upscaling_methods=[upscaler], @@ -84,25 +79,26 @@ async def get_config() -> AppConfig: watermarking_methods=watermarking_methods, ) + @app_router.get( "/logging", operation_id="get_log_level", - responses={200: {"description" : "The operation was successful"}}, - response_model = LogLevel, + responses={200: {"description": "The operation was successful"}}, + response_model=LogLevel, ) -async def get_log_level( -) -> LogLevel: +async def get_log_level() -> LogLevel: """Returns the log level""" return LogLevel(ApiDependencies.invoker.services.logger.level) + @app_router.post( "/logging", operation_id="set_log_level", - responses={200: {"description" : "The operation was successful"}}, - response_model = LogLevel, + responses={200: {"description": "The operation was successful"}}, + response_model=LogLevel, ) async def set_log_level( - level: LogLevel = Body(description="New log verbosity level"), + level: LogLevel = Body(description="New log verbosity level"), ) -> LogLevel: """Sets the log verbosity level""" ApiDependencies.invoker.services.logger.setLevel(level) diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py index 651310af24..6cb073ca7c 100644 --- a/invokeai/app/api/routers/board_images.py +++ b/invokeai/app/api/routers/board_images.py @@ -52,4 +52,3 @@ async def remove_board_image( return result except Exception as e: raise HTTPException(status_code=500, detail="Failed to update board") - diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index f3de7f4952..69f4a8b3a7 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -18,9 +18,7 @@ class DeleteBoardResult(BaseModel): deleted_board_images: list[str] = Field( description="The image names of the board-images relationships that were deleted." ) - deleted_images: list[str] = Field( - description="The names of the images that were deleted." - ) + deleted_images: list[str] = Field(description="The names of the images that were deleted.") @boards_router.post( @@ -73,22 +71,16 @@ async def update_board( ) -> BoardDTO: """Updates a board""" try: - result = ApiDependencies.invoker.services.boards.update( - board_id=board_id, changes=changes - ) + result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes) return result except Exception as e: raise HTTPException(status_code=500, detail="Failed to update board") -@boards_router.delete( - "/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult -) +@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult) async def delete_board( board_id: str = Path(description="The id of board to delete"), - include_images: Optional[bool] = Query( - description="Permanently delete all images on the board", default=False - ), + include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False), ) -> DeleteBoardResult: """Deletes a board""" try: @@ -96,9 +88,7 @@ async def delete_board( deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( board_id=board_id ) - ApiDependencies.invoker.services.images.delete_images_on_board( - board_id=board_id - ) + ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id) ApiDependencies.invoker.services.boards.delete(board_id=board_id) return DeleteBoardResult( board_id=board_id, @@ -127,9 +117,7 @@ async def delete_board( async def list_boards( all: Optional[bool] = Query(default=None, description="Whether to list all boards"), offset: Optional[int] = Query(default=None, description="The page offset"), - limit: Optional[int] = Query( - default=None, description="The number of boards per page" - ), + limit: Optional[int] = Query(default=None, description="The number of boards per page"), ) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]: """Gets a list of boards""" if all: diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 36e2e3d75d..498a1139e4 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -40,15 +40,9 @@ async def upload_image( response: Response, image_category: ImageCategory = Query(description="The category of the image"), is_intermediate: bool = Query(description="Whether this is an intermediate image"), - board_id: Optional[str] = Query( - default=None, description="The board to add this image to, if any" - ), - session_id: Optional[str] = Query( - default=None, description="The session ID associated with this upload, if any" - ), - crop_visible: Optional[bool] = Query( - default=False, description="Whether to crop the image" - ), + board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), + session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), + crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), ) -> ImageDTO: """Uploads an image""" if not file.content_type.startswith("image"): @@ -115,9 +109,7 @@ async def clear_intermediates() -> int: ) async def update_image( image_name: str = Path(description="The name of the image to update"), - image_changes: ImageRecordChanges = Body( - description="The changes to apply to the image" - ), + image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"), ) -> ImageDTO: """Updates an image""" @@ -212,15 +204,11 @@ async def get_image_thumbnail( """Gets a thumbnail image file""" try: - path = ApiDependencies.invoker.services.images.get_path( - image_name, thumbnail=True - ) + path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) - response = FileResponse( - path, media_type="image/webp", content_disposition_type="inline" - ) + response = FileResponse(path, media_type="image/webp", content_disposition_type="inline") response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" return response except Exception as e: @@ -239,9 +227,7 @@ async def get_image_urls( try: image_url = ApiDependencies.invoker.services.images.get_url(image_name) - thumbnail_url = ApiDependencies.invoker.services.images.get_url( - image_name, thumbnail=True - ) + thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True) return ImageUrlsDTO( image_name=image_name, image_url=image_url, @@ -257,15 +243,9 @@ async def get_image_urls( response_model=OffsetPaginatedResults[ImageDTO], ) async def list_image_dtos( - image_origin: Optional[ResourceOrigin] = Query( - default=None, description="The origin of images to list." - ), - categories: Optional[list[ImageCategory]] = Query( - default=None, description="The categories of image to include." - ), - is_intermediate: Optional[bool] = Query( - default=None, description="Whether to list intermediate images." - ), + image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."), + categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), + is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."), board_id: Optional[str] = Query( default=None, description="The board id to filter by. Use 'none' to find images without a board.", diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 759f6c9f59..3b2907b937 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -28,49 +28,52 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)] ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)] + class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] + @models_router.get( "/", operation_id="list_models", - responses={200: {"model": ModelsList }}, + responses={200: {"model": ModelsList}}, ) async def list_models( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" - if base_models and len(base_models)>0: + if base_models and len(base_models) > 0: models_raw = list() for base_model in base_models: models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) else: models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) - models = parse_obj_as(ModelsList, { "models": models_raw }) + models = parse_obj_as(ModelsList, {"models": models_raw}) return models + @models_router.patch( "/{base_model}/{model_type}/{model_name}", operation_id="update_model", - responses={200: {"description" : "The model was updated successfully"}, - 400: {"description" : "Bad request"}, - 404: {"description" : "The model could not be found"}, - 409: {"description" : "There is already a model corresponding to the new name"}, - }, - status_code = 200, - response_model = UpdateModelResponse, + responses={ + 200: {"description": "The model was updated successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "The model could not be found"}, + 409: {"description": "There is already a model corresponding to the new name"}, + }, + status_code=200, + response_model=UpdateModelResponse, ) async def update_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> UpdateModelResponse: - """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """ + """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger - try: previous_info = ApiDependencies.invoker.services.model_manager.list_model( model_name=model_name, @@ -81,13 +84,13 @@ async def update_model( # rename operation requested if info.model_name != model_name or info.base_model != base_model: ApiDependencies.invoker.services.model_manager.rename_model( - base_model = base_model, - model_type = model_type, - model_name = model_name, - new_name = info.model_name, - new_base = info.base_model, + base_model=base_model, + model_type=model_type, + model_name=model_name, + new_name=info.model_name, + new_base=info.base_model, ) - logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}') + logger.info(f"Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}") # update information to support an update of attributes model_name = info.model_name base_model = info.base_model @@ -96,16 +99,15 @@ async def update_model( base_model=base_model, model_type=model_type, ) - if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it - info.path = new_info.get('path') - + if new_info.get("path") != previous_info.get( + "path" + ): # model manager moved model path during rename - don't overwrite it + info.path = new_info.get("path") + ApiDependencies.invoker.services.model_manager.update_model( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_attributes=info.dict() + model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() ) - + model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name=model_name, base_model=base_model, @@ -123,49 +125,48 @@ async def update_model( return model_response + @models_router.post( "/import", operation_id="import_model", - responses= { - 201: {"description" : "The model imported successfully"}, - 404: {"description" : "The model could not be found"}, - 415: {"description" : "Unrecognized file/folder format"}, - 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model imported successfully"}, + 404: {"description": "The model could not be found"}, + 415: {"description": "Unrecognized file/folder format"}, + 424: {"description": "The model appeared to import successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse + response_model=ImportModelResponse, ) async def import_model( - location: str = Body(description="A model path, repo_id or URL to import"), - prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ - Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), + location: str = Body(description="A model path, repo_id or URL to import"), + prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body( + description="Prediction type for SDv2 checkpoint files", default="v_prediction" + ), ) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ - + """Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically""" + items_to_import = {location} - prediction_types = { x.value: x for x in SchedulerPredictionType } + prediction_types = {x.value: x for x in SchedulerPredictionType} logger = ApiDependencies.invoker.services.logger try: installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) + items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type) ) info = installed_models.get(location) if not info: logger.error("Import failed") raise HTTPException(status_code=415) - - logger.info(f'Successfully imported {location}, got {info}') + + logger.info(f"Successfully imported {location}, got {info}") model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, - base_model=info.base_model, - model_type=info.model_type + model_name=info.name, base_model=info.base_model, model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) - + except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) @@ -175,38 +176,34 @@ async def import_model( except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + + @models_router.post( "/add", operation_id="add_model", - responses= { - 201: {"description" : "The model added successfully"}, - 404: {"description" : "The model could not be found"}, - 424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + responses={ + 201: {"description": "The model added successfully"}, + 404: {"description": "The model could not be found"}, + 424: {"description": "The model appeared to add successfully, but could not be found in the model manager"}, + 409: {"description": "There is already a model corresponding to this path or repo_id"}, }, status_code=201, - response_model=ImportModelResponse + response_model=ImportModelResponse, ) async def add_model( - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> ImportModelResponse: - """ Add a model using the configuration information appropriate for its type. Only local models can be added by path""" - + """Add a model using the configuration information appropriate for its type. Only local models can be added by path""" + logger = ApiDependencies.invoker.services.logger try: ApiDependencies.invoker.services.model_manager.add_model( - info.model_name, - info.base_model, - info.model_type, - model_attributes = info.dict() + info.model_name, info.base_model, info.model_type, model_attributes=info.dict() ) - logger.info(f'Successfully added {info.model_name}') + logger.info(f"Successfully added {info.model_name}") model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.model_name, - base_model=info.base_model, - model_type=info.model_type + model_name=info.model_name, base_model=info.base_model, model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) except ModelNotFoundException as e: @@ -216,66 +213,66 @@ async def add_model( logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) - + @models_router.delete( "/{base_model}/{model_type}/{model_name}", operation_id="del_model", - responses={ - 204: { "description": "Model deleted successfully" }, - 404: { "description": "Model not found" } - }, - status_code = 204, - response_model = None, + responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, + status_code=204, + response_model=None, ) async def delete_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), ) -> Response: """Delete Model""" logger = ApiDependencies.invoker.services.logger - + try: - ApiDependencies.invoker.services.model_manager.del_model(model_name, - base_model = base_model, - model_type = model_type - ) + ApiDependencies.invoker.services.model_manager.del_model( + model_name, base_model=base_model, model_type=model_type + ) logger.info(f"Deleted model: {model_name}") return Response(status_code=204) except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) + @models_router.put( "/convert/{base_model}/{model_type}/{model_name}", operation_id="convert_model", responses={ - 200: { "description": "Model converted successfully" }, - 400: {"description" : "Bad request" }, - 404: { "description": "Model not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Bad request"}, + 404: {"description": "Model not found"}, }, - status_code = 200, - response_model = ConvertModelResponse, + status_code=200, + response_model=ConvertModelResponse, ) async def convert_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="model name"), - convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + convert_dest_directory: Optional[str] = Query( + default=None, description="Save the converted model to the designated directory" + ), ) -> ConvertModelResponse: """Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none.""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Converting model: {model_name}") dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None - ApiDependencies.invoker.services.model_manager.convert_model(model_name, - base_model = base_model, - model_type = model_type, - convert_dest_directory = dest, - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name, - base_model = base_model, - model_type = model_type) + ApiDependencies.invoker.services.model_manager.convert_model( + model_name, + base_model=base_model, + model_type=model_type, + convert_dest_directory=dest, + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + model_name, base_model=base_model, model_type=model_type + ) response = parse_obj_as(ConvertModelResponse, model_raw) except ModelNotFoundException as e: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") @@ -283,91 +280,101 @@ async def convert_model( raise HTTPException(status_code=400, detail=str(e)) return response + @models_router.get( "/search", operation_id="search_for_models", responses={ - 200: { "description": "Directory searched successfully" }, - 404: { "description": "Invalid directory path" }, + 200: {"description": "Directory searched successfully"}, + 404: {"description": "Invalid directory path"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path], ) async def search_for_models( - search_path: pathlib.Path = Query(description="Directory path to search for models") -)->List[pathlib.Path]: + search_path: pathlib.Path = Query(description="Directory path to search for models"), +) -> List[pathlib.Path]: if not search_path.is_dir(): - raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") + raise HTTPException( + status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory" + ) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path) + @models_router.get( "/ckpt_confs", operation_id="list_ckpt_configs", responses={ - 200: { "description" : "paths retrieved successfully" }, + 200: {"description": "paths retrieved successfully"}, }, - status_code = 200, - response_model = List[pathlib.Path] + status_code=200, + response_model=List[pathlib.Path], ) -async def list_ckpt_configs( -)->List[pathlib.Path]: +async def list_ckpt_configs() -> List[pathlib.Path]: """Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT.""" return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs() - - + + @models_router.post( "/sync", operation_id="sync_to_config", responses={ - 201: { "description": "synchronization successful" }, + 201: {"description": "synchronization successful"}, }, - status_code = 201, - response_model = bool + status_code=201, + response_model=bool, ) -async def sync_to_config( -)->bool: +async def sync_to_config() -> bool: """Call after making changes to models.yaml, autoimport directories or models directory to synchronize in-memory data structures with disk data structures.""" ApiDependencies.invoker.services.model_manager.sync_to_config() return True - + + @models_router.put( "/merge/{base_model}", operation_id="merge_models", responses={ - 200: { "description": "Model converted successfully" }, - 400: { "description": "Incompatible models" }, - 404: { "description": "One or more models not found" }, + 200: {"description": "Model converted successfully"}, + 400: {"description": "Incompatible models"}, + 404: {"description": "One or more models not found"}, }, - status_code = 200, - response_model = MergeModelResponse, + status_code=200, + response_model=MergeModelResponse, ) async def merge_models( - base_model: BaseModelType = Path(description="Base model"), - model_names: List[str] = Body(description="model name", min_items=2, max_items=3), - merged_model_name: Optional[str] = Body(description="Name of destination model"), - alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), - interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), - force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False), - merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None) + base_model: BaseModelType = Path(description="Base model"), + model_names: List[str] = Body(description="model name", min_items=2, max_items=3), + merged_model_name: Optional[str] = Body(description="Name of destination model"), + alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5), + interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"), + force: Optional[bool] = Body( + description="Force merging of models created with different versions of diffusers", default=False + ), + merge_dest_directory: Optional[str] = Body( + description="Save the merged model to the designated directory (with 'merged_model_name' appended)", + default=None, + ), ) -> MergeModelResponse: """Convert a checkpoint model into a diffusers model""" logger = ApiDependencies.invoker.services.logger try: logger.info(f"Merging models: {model_names} into {merge_dest_directory or ''}/{merged_model_name}") dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None - result = ApiDependencies.invoker.services.model_manager.merge_models(model_names, - base_model, - merged_model_name=merged_model_name or "+".join(model_names), - alpha=alpha, - interp=interp, - force=force, - merge_dest_directory = dest - ) - model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name, - base_model = base_model, - model_type = ModelType.Main, - ) + result = ApiDependencies.invoker.services.model_manager.merge_models( + model_names, + base_model, + merged_model_name=merged_model_name or "+".join(model_names), + alpha=alpha, + interp=interp, + force=force, + merge_dest_directory=dest, + ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( + result.name, + base_model=base_model, + model_type=ModelType.Main, + ) response = parse_obj_as(ConvertModelResponse, model_raw) except ModelNotFoundException: raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index da842a3968..e4ba2a353e 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -30,9 +30,7 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"]) }, ) async def create_session( - graph: Optional[Graph] = Body( - default=None, description="The graph to initialize the session with" - ) + graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with") ) -> GraphExecutionState: """Creates a new session, optionally initializing it with an invocation graph""" session = ApiDependencies.invoker.create_execution_state(graph) @@ -51,13 +49,9 @@ async def list_sessions( ) -> PaginatedResults[GraphExecutionState]: """Gets a list of sessions, optionally searching""" if query == "": - result = ApiDependencies.invoker.services.graph_execution_manager.list( - page, per_page - ) + result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page) else: - result = ApiDependencies.invoker.services.graph_execution_manager.search( - query, page, per_page - ) + result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page) return result @@ -91,9 +85,9 @@ async def get_session( ) async def add_node( session_id: str = Path(description="The id of the session"), - node: Annotated[ - Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore - ] = Body(description="The node to add"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore + description="The node to add" + ), ) -> str: """Adds a node to the graph""" session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) @@ -124,9 +118,9 @@ async def add_node( async def update_node( session_id: str = Path(description="The id of the session"), node_path: str = Path(description="The path to the node in the graph"), - node: Annotated[ - Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore - ] = Body(description="The new node"), + node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore + description="The new node" + ), ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) @@ -230,7 +224,7 @@ async def delete_edge( try: edge = Edge( source=EdgeConnection(node_id=from_node_id, field=from_field), - destination=EdgeConnection(node_id=to_node_id, field=to_field) + destination=EdgeConnection(node_id=to_node_id, field=to_field), ) session.delete_edge(edge) ApiDependencies.invoker.services.graph_execution_manager.set( @@ -255,9 +249,7 @@ async def delete_edge( ) async def invoke_session( session_id: str = Path(description="The id of the session to invoke"), - all: bool = Query( - default=False, description="Whether or not to invoke all remaining invocations" - ), + all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"), ) -> Response: """Invokes a session""" session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) @@ -274,9 +266,7 @@ async def invoke_session( @session_router.delete( "/{session_id}/invoke", operation_id="cancel_session_invoke", - responses={ - 202: {"description": "The invocation is canceled"} - }, + responses={202: {"description": "The invocation is canceled"}}, ) async def cancel_session_invoke( session_id: str = Path(description="The id of the session to cancel"), diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index f70d7a6609..4591bac540 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -16,9 +16,7 @@ class SocketIO: self.__sio.on("subscribe", handler=self._handle_sub) self.__sio.on("unsubscribe", handler=self._handle_unsub) - local_handler.register( - event_name=EventServiceBase.session_event, _func=self._handle_session_event - ) + local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event) async def _handle_session_event(self, event: Event): await self.__sio.emit( diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 044271779c..1f3a166e62 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -16,9 +16,10 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware from pathlib import Path from pydantic.schema import schema -#This should come early so that modules can log their initialization properly +# This should come early so that modules can log their initialization properly from .services.config import InvokeAIAppConfig from ..backend.util.logging import InvokeAILogger + app_config = InvokeAIAppConfig.get_config() app_config.parse_args() logger = InvokeAILogger.getLogger(config=app_config) @@ -27,7 +28,7 @@ from invokeai.version.invokeai_version import __version__ # we call this early so that the message appears before # other invokeai initialization messages if app_config.version: - print(f'InvokeAI version {__version__}') + print(f"InvokeAI version {__version__}") sys.exit(0) import invokeai.frontend.web as web_dir @@ -37,17 +38,18 @@ from .api.dependencies import ApiDependencies from .api.routers import sessions, models, images, boards, board_images, app_info from .api.sockets import SocketIO from .invocations.baseinvocation import BaseInvocation - + import torch import invokeai.backend.util.hotfixes + if torch.backends.mps.is_available(): import invokeai.backend.util.mps_fixes # fix for windows mimetypes registry entries being borked # see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352 -mimetypes.add_type('application/javascript', '.js') -mimetypes.add_type('text/css', '.css') +mimetypes.add_type("application/javascript", ".js") +mimetypes.add_type("text/css", ".css") # Create the app # TODO: create this all in a method so configuration/etc. can be passed in? @@ -57,14 +59,13 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None) event_handler_id: int = id(app) app.add_middleware( EventHandlerASGIMiddleware, - handlers=[ - local_handler - ], # TODO: consider doing this in services to support different configurations + handlers=[local_handler], # TODO: consider doing this in services to support different configurations middleware_id=event_handler_id, ) socket_io = SocketIO(app) + # Add startup event to load dependencies @app.on_event("startup") async def startup_event(): @@ -76,9 +77,7 @@ async def startup_event(): allow_headers=app_config.allow_headers, ) - ApiDependencies.initialize( - config=app_config, event_handler_id=event_handler_id, logger=logger - ) + ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) # Shut down threads @@ -103,7 +102,8 @@ app.include_router(boards.boards_router, prefix="/api") app.include_router(board_images.board_images_router, prefix="/api") -app.include_router(app_info.app_router, prefix='/api') +app.include_router(app_info.app_router, prefix="/api") + # Build a custom OpenAPI to include all outputs # TODO: can outputs be included on metadata of invocation schemas somehow? @@ -144,6 +144,7 @@ def custom_openapi(): invoker_schema["output"] = outputs_ref from invokeai.backend.model_management.models import get_model_config_enums + for model_config_format_enum in set(get_model_config_enums()): name = model_config_format_enum.__qualname__ @@ -166,7 +167,8 @@ def custom_openapi(): app.openapi = custom_openapi # Override API doc favicons -app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static") +app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static") + @app.get("/docs", include_in_schema=False) def overridden_swagger(): @@ -187,11 +189,8 @@ def overridden_redoc(): # Must mount *after* the other routes else it borks em -app.mount("/", - StaticFiles(directory=Path(web_dir.__path__[0],"dist"), - html=True - ), name="ui" - ) +app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui") + def invoke_api(): def find_port(port: int): @@ -203,10 +202,11 @@ def invoke_api(): return find_port(port=port + 1) else: return port - + from invokeai.backend.install.check_root import check_invokeai_root + check_invokeai_root(app_config) # note, may exit with an exception if root not set up - + port = find_port(app_config.port) if port != app_config.port: logger.warn(f"Port {app_config.port} in use, using port {port}") @@ -217,5 +217,6 @@ def invoke_api(): server = uvicorn.Server(config) loop.run_until_complete(server.serve()) + if __name__ == "__main__": invoke_api() diff --git a/invokeai/app/cli/commands.py b/invokeai/app/cli/commands.py index bffb2988dc..64ea6034fc 100644 --- a/invokeai/app/cli/commands.py +++ b/invokeai/app/cli/commands.py @@ -14,8 +14,14 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge from ..services.invoker import Invoker -def add_field_argument(command_parser, name: str, field, default_override = None): - default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() +def add_field_argument(command_parser, name: str, field, default_override=None): + default = ( + default_override + if default_override is not None + else field.default + if field.default_factory is None + else field.default_factory() + ) if get_origin(field.type_) == Literal: allowed_values = get_args(field.type_) allowed_types = set() @@ -47,8 +53,8 @@ def add_parsers( commands: list[type], command_field: str = "type", exclude_fields: list[str] = ["id", "type"], - add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None - ): + add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None, +): """Adds parsers for each command to the subparsers""" # Create subparsers for each command @@ -61,7 +67,7 @@ def add_parsers( add_arguments(command_parser) # Convert all fields to arguments - fields = command.__fields__ # type: ignore + fields = command.__fields__ # type: ignore for name, field in fields.items(): if name in exclude_fields: continue @@ -70,13 +76,11 @@ def add_parsers( def add_graph_parsers( - subparsers, - graphs: list[LibraryGraph], - add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None + subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None ): for graph in graphs: command_parser = subparsers.add_parser(graph.name, help=graph.description) - + if add_arguments is not None: add_arguments(command_parser) @@ -128,6 +132,7 @@ class CliContext: class ExitCli(Exception): """Exception to exit the CLI""" + pass @@ -155,7 +160,7 @@ class BaseCommand(ABC, BaseModel): @classmethod def get_commands_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses())) + return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses())) @abstractmethod def run(self, context: CliContext) -> None: @@ -165,7 +170,8 @@ class BaseCommand(ABC, BaseModel): class ExitCommand(BaseCommand): """Exits the CLI""" - type: Literal['exit'] = 'exit' + + type: Literal["exit"] = "exit" def run(self, context: CliContext) -> None: raise ExitCli() @@ -173,7 +179,8 @@ class ExitCommand(BaseCommand): class HelpCommand(BaseCommand): """Shows help""" - type: Literal['help'] = 'help' + + type: Literal["help"] = "help" def run(self, context: CliContext) -> None: context.parser.print_help() @@ -183,11 +190,7 @@ def get_graph_execution_history( graph_execution_state: GraphExecutionState, ) -> Iterable[str]: """Gets the history of fully-executed invocations for a graph execution""" - return ( - n - for n in reversed(graph_execution_state.executed_history) - if n in graph_execution_state.graph.nodes - ) + return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes) def get_invocation_command(invocation) -> str: @@ -218,7 +221,8 @@ def get_invocation_command(invocation) -> str: class HistoryCommand(BaseCommand): """Shows the invocation history""" - type: Literal['history'] = 'history' + + type: Literal["history"] = "history" # Inputs # fmt: off @@ -235,7 +239,8 @@ class HistoryCommand(BaseCommand): class SetDefaultCommand(BaseCommand): """Sets a default value for a field""" - type: Literal['default'] = 'default' + + type: Literal["default"] = "default" # Inputs # fmt: off @@ -253,7 +258,8 @@ class SetDefaultCommand(BaseCommand): class DrawGraphCommand(BaseCommand): """Debugs a graph""" - type: Literal['draw_graph'] = 'draw_graph' + + type: Literal["draw_graph"] = "draw_graph" def run(self, context: CliContext) -> None: session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) @@ -271,7 +277,8 @@ class DrawGraphCommand(BaseCommand): class DrawExecutionGraphCommand(BaseCommand): """Debugs an execution graph""" - type: Literal['draw_xgraph'] = 'draw_xgraph' + + type: Literal["draw_xgraph"] = "draw_xgraph" def run(self, context: CliContext) -> None: session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id) @@ -286,6 +293,7 @@ class DrawExecutionGraphCommand(BaseCommand): plt.axis("off") plt.show() + class SortedHelpFormatter(argparse.HelpFormatter): def _iter_indented_subactions(self, action): try: diff --git a/invokeai/app/cli/completer.py b/invokeai/app/cli/completer.py index 79274dab8c..fd10034dd7 100644 --- a/invokeai/app/cli/completer.py +++ b/invokeai/app/cli/completer.py @@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices # singleton object, class variable completer = None + class Completer(object): - def __init__(self, model_manager: ModelManager): self.commands = self.get_commands() self.matches = None @@ -43,7 +43,7 @@ class Completer(object): except IndexError: pass options = options or list(self.parse_commands().keys()) - + if not text: # first time self.matches = options else: @@ -56,17 +56,17 @@ class Completer(object): return match @classmethod - def get_commands(self)->List[object]: + def get_commands(self) -> List[object]: """ Return a list of all the client commands and invocations. """ return BaseCommand.get_commands() + BaseInvocation.get_invocations() - def get_current_command(self, buffer: str)->tuple[str, str]: + def get_current_command(self, buffer: str) -> tuple[str, str]: """ Parse the readline buffer to find the most recent command and its switch. """ - if len(buffer)==0: + if len(buffer) == 0: return None, None tokens = shlex.split(buffer) command = None @@ -78,11 +78,11 @@ class Completer(object): else: switch = t # don't try to autocomplete switches that are already complete - if switch and buffer.endswith(' '): - switch=None - return command or '', switch or '' + if switch and buffer.endswith(" "): + switch = None + return command or "", switch or "" - def parse_commands(self)->Dict[str, List[str]]: + def parse_commands(self) -> Dict[str, List[str]]: """ Return a dict in which the keys are the command name and the values are the parameters the command takes. @@ -90,11 +90,11 @@ class Completer(object): result = dict() for command in self.commands: hints = get_type_hints(command) - name = get_args(hints['type'])[0] - result.update({name:hints}) + name = get_args(hints["type"])[0] + result.update({name: hints}) return result - def get_command_options(self, command: str, switch: str)->List[str]: + def get_command_options(self, command: str, switch: str) -> List[str]: """ Return all the parameters that can be passed to the command as command-line switches. Returns None if the command is unrecognized. @@ -102,42 +102,46 @@ class Completer(object): parsed_commands = self.parse_commands() if command not in parsed_commands: return None - + # handle switches in the format "-foo=bar" argument = None - if switch and '=' in switch: - switch, argument = switch.split('=') - - parameter = switch.strip('-') + if switch and "=" in switch: + switch, argument = switch.split("=") + + parameter = switch.strip("-") if parameter in parsed_commands[command]: if argument is None: return self.get_parameter_options(parameter, parsed_commands[command][parameter]) else: - return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])] + return [ + f"--{parameter}={x}" + for x in self.get_parameter_options(parameter, parsed_commands[command][parameter]) + ] else: return [f"--{x}" for x in parsed_commands[command].keys()] - def get_parameter_options(self, parameter: str, typehint)->List[str]: + def get_parameter_options(self, parameter: str, typehint) -> List[str]: """ Given a parameter type (such as Literal), offers autocompletions. """ if get_origin(typehint) == Literal: return get_args(typehint) - if parameter == 'model': + if parameter == "model": return self.manager.model_names() - + def _pre_input_hook(self): if self.linebuffer: readline.insert_text(self.linebuffer) readline.redisplay() self.linebuffer = None - + + def set_autocompleter(services: InvocationServices) -> Completer: global completer - + if completer: return completer - + completer = Completer(services.model_manager) readline.set_completer(completer.complete) @@ -162,8 +166,6 @@ def set_autocompleter(services: InvocationServices) -> Completer: pass except OSError: # file likely corrupted newname = f"{histfile}.old" - logger.error( - f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}" - ) + logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}") histfile.replace(Path(newname)) atexit.register(readline.write_history_file, histfile) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 341e9e5b7e..bad95bb559 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -13,6 +13,7 @@ from pydantic.fields import Field # This should come early so that the logger can pick up its configuration options from .services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger + config = InvokeAIAppConfig.get_config() config.parse_args() logger = InvokeAILogger().getLogger(config=config) @@ -20,7 +21,7 @@ from invokeai.version.invokeai_version import __version__ # we call this early so that the message appears before other invokeai initialization messages if config.version: - print(f'InvokeAI version {__version__}') + print(f"InvokeAI version {__version__}") sys.exit(0) from invokeai.app.services.board_image_record_storage import ( @@ -36,18 +37,21 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService, ImageServiceDependencies from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService -from .services.default_graphs import (default_text_to_image_graph_id, - create_system_graphs) +from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage -from .cli.commands import (BaseCommand, CliContext, ExitCli, - SortedHelpFormatter, add_graph_parsers, add_parsers) +from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers from .cli.completer import set_autocompleter from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase -from .services.graph import (Edge, EdgeConnection, GraphExecutionState, - GraphInvocation, LibraryGraph, - are_connection_types_compatible) +from .services.graph import ( + Edge, + EdgeConnection, + GraphExecutionState, + GraphInvocation, + LibraryGraph, + are_connection_types_compatible, +) from .services.image_file_storage import DiskImageFileStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -58,6 +62,7 @@ from .services.sqlite import SqliteItemStorage import torch import invokeai.backend.util.hotfixes + if torch.backends.mps.is_available(): import invokeai.backend.util.mps_fixes @@ -69,6 +74,7 @@ class CliCommand(BaseModel): class InvalidArgs(Exception): pass + def add_invocation_args(command_parser): # Add linking capability command_parser.add_argument( @@ -113,7 +119,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser: return parser -class NodeField(): +class NodeField: alias: str node_path: str field: str @@ -126,15 +132,20 @@ class NodeField(): self.field_type = field_type -def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]: - return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()} +def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]: + return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()} def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: """Gets the node field for the specified field alias""" exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias) node_type = type(graph.graph.get_node(exposed_input.node_path)) - return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field]) + return NodeField( + alias=exposed_input.alias, + node_path=f"{node_id}.{exposed_input.node_path}", + field=exposed_input.field, + field_type=get_type_hints(node_type)[exposed_input.field], + ) def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField: @@ -142,7 +153,12 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) - exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias) node_type = type(graph.graph.get_node(exposed_output.node_path)) node_output_type = node_type.get_output_type() - return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field]) + return NodeField( + alias=exposed_output.alias, + node_path=f"{node_id}.{exposed_output.node_path}", + field=exposed_output.field, + field_type=get_type_hints(node_output_type)[exposed_output.field], + ) def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]: @@ -165,9 +181,7 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs} -def generate_matching_edges( - a: BaseInvocation, b: BaseInvocation, context: CliContext -) -> list[Edge]: +def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]: """Generates all possible edges between two invocations""" afields = get_node_outputs(a, context) bfields = get_node_inputs(b, context) @@ -179,12 +193,14 @@ def generate_matching_edges( matching_fields = matching_fields.difference(invalid_fields) # Validate types - matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)] + matching_fields = [ + f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type) + ] edges = [ Edge( source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field), - destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field) + destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field), ) for alias in matching_fields ] @@ -193,6 +209,7 @@ def generate_matching_edges( class SessionError(Exception): """Raised when a session error has occurred""" + pass @@ -209,22 +226,23 @@ def invoke_all(context: CliContext): context.invoker.services.logger.error( f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}" ) - + raise SessionError() + def invoke_cli(): - logger.info(f'InvokeAI version {__version__}') + logger.info(f"InvokeAI version {__version__}") # get the optional list of invocations to execute on the command line parser = config.get_parser() - parser.add_argument('commands',nargs='*') + parser.add_argument("commands", nargs="*") invocation_commands = parser.parse_args().commands # get the optional file to read commands from. # Simplest is to use it for STDIN if infile := config.from_file: - sys.stdin = open(infile,"r") - - model_manager = ModelManagerService(config,logger) + sys.stdin = open(infile, "r") + + model_manager = ModelManagerService(config, logger) events = EventServiceBase() output_folder = config.output_path @@ -234,13 +252,13 @@ def invoke_cli(): db_location = ":memory:" else: db_location = config.db_path - db_location.parent.mkdir(parents=True,exist_ok=True) + db_location.parent.mkdir(parents=True, exist_ok=True) logger.info(f'InvokeAI database location is "{db_location}"') graph_execution_manager = SqliteItemStorage[GraphExecutionState]( - filename=db_location, table_name="graph_executions" - ) + filename=db_location, table_name="graph_executions" + ) urls = LocalUrlService() image_record_storage = SqliteImageRecordStorage(db_location) @@ -281,24 +299,21 @@ def invoke_cli(): graph_execution_manager=graph_execution_manager, ) ) - + services = InvocationServices( model_manager=model_manager, events=events, - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')), + latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")), images=images, boards=boards, board_images=board_images, queue=MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=db_location, table_name="graphs" - ), + graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"), graph_execution_manager=graph_execution_manager, processor=DefaultInvocationProcessor(), logger=logger, configuration=config, ) - system_graphs = create_system_graphs(services.graph_library) system_graph_names = set([g.name for g in system_graphs]) @@ -308,7 +323,7 @@ def invoke_cli(): session: GraphExecutionState = invoker.create_execution_state() parser = get_command_parser(services) - re_negid = re.compile('^-[0-9]+$') + re_negid = re.compile("^-[0-9]+$") # Uncomment to print out previous sessions at startup # print(services.session_manager.list()) @@ -318,7 +333,7 @@ def invoke_cli(): command_line_args_exist = len(invocation_commands) > 0 done = False - + while not done: try: if command_line_args_exist: @@ -332,7 +347,7 @@ def invoke_cli(): try: # Refresh the state of the session - #history = list(get_graph_execution_history(context.session)) + # history = list(get_graph_execution_history(context.session)) history = list(reversed(context.nodes_added)) # Split the command for piping @@ -353,17 +368,17 @@ def invoke_cli(): args[field_name] = field_default # Parse invocation - command: CliCommand = None # type:ignore + command: CliCommand = None # type:ignore system_graph: Optional[LibraryGraph] = None - if args['type'] in system_graph_names: - system_graph = next(filter(lambda g: g.name == args['type'], system_graphs)) + if args["type"] in system_graph_names: + system_graph = next(filter(lambda g: g.name == args["type"], system_graphs)) invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id)) for exposed_input in system_graph.exposed_inputs: if exposed_input.alias in args: node = invocation.graph.get_node(exposed_input.node_path) field = exposed_input.field setattr(node, field, args[exposed_input.alias]) - command = CliCommand(command = invocation) + command = CliCommand(command=invocation) context.graph_nodes[invocation.id] = system_graph.id else: args["id"] = current_id @@ -385,17 +400,13 @@ def invoke_cli(): # Pipe previous command output (if there was a previous command) edges: list[Edge] = list() if len(history) > 0 or current_id != start_id: - from_id = ( - history[0] if current_id == start_id else str(current_id - 1) - ) + from_id = history[0] if current_id == start_id else str(current_id - 1) from_node = ( next(filter(lambda n: n[0].id == from_id, new_invocations))[0] if current_id != start_id else context.session.graph.get_node(from_id) ) - matching_edges = generate_matching_edges( - from_node, command.command, context - ) + matching_edges = generate_matching_edges(from_node, command.command, context) edges.extend(matching_edges) # Parse provided links @@ -406,16 +417,18 @@ def invoke_cli(): node_id = str(current_id + int(node_id)) link_node = context.session.graph.get_node(node_id) - matching_edges = generate_matching_edges( - link_node, command.command, context - ) + matching_edges = generate_matching_edges(link_node, command.command, context) matching_destinations = [e.destination for e in matching_edges] edges = [e for e in edges if e.destination not in matching_destinations] edges.extend(matching_edges) if "link" in args and args["link"]: for link in args["link"]: - edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]] + edges = [ + e + for e in edges + if e.destination.node_id != command.command.id or e.destination.field != link[2] + ] node_id = link[0] if re_negid.match(node_id): @@ -428,7 +441,7 @@ def invoke_cli(): edges.append( Edge( source=EdgeConnection(node_id=node_output.node_path, field=node_output.field), - destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field) + destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field), ) ) diff --git a/invokeai/app/invocations/__init__.py b/invokeai/app/invocations/__init__.py index 0a451ff618..6407a1cdee 100644 --- a/invokeai/app/invocations/__init__.py +++ b/invokeai/app/invocations/__init__.py @@ -4,9 +4,5 @@ __all__ = [] dirname = os.path.dirname(os.path.abspath(__file__)) for f in os.listdir(dirname): - if ( - f != "__init__.py" - and os.path.isfile("%s/%s" % (dirname, f)) - and f[-3:] == ".py" - ): + if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py": __all__.append(f[:-3]) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 4c7314bd2b..758ab2e787 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,8 +4,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, - get_type_hints) +from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints from pydantic import BaseConfig, BaseModel, Field diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 5446757eb0..01c003da96 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -8,8 +8,7 @@ from pydantic import Field, validator from invokeai.app.models.image import ImageField from invokeai.app.util.misc import SEED_MAX, get_random_seed -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext, UIConfig) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig class IntCollectionOutput(BaseInvocationOutput): @@ -27,8 +26,7 @@ class FloatCollectionOutput(BaseInvocationOutput): type: Literal["float_collection"] = "float_collection" # Outputs - collection: list[float] = Field( - default=[], description="The float collection") + collection: list[float] = Field(default=[], description="The float collection") class ImageCollectionOutput(BaseInvocationOutput): @@ -37,8 +35,7 @@ class ImageCollectionOutput(BaseInvocationOutput): type: Literal["image_collection"] = "image_collection" # Outputs - collection: list[ImageField] = Field( - default=[], description="The output images") + collection: list[ImageField] = Field(default=[], description="The output images") class Config: schema_extra = {"required": ["type", "collection"]} @@ -56,10 +53,7 @@ class RangeInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Range", - "tags": ["range", "integer", "collection"] - }, + "ui": {"title": "Range", "tags": ["range", "integer", "collection"]}, } @validator("stop") @@ -69,9 +63,7 @@ class RangeInvocation(BaseInvocation): return v def invoke(self, context: InvocationContext) -> IntCollectionOutput: - return IntCollectionOutput( - collection=list(range(self.start, self.stop, self.step)) - ) + return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) class RangeOfSizeInvocation(BaseInvocation): @@ -86,18 +78,11 @@ class RangeOfSizeInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Sized Range", - "tags": ["range", "integer", "size", "collection"] - }, + "ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]}, } def invoke(self, context: InvocationContext) -> IntCollectionOutput: - return IntCollectionOutput( - collection=list( - range( - self.start, self.start + self.size, - self.step))) + return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) class RandomRangeInvocation(BaseInvocation): @@ -107,9 +92,7 @@ class RandomRangeInvocation(BaseInvocation): # Inputs low: int = Field(default=0, description="The inclusive low value") - high: int = Field( - default=np.iinfo(np.int32).max, description="The exclusive high value" - ) + high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") size: int = Field(default=1, description="The number of values to generate") seed: int = Field( ge=0, @@ -120,19 +103,12 @@ class RandomRangeInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Random Range", - "tags": ["range", "integer", "random", "collection"] - }, + "ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]}, } def invoke(self, context: InvocationContext) -> IntCollectionOutput: rng = np.random.default_rng(self.seed) - return IntCollectionOutput( - collection=list( - rng.integers( - low=self.low, high=self.high, - size=self.size))) + return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) class ImageCollectionInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 54f7bc6683..ada7a06a57 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -11,64 +11,63 @@ from ...backend.model_management import BaseModelType, ModelType, SubModelType, import torch from compel import Compel, ReturnedEmbeddingsType -from compel.prompt_parser import (Blend, Conjunction, - CrossAttentionControlSubstitute, - FlattenedPrompt, Fragment) +from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from ...backend.util.devices import torch_dtype from ...backend.model_management import ModelType from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .model import ClipField from dataclasses import dataclass class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field( - default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") class Config: schema_extra = {"required": ["conditioning_name"]} + @dataclass class BasicConditioningInfo: - #type: Literal["basic_conditioning"] = "basic_conditioning" + # type: Literal["basic_conditioning"] = "basic_conditioning" embeds: torch.Tensor extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] # weight: float # mode: ConditioningAlgo + @dataclass class SDXLConditioningInfo(BasicConditioningInfo): - #type: Literal["sdxl_conditioning"] = "sdxl_conditioning" + # type: Literal["sdxl_conditioning"] = "sdxl_conditioning" pooled_embeds: torch.Tensor add_time_ids: torch.Tensor -ConditioningInfoType = Annotated[ - Union[BasicConditioningInfo, SDXLConditioningInfo], - Field(discriminator="type") -] + +ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")] + @dataclass class ConditioningFieldData: conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]] - #unconditioned: Optional[torch.Tensor] + # unconditioned: Optional[torch.Tensor] -#class ConditioningAlgo(str, Enum): + +# class ConditioningAlgo(str, Enum): # Compose = "compose" # ComposeEx = "compose_ex" # PerpNeg = "perp_neg" + class CompelOutput(BaseInvocationOutput): """Compel parser output""" - #fmt: off + # fmt: off type: Literal["compel_output"] = "compel_output" conditioning: ConditioningField = Field(default=None, description="Conditioning") - #fmt: on + # fmt: on class CompelInvocation(BaseInvocation): @@ -82,33 +81,28 @@ class CompelInvocation(BaseInvocation): # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Prompt (Compel)", - "tags": ["prompt", "compel"], - "type_hints": { - "model": "model" - } - }, + "ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, } @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.dict(), context=context, + **self.clip.tokenizer.dict(), + context=context, ) text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.dict(), context=context, + **self.clip.text_encoder.dict(), + context=context, ) def _lora_loader(): for lora in self.clip.loras: - lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context) + lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return - #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): @@ -124,15 +118,18 @@ class CompelInvocation(BaseInvocation): ) except ModelNotFoundException: # print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") - - with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\ - text_encoder_info as text_encoder: + # import traceback + # print(traceback.format_exc()) + print(f'Warn: trigger: "{trigger}" not found') + with ModelPatcher.apply_lora_text_encoder( + text_encoder_info.context.model, _lora_loader() + ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + tokenizer, + ti_manager, + ), ModelPatcher.apply_clip_skip( + text_encoder_info.context.model, self.clip.skipped_layers + ), text_encoder_info as text_encoder: compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, @@ -147,14 +144,12 @@ class CompelInvocation(BaseInvocation): if context.services.configuration.log_tokenization: log_tokenization_for_prompt_object(prompt, tokenizer) - c, options = compel.build_conditioning_tensor_for_prompt_object( - prompt) + c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count( - tokenizer, conjunction), - cross_attention_control_args=options.get( - "cross_attention_control", None),) + tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), + cross_attention_control_args=options.get("cross_attention_control", None), + ) c = c.detach().to("cpu") @@ -176,24 +171,26 @@ class CompelInvocation(BaseInvocation): ), ) + class SDXLPromptInvocationBase: def run_clip_raw(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), context=context, + **clip_field.tokenizer.dict(), + context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), context=context, + **clip_field.text_encoder.dict(), + context=context, ) def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context) + lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return - #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): @@ -209,15 +206,18 @@ class SDXLPromptInvocationBase: ) except ModelNotFoundException: # print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") - - with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\ - text_encoder_info as text_encoder: + # import traceback + # print(traceback.format_exc()) + print(f'Warn: trigger: "{trigger}" not found') + with ModelPatcher.apply_lora_text_encoder( + text_encoder_info.context.model, _lora_loader() + ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + tokenizer, + ti_manager, + ), ModelPatcher.apply_clip_skip( + text_encoder_info.context.model, clip_field.skipped_layers + ), text_encoder_info as text_encoder: text_inputs = tokenizer( prompt, padding="max_length", @@ -249,21 +249,22 @@ class SDXLPromptInvocationBase: def run_clip_compel(self, context, clip_field, prompt, get_pooled): tokenizer_info = context.services.model_manager.get_model( - **clip_field.tokenizer.dict(), context=context, + **clip_field.tokenizer.dict(), + context=context, ) text_encoder_info = context.services.model_manager.get_model( - **clip_field.text_encoder.dict(), context=context, + **clip_field.text_encoder.dict(), + context=context, ) def _lora_loader(): for lora in clip_field.loras: - lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context) + lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context) yield (lora_info.context.model, lora.weight) del lora_info return - #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + # loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): @@ -279,22 +280,25 @@ class SDXLPromptInvocationBase: ) except ModelNotFoundException: # print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") - - with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ - ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ - ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\ - text_encoder_info as text_encoder: + # import traceback + # print(traceback.format_exc()) + print(f'Warn: trigger: "{trigger}" not found') + with ModelPatcher.apply_lora_text_encoder( + text_encoder_info.context.model, _lora_loader() + ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( + tokenizer, + ti_manager, + ), ModelPatcher.apply_clip_skip( + text_encoder_info.context.model, clip_field.skipped_layers + ), text_encoder_info as text_encoder: compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, dtype_for_device_getter=torch_dtype, truncate_long_prompts=True, # TODO: - returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip + returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=True, ) @@ -328,6 +332,7 @@ class SDXLPromptInvocationBase: return c, c_pooled, ec + class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" @@ -347,13 +352,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "SDXL Prompt (Compel)", - "tags": ["prompt", "compel"], - "type_hints": { - "model": "model" - } - }, + "ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, } @torch.no_grad() @@ -368,9 +367,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): crop_coords = (self.crop_top, self.crop_left) target_size = (self.target_height, self.target_width) - add_time_ids = torch.tensor([ - original_size + crop_coords + target_size - ]) + add_time_ids = torch.tensor([original_size + crop_coords + target_size]) conditioning_data = ConditioningFieldData( conditionings=[ @@ -392,12 +389,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ), ) + class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt" - style: str = Field(default="", description="Style prompt") # TODO: ? + style: str = Field(default="", description="Style prompt") # TODO: ? original_width: int = Field(1024, description="") original_height: int = Field(1024, description="") crop_top: int = Field(0, description="") @@ -411,9 +409,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase "ui": { "title": "SDXL Refiner Prompt (Compel)", "tags": ["prompt", "compel"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } @@ -424,9 +420,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) - add_time_ids = torch.tensor([ - original_size + crop_coords + (self.aesthetic_score,) - ]) + add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) conditioning_data = ConditioningFieldData( conditionings=[ @@ -434,7 +428,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids, - extra_conditioning=ec2, # or None + extra_conditioning=ec2, # or None ) ] ) @@ -448,6 +442,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase ), ) + class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Pass unmodified prompt to conditioning without compel processing.""" @@ -467,13 +462,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "SDXL Prompt (Raw)", - "tags": ["prompt", "compel"], - "type_hints": { - "model": "model" - } - }, + "ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, } @torch.no_grad() @@ -488,9 +477,7 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): crop_coords = (self.crop_top, self.crop_left) target_size = (self.target_height, self.target_width) - add_time_ids = torch.tensor([ - original_size + crop_coords + target_size - ]) + add_time_ids = torch.tensor([original_size + crop_coords + target_size]) conditioning_data = ConditioningFieldData( conditionings=[ @@ -512,12 +499,13 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ), ) + class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt" - style: str = Field(default="", description="Style prompt") # TODO: ? + style: str = Field(default="", description="Style prompt") # TODO: ? original_width: int = Field(1024, description="") original_height: int = Field(1024, description="") crop_top: int = Field(0, description="") @@ -531,9 +519,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): "ui": { "title": "SDXL Refiner Prompt (Raw)", "tags": ["prompt", "compel"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } @@ -544,9 +530,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) - add_time_ids = torch.tensor([ - original_size + crop_coords + (self.aesthetic_score,) - ]) + add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)]) conditioning_data = ConditioningFieldData( conditionings=[ @@ -554,7 +538,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): embeds=c2, pooled_embeds=c2_pooled, add_time_ids=add_time_ids, - extra_conditioning=ec2, # or None + extra_conditioning=ec2, # or None ) ] ) @@ -571,11 +555,14 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class ClipSkipInvocationOutput(BaseInvocationOutput): """Clip skip node output""" + type: Literal["clip_skip_output"] = "clip_skip_output" clip: ClipField = Field(None, description="Clip with skipped layers") + class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" + type: Literal["clip_skip"] = "clip_skip" clip: ClipField = Field(None, description="Clip to use") @@ -583,10 +570,7 @@ class ClipSkipInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "CLIP Skip", - "tags": ["clip", "skip"] - }, + "ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]}, } def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: @@ -597,46 +581,26 @@ class ClipSkipInvocation(BaseInvocation): def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], - truncate_if_too_long=False) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False +) -> int: if type(prompt) is Blend: blend: Blend = prompt - return max( - [ - get_max_token_count(tokenizer, p, truncate_if_too_long) - for p in blend.prompts - ] - ) + return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts]) elif type(prompt) is Conjunction: conjunction: Conjunction = prompt - return sum( - [ - get_max_token_count(tokenizer, p, truncate_if_too_long) - for p in conjunction.prompts - ] - ) + return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts]) else: - return len( - get_tokens_for_prompt_object( - tokenizer, prompt, truncate_if_too_long)) + return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)) -def get_tokens_for_prompt_object( - tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> List[str]: +def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]: if type(parsed_prompt) is Blend: - raise ValueError( - "Blend is not supported here - you need to get tokens for each of its .children" - ) + raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") text_fragments = [ x.text if type(x) is Fragment - else ( - " ".join([f.text for f in x.original]) - if type(x) is CrossAttentionControlSubstitute - else str(x) - ) + else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x)) for x in parsed_prompt.children ] text = " ".join(text_fragments) @@ -647,25 +611,17 @@ def get_tokens_for_prompt_object( return tokens -def log_tokenization_for_conjunction( - c: Conjunction, tokenizer, display_label_prefix=None -): +def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix - log_tokenization_for_prompt_object( - p, - tokenizer, - display_label_prefix=this_display_label_prefix - ) + log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix) -def log_tokenization_for_prompt_object( - p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None -): +def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None): display_label_prefix = display_label_prefix or "" if type(p) is Blend: blend: Blend = p @@ -702,13 +658,10 @@ def log_tokenization_for_prompt_object( ) else: text = " ".join([x.text for x in flattened_prompt.children]) - log_tokenization_for_text( - text, tokenizer, display_label=display_label_prefix - ) + log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def log_tokenization_for_text( - text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 2087594b7b..d2b2d44526 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -6,20 +6,29 @@ from typing import Dict, List, Literal, Optional, Union import cv2 import numpy as np -from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, - LeresDetector, LineartAnimeDetector, - LineartDetector, MediapipeFaceDetector, - MidasDetector, MLSDdetector, NormalBaeDetector, - OpenposeDetector, PidiNetDetector, SamDetector, - ZoeDetector) +from controlnet_aux import ( + CannyDetector, + ContentShuffleDetector, + HEDdetector, + LeresDetector, + LineartAnimeDetector, + LineartDetector, + MediapipeFaceDetector, + MidasDetector, + MLSDdetector, + NormalBaeDetector, + OpenposeDetector, + PidiNetDetector, + SamDetector, + ZoeDetector, +) from controlnet_aux.util import HWC3, ade_palette from PIL import Image from pydantic import BaseModel, Field, validator from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from ..models.image import ImageOutput, PILInvocationConfig CONTROLNET_DEFAULT_MODELS = [ @@ -34,7 +43,6 @@ CONTROLNET_DEFAULT_MODELS = [ "lllyasviel/sd-controlnet-scribble", "lllyasviel/sd-controlnet-normal", "lllyasviel/sd-controlnet-mlsd", - ############################################# # lllyasviel sd v1.5, ControlNet v1.1 models ############################################# @@ -56,7 +64,6 @@ CONTROLNET_DEFAULT_MODELS = [ "lllyasviel/control_v11e_sd15_shuffle", "lllyasviel/control_v11e_sd15_ip2p", "lllyasviel/control_v11f1e_sd15_tile", - ################################################# # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? ################################################## @@ -71,7 +78,6 @@ CONTROLNET_DEFAULT_MODELS = [ "thibaud/controlnet-sd21-lineart-diffusers", "thibaud/controlnet-sd21-normalbae-diffusers", "thibaud/controlnet-sd21-ade20k-diffusers", - ############################################## # ControlNetMediaPipeface, ControlNet v1.1 ############################################## @@ -83,10 +89,17 @@ CONTROLNET_DEFAULT_MODELS = [ ] CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] -CONTROLNET_MODE_VALUES = Literal[tuple( - ["balanced", "more_prompt", "more_control", "unbalanced"])] -CONTROLNET_RESIZE_VALUES = Literal[tuple( - ["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])] +CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] +CONTROLNET_RESIZE_VALUES = Literal[ + tuple( + [ + "just_resize", + "crop_resize", + "fill_resize", + "just_resize_simple", + ] + ) +] class ControlNetModelField(BaseModel): @@ -98,21 +111,17 @@ class ControlNetModelField(BaseModel): class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") - control_model: Optional[ControlNetModelField] = Field( - default=None, description="The ControlNet model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") - control_weight: Union[float, List[float]] = Field( - default=1, description="The weight given to the ControlNet") + control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( - default=0, ge=0, le=1, - description="When the ControlNet is first applied (% of total steps)") + default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" + ) end_step_percent: float = Field( - default=1, ge=0, le=1, - description="When the ControlNet is last applied (% of total steps)") - control_mode: CONTROLNET_MODE_VALUES = Field( - default="balanced", description="The control mode to use") - resize_mode: CONTROLNET_RESIZE_VALUES = Field( - default="just_resize", description="The resize mode to use") + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") + resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") def validate_control_weight(cls, v): @@ -120,11 +129,10 @@ class ControlField(BaseModel): if isinstance(v, list): for i in v: if i < -1 or i > 2: - raise ValueError( - 'Control weights must be within -1 to 2 range') + raise ValueError("Control weights must be within -1 to 2 range") else: if v < -1 or v > 2: - raise ValueError('Control weights must be within -1 to 2 range') + raise ValueError("Control weights must be within -1 to 2 range") return v class Config: @@ -136,12 +144,13 @@ class ControlField(BaseModel): "control_model": "controlnet_model", # "control_weight": "number", } - } + }, } class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" + # fmt: off type: Literal["control_output"] = "control_output" control: ControlField = Field(default=None, description="The control info") @@ -150,6 +159,7 @@ class ControlOutput(BaseInvocationOutput): class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" + # fmt: off type: Literal["controlnet"] = "controlnet" # Inputs @@ -176,7 +186,7 @@ class ControlNetInvocation(BaseInvocation): # "cfg_scale": "float", "cfg_scale": "number", "control_weight": "float", - } + }, }, } @@ -205,10 +215,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Image Processor", - "tags": ["image", "processor"] - }, + "ui": {"title": "Image Processor", "tags": ["image", "processor"]}, } def run_processor(self, image): @@ -233,7 +240,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): image_category=ImageCategory.CONTROL, session_id=context.graph_execution_state_id, node_id=self.id, - is_intermediate=self.is_intermediate + is_intermediate=self.is_intermediate, ) """Builds an ImageOutput and its ImageField""" @@ -248,9 +255,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): ) -class CannyImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Canny edge detection for ControlNet""" + # fmt: off type: Literal["canny_image_processor"] = "canny_image_processor" # Input @@ -260,22 +267,18 @@ class CannyImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Canny Processor", - "tags": ["controlnet", "canny", "image", "processor"] - }, + "ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]}, } def run_processor(self, image): canny_processor = CannyDetector() - processed_image = canny_processor( - image, self.low_threshold, self.high_threshold) + processed_image = canny_processor(image, self.low_threshold, self.high_threshold) return processed_image -class HedImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies HED edge detection to image""" + # fmt: off type: Literal["hed_image_processor"] = "hed_image_processor" # Inputs @@ -288,27 +291,25 @@ class HedImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Softedge(HED) Processor", - "tags": ["controlnet", "softedge", "hed", "image", "processor"] - }, + "ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]}, } def run_processor(self, image): hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") - processed_image = hed_processor(image, - detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, - # safe not supported in controlnet_aux v0.0.3 - # safe=self.safe, - scribble=self.scribble, - ) + processed_image = hed_processor( + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + # safe not supported in controlnet_aux v0.0.3 + # safe=self.safe, + scribble=self.scribble, + ) return processed_image -class LineartImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies line art processing to image""" + # fmt: off type: Literal["lineart_image_processor"] = "lineart_image_processor" # Inputs @@ -319,24 +320,20 @@ class LineartImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Lineart Processor", - "tags": ["controlnet", "lineart", "image", "processor"] - }, + "ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]}, } def run_processor(self, image): - lineart_processor = LineartDetector.from_pretrained( - "lllyasviel/Annotators") + lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") processed_image = lineart_processor( - image, detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, coarse=self.coarse) + image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse + ) return processed_image -class LineartAnimeImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies line art anime processing to image""" + # fmt: off type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" # Inputs @@ -348,23 +345,23 @@ class LineartAnimeImageProcessorInvocation( schema_extra = { "ui": { "title": "Lineart Anime Processor", - "tags": ["controlnet", "lineart", "anime", "image", "processor"] + "tags": ["controlnet", "lineart", "anime", "image", "processor"], }, } def run_processor(self, image): - processor = LineartAnimeDetector.from_pretrained( - "lllyasviel/Annotators") - processed_image = processor(image, - detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, - ) + processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") + processed_image = processor( + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + ) return processed_image -class OpenposeImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies Openpose processing to image""" + # fmt: off type: Literal["openpose_image_processor"] = "openpose_image_processor" # Inputs @@ -375,25 +372,23 @@ class OpenposeImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Openpose Processor", - "tags": ["controlnet", "openpose", "image", "processor"] - }, + "ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]}, } def run_processor(self, image): - openpose_processor = OpenposeDetector.from_pretrained( - "lllyasviel/Annotators") + openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") processed_image = openpose_processor( - image, detect_resolution=self.detect_resolution, + image, + detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, - hand_and_face=self.hand_and_face,) + hand_and_face=self.hand_and_face, + ) return processed_image -class MidasDepthImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies Midas depth processing to image""" + # fmt: off type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" # Inputs @@ -405,26 +400,24 @@ class MidasDepthImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Midas (Depth) Processor", - "tags": ["controlnet", "midas", "depth", "image", "processor"] - }, + "ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]}, } def run_processor(self, image): midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") - processed_image = midas_processor(image, - a=np.pi * self.a_mult, - bg_th=self.bg_th, - # dept_and_normal not supported in controlnet_aux v0.0.3 - # depth_and_normal=self.depth_and_normal, - ) + processed_image = midas_processor( + image, + a=np.pi * self.a_mult, + bg_th=self.bg_th, + # dept_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal=self.depth_and_normal, + ) return processed_image -class NormalbaeImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies NormalBae processing to image""" + # fmt: off type: Literal["normalbae_image_processor"] = "normalbae_image_processor" # Inputs @@ -434,24 +427,20 @@ class NormalbaeImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Normal BAE Processor", - "tags": ["controlnet", "normal", "bae", "image", "processor"] - }, + "ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]}, } def run_processor(self, image): - normalbae_processor = NormalBaeDetector.from_pretrained( - "lllyasviel/Annotators") + normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") processed_image = normalbae_processor( - image, detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution) + image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution + ) return processed_image -class MlsdImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies MLSD processing to image""" + # fmt: off type: Literal["mlsd_image_processor"] = "mlsd_image_processor" # Inputs @@ -463,24 +452,24 @@ class MlsdImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "MLSD Processor", - "tags": ["controlnet", "mlsd", "image", "processor"] - }, + "ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]}, } def run_processor(self, image): mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") processed_image = mlsd_processor( - image, detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, thr_v=self.thr_v, - thr_d=self.thr_d) + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + thr_v=self.thr_v, + thr_d=self.thr_d, + ) return processed_image -class PidiImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies PIDI processing to image""" + # fmt: off type: Literal["pidi_image_processor"] = "pidi_image_processor" # Inputs @@ -492,25 +481,24 @@ class PidiImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "PIDI Processor", - "tags": ["controlnet", "pidi", "image", "processor"] - }, + "ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]}, } def run_processor(self, image): - pidi_processor = PidiNetDetector.from_pretrained( - "lllyasviel/Annotators") + pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") processed_image = pidi_processor( - image, detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, safe=self.safe, - scribble=self.scribble) + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + safe=self.safe, + scribble=self.scribble, + ) return processed_image -class ContentShuffleImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies content shuffle processing to image""" + # fmt: off type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" # Inputs @@ -525,48 +513,45 @@ class ContentShuffleImageProcessorInvocation( schema_extra = { "ui": { "title": "Content Shuffle Processor", - "tags": ["controlnet", "contentshuffle", "image", "processor"] + "tags": ["controlnet", "contentshuffle", "image", "processor"], }, } def run_processor(self, image): content_shuffle_processor = ContentShuffleDetector() - processed_image = content_shuffle_processor(image, - detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, - h=self.h, - w=self.w, - f=self.f - ) + processed_image = content_shuffle_processor( + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + h=self.h, + w=self.w, + f=self.f, + ) return processed_image # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 -class ZoeDepthImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies Zoe depth processing to image""" + # fmt: off type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" # fmt: on class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Zoe (Depth) Processor", - "tags": ["controlnet", "zoe", "depth", "image", "processor"] - }, + "ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]}, } def run_processor(self, image): - zoe_depth_processor = ZoeDetector.from_pretrained( - "lllyasviel/Annotators") + zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") processed_image = zoe_depth_processor(image) return processed_image -class MediapipeFaceProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies mediapipe face processing to image""" + # fmt: off type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" # Inputs @@ -576,26 +561,22 @@ class MediapipeFaceProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Mediapipe Processor", - "tags": ["controlnet", "mediapipe", "image", "processor"] - }, + "ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]}, } def run_processor(self, image): # MediaPipeFaceDetector throws an error if image has alpha channel # so convert to RGB if needed - if image.mode == 'RGBA': - image = image.convert('RGB') + if image.mode == "RGBA": + image = image.convert("RGB") mediapipe_face_processor = MediapipeFaceDetector() - processed_image = mediapipe_face_processor( - image, max_faces=self.max_faces, min_confidence=self.min_confidence) + processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence) return processed_image -class LeresImageProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies leres processing to image""" + # fmt: off type: Literal["leres_image_processor"] = "leres_image_processor" # Inputs @@ -608,24 +589,23 @@ class LeresImageProcessorInvocation( class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Leres (Depth) Processor", - "tags": ["controlnet", "leres", "depth", "image", "processor"] - }, + "ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]}, } def run_processor(self, image): leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") processed_image = leres_processor( - image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost, + image, + thr_a=self.thr_a, + thr_b=self.thr_b, + boost=self.boost, detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution) + image_resolution=self.image_resolution, + ) return processed_image -class TileResamplerProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): - +class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): # fmt: off type: Literal["tile_image_processor"] = "tile_image_processor" # Inputs @@ -637,16 +617,17 @@ class TileResamplerProcessorInvocation( schema_extra = { "ui": { "title": "Tile Resample Processor", - "tags": ["controlnet", "tile", "resample", "image", "processor"] + "tags": ["controlnet", "tile", "resample", "image", "processor"], }, } # tile_resample copied from sd-webui-controlnet/scripts/processor.py - def tile_resample(self, - np_img: np.ndarray, - res=512, # never used? - down_sampling_rate=1.0, - ): + def tile_resample( + self, + np_img: np.ndarray, + res=512, # never used? + down_sampling_rate=1.0, + ): np_img = HWC3(np_img) if down_sampling_rate < 1.1: return np_img @@ -658,36 +639,41 @@ class TileResamplerProcessorInvocation( def run_processor(self, img): np_img = np.array(img, dtype=np.uint8) - processed_np_image = self.tile_resample(np_img, - # res=self.tile_size, - down_sampling_rate=self.down_sampling_rate - ) + processed_np_image = self.tile_resample( + np_img, + # res=self.tile_size, + down_sampling_rate=self.down_sampling_rate, + ) processed_image = Image.fromarray(processed_np_image) return processed_image -class SegmentAnythingProcessorInvocation( - ImageProcessorInvocation, PILInvocationConfig): +class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): """Applies segment anything processing to image""" + # fmt: off type: Literal["segment_anything_processor"] = "segment_anything_processor" # fmt: on class Config(InvocationConfig): - schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [ - "controlnet", "segment", "anything", "sam", "image", "processor"]}, } + schema_extra = { + "ui": { + "title": "Segment Anything Processor", + "tags": ["controlnet", "segment", "anything", "sam", "image", "processor"], + }, + } def run_processor(self, image): # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( - "ybelkada/segment-anything", subfolder="checkpoints") + "ybelkada/segment-anything", subfolder="checkpoints" + ) np_img = np.array(image, dtype=np.uint8) processed_image = segment_anything_processor(np_img) return processed_image class SamDetectorReproducibleColors(SamDetector): - # overriding SamDetector.show_anns() method to use reproducible colors for segmentation image # base class show_anns() method randomizes colors, # which seems to also lead to non-reproducible image generation @@ -695,19 +681,15 @@ class SamDetectorReproducibleColors(SamDetector): def show_anns(self, anns: List[Dict]): if len(anns) == 0: return - sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) - h, w = anns[0]['segmentation'].shape - final_img = Image.fromarray( - np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True) + h, w = anns[0]["segmentation"].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") palette = ade_palette() for i, ann in enumerate(sorted_anns): - m = ann['segmentation'] + m = ann["segmentation"] img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) # doing modulo just in case number of annotated regions exceeds number of colors in palette ann_color = palette[i % len(palette)] img[:, :] = ann_color - final_img.paste( - Image.fromarray(img, mode="RGB"), - (0, 0), - Image.fromarray(np.uint8(m * 255))) + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255))) return np.array(final_img, dtype=np.uint8) diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index cd7eaebeec..bd3a4adbe4 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -37,10 +37,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "OpenCV Inpaint", - "tags": ["opencv", "inpaint"] - }, + "ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 910a7edf8b..d48c9f922e 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -6,8 +6,7 @@ from typing import Literal, Optional, get_args import torch from pydantic import Field -from invokeai.app.models.image import (ColorField, ImageCategory, ImageField, - ResourceOrigin) +from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods @@ -25,13 +24,12 @@ from contextlib import contextmanager, ExitStack, ContextDecorator SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())] INFILL_METHODS = Literal[tuple(infill_methods())] -DEFAULT_INFILL_METHOD = ( - "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" -) +DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" from .latent import get_scheduler + class OldModelContext(ContextDecorator): model: StableDiffusionGeneratorPipeline @@ -44,6 +42,7 @@ class OldModelContext(ContextDecorator): def __exit__(self, *exc): return False + class OldModelInfo: name: str hash: str @@ -64,20 +63,34 @@ class InpaintInvocation(BaseInvocation): positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") - seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed) - steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") - width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", ) - height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", ) - cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) + seed: int = Field( + ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed + ) + steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image") + width: int = Field( + default=512, + multiple_of=8, + gt=0, + description="The width of the resulting image", + ) + height: int = Field( + default=512, + multiple_of=8, + gt=0, + description="The height of the resulting image", + ) + cfg_scale: float = Field( + default=7.5, + ge=1, + description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", + ) + scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use") unet: UNetField = Field(default=None, description="UNet model") vae: VaeField = Field(default=None, description="Vae model") # Inputs image: Optional[ImageField] = Field(description="The input image") - strength: float = Field( - default=0.75, gt=0, le=1, description="The strength of the original image" - ) + strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image") fit: bool = Field( default=True, description="Whether or not the result should be fit to the aspect ratio of the input image", @@ -86,18 +99,10 @@ class InpaintInvocation(BaseInvocation): # Inputs mask: Optional[ImageField] = Field(description="The mask") seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)") - seam_blur: int = Field( - default=16, ge=0, description="The seam inpaint blur radius (px)" - ) - seam_strength: float = Field( - default=0.75, gt=0, le=1, description="The seam inpaint strength" - ) - seam_steps: int = Field( - default=30, ge=1, description="The number of steps to use for seam inpaint" - ) - tile_size: int = Field( - default=32, ge=1, description="The tile infill method size (px)" - ) + seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)") + seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength") + seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint") + tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)") infill_method: INFILL_METHODS = Field( default=DEFAULT_INFILL_METHOD, description="The method used to infill empty regions (px)", @@ -128,10 +133,7 @@ class InpaintInvocation(BaseInvocation): # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "tags": ["stable-diffusion", "image"], - "title": "Inpaint" - }, + "ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"}, } def dispatch_progress( @@ -162,18 +164,23 @@ class InpaintInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context,) + **lora.dict(exclude={"weight"}), + context=context, + ) yield (lora_info.context.model, lora.weight) del lora_info return - - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,) - vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,) - with vae_info as vae,\ - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict(), + context=context, + ) + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + context=context, + ) + with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: device = context.services.model_manager.mgr.cache.execution_device dtype = context.services.model_manager.mgr.cache.precision @@ -197,21 +204,11 @@ class InpaintInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = ( - None - if self.image is None - else context.services.images.get_pil_image(self.image.image_name) - ) - mask = ( - None - if self.mask is None - else context.services.images.get_pil_image(self.mask.image_name) - ) + image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name) + mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] scheduler = get_scheduler( diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 26ee9e391d..3f40ea3cbe 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -9,9 +9,13 @@ from pathlib import Path from typing import Union from invokeai.app.invocations.metadata import CoreMetadata from ..models.image import ( - ImageCategory, ImageField, ResourceOrigin, - PILInvocationConfig, ImageOutput, MaskOutput, -) + ImageCategory, + ImageField, + ResourceOrigin, + PILInvocationConfig, + ImageOutput, + MaskOutput, +) from .baseinvocation import ( BaseInvocation, InvocationContext, @@ -20,6 +24,7 @@ from .baseinvocation import ( from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark + class LoadImageInvocation(BaseInvocation): """Load an image and provide it as output.""" @@ -34,10 +39,7 @@ class LoadImageInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Load Image", - "tags": ["image", "load"] - }, + "ui": {"title": "Load Image", "tags": ["image", "load"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -56,16 +58,11 @@ class ShowImageInvocation(BaseInvocation): type: Literal["show_image"] = "show_image" # Inputs - image: Optional[ImageField] = Field( - default=None, description="The image to show" - ) + image: Optional[ImageField] = Field(default=None, description="The image to show") class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Show Image", - "tags": ["image", "show"] - }, + "ui": {"title": "Show Image", "tags": ["image", "show"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -98,18 +95,13 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Crop Image", - "tags": ["image", "crop"] - }, + "ui": {"title": "Crop Image", "tags": ["image", "crop"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) - image_crop = Image.new( - mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0) - ) + image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)) image_crop.paste(image, (-self.x, -self.y)) image_dto = context.services.images.create( @@ -144,21 +136,14 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Paste Image", - "tags": ["image", "paste"] - }, + "ui": {"title": "Paste Image", "tags": ["image", "paste"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image(self.base_image.image_name) image = context.services.images.get_pil_image(self.image.image_name) mask = ( - None - if self.mask is None - else ImageOps.invert( - context.services.images.get_pil_image(self.mask.image_name) - ) + None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name)) ) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -167,9 +152,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): max_x = max(base_image.width, image.width + self.x) max_y = max(base_image.height, image.height + self.y) - new_image = Image.new( - mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0) - ) + new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)) new_image.paste(base_image, (abs(min_x), abs(min_y))) new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask) @@ -202,10 +185,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Mask From Alpha", - "tags": ["image", "mask", "alpha"] - }, + "ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]}, } def invoke(self, context: InvocationContext) -> MaskOutput: @@ -244,10 +224,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Multiply Images", - "tags": ["image", "multiply"] - }, + "ui": {"title": "Multiply Images", "tags": ["image", "multiply"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -288,10 +265,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Image Channel", - "tags": ["image", "channel"] - }, + "ui": {"title": "Image Channel", "tags": ["image", "channel"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -331,10 +305,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Convert Image", - "tags": ["image", "convert"] - }, + "ui": {"title": "Convert Image", "tags": ["image", "convert"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -357,6 +328,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): height=image_dto.height, ) + class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): """Blurs an image""" @@ -371,19 +343,14 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Blur Image", - "tags": ["image", "blur"] - }, + "ui": {"title": "Blur Image", "tags": ["image", "blur"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) blur = ( - ImageFilter.GaussianBlur(self.radius) - if self.blur_type == "gaussian" - else ImageFilter.BoxBlur(self.radius) + ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius) ) blur_image = image.filter(blur) @@ -438,10 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Resize Image", - "tags": ["image", "resize"] - }, + "ui": {"title": "Resize Image", "tags": ["image", "resize"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -484,10 +448,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Scale Image", - "tags": ["image", "scale"] - }, + "ui": {"title": "Scale Image", "tags": ["image", "scale"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -532,10 +493,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Image Linear Interpolation", - "tags": ["image", "linear", "interpolation", "lerp"] - }, + "ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -561,6 +519,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): height=image_dto.height, ) + class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): """Inverse linear interpolation of all pixels of an image""" @@ -577,7 +536,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): schema_extra = { "ui": { "title": "Image Inverse Linear Interpolation", - "tags": ["image", "linear", "interpolation", "inverse"] + "tags": ["image", "linear", "interpolation", "inverse"], }, } @@ -585,12 +544,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image = context.services.images.get_pil_image(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) - image_arr = ( - numpy.minimum( - numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1 - ) - * 255 - ) + image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255 ilerp_image = Image.fromarray(numpy.uint8(image_arr)) @@ -609,6 +563,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): height=image_dto.height, ) + class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): """Add blur to NSFW-flagged images""" @@ -622,22 +577,19 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Blur NSFW Images", - "tags": ["image", "nsfw", "checker"] - }, + "ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) - + logger = context.services.logger logger.debug("Running NSFW checker") if SafetyChecker.has_nsfw_concept(image): logger.info("A potentially NSFW image has been detected. Image will be blurred.") blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) caution = self._get_caution_img() - blurry_image.paste(caution,(0,0),caution) + blurry_image.paste(caution, (0, 0), caution) image = blurry_image image_dto = context.services.images.create( @@ -649,20 +601,22 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): is_intermediate=self.is_intermediate, metadata=self.metadata.dict() if self.metadata else None, ) - + return ImageOutput( image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) - - def _get_caution_img(self)->Image: + + def _get_caution_img(self) -> Image: import invokeai.app.assets.images as image_assets - caution = Image.open(Path(image_assets.__path__[0]) / 'caution.png') - return caution.resize((caution.width // 2, caution.height //2)) + + caution = Image.open(Path(image_assets.__path__[0]) / "caution.png") + return caution.resize((caution.width // 2, caution.height // 2)) + class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): - """ Add an invisible watermark to an image """ + """Add an invisible watermark to an image""" # fmt: off type: Literal["img_watermark"] = "img_watermark" @@ -675,10 +629,7 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Add Invisible Watermark", - "tags": ["image", "watermark", "invisible"] - }, + "ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -699,6 +650,3 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): width=image_dto.width, height=image_dto.height, ) - - - diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index ff8900fc50..cd5b2f9a11 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -30,9 +30,7 @@ def infill_methods() -> list[str]: INFILL_METHODS = Literal[tuple(infill_methods())] -DEFAULT_INFILL_METHOD = ( - "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" -) +DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile" def infill_patchmatch(im: Image.Image) -> Image.Image: @@ -44,9 +42,7 @@ def infill_patchmatch(im: Image.Image) -> Image.Image: return im # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) - im_patched_np = PatchMatch.inpaint( - im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3 - ) + im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3) im_patched = Image.fromarray(im_patched_np, mode="RGB") return im_patched @@ -68,9 +64,7 @@ def get_tile_images(image: np.ndarray, width=8, height=8): ) -def tile_fill_missing( - im: Image.Image, tile_size: int = 16, seed: Optional[int] = None -) -> Image.Image: +def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image: # Only fill if there's an alpha layer if im.mode != "RGBA": return im @@ -103,9 +97,7 @@ def tile_fill_missing( # Find all invalid tiles and replace with a random valid tile replace_count = (tiles_mask == False).sum() rng = np.random.default_rng(seed=seed) - tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[ - rng.choice(filtered_tiles.shape[0], replace_count), :, :, : - ] + tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :] # Convert back to an image tiles_all = tiles_all.reshape(tshape) @@ -126,9 +118,7 @@ class InfillColorInvocation(BaseInvocation): """Infills transparent areas of an image with a solid color""" type: Literal["infill_rgba"] = "infill_rgba" - image: Optional[ImageField] = Field( - default=None, description="The image to infill" - ) + image: Optional[ImageField] = Field(default=None, description="The image to infill") color: ColorField = Field( default=ColorField(r=127, g=127, b=127, a=255), description="The color to use to infill", @@ -136,10 +126,7 @@ class InfillColorInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Color Infill", - "tags": ["image", "inpaint", "color", "infill"] - }, + "ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -171,9 +158,7 @@ class InfillTileInvocation(BaseInvocation): type: Literal["infill_tile"] = "infill_tile" - image: Optional[ImageField] = Field( - default=None, description="The image to infill" - ) + image: Optional[ImageField] = Field(default=None, description="The image to infill") tile_size: int = Field(default=32, ge=1, description="The tile size (px)") seed: int = Field( ge=0, @@ -184,18 +169,13 @@ class InfillTileInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Tile Infill", - "tags": ["image", "inpaint", "tile", "infill"] - }, + "ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) - infilled = tile_fill_missing( - image.copy(), seed=self.seed, tile_size=self.tile_size - ) + infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size) infilled.paste(image, (0, 0), image.split()[-1]) image_dto = context.services.images.create( @@ -219,16 +199,11 @@ class InfillPatchMatchInvocation(BaseInvocation): type: Literal["infill_patchmatch"] = "infill_patchmatch" - image: Optional[ImageField] = Field( - default=None, description="The image to infill" - ) + image: Optional[ImageField] = Field(default=None, description="The image to infill") class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Patch Match Infill", - "tags": ["image", "inpaint", "patchmatch", "infill"] - }, + "ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 429f2e539c..f51b91c2d6 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -17,16 +17,17 @@ from invokeai.backend.model_management.models.base import ModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( - ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, - image_resized_to_grid_as_tensor) -from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ - PostprocessingSettings + ConditioningData, + ControlNetData, + StableDiffusionGeneratorPipeline, + image_resized_to_grid_as_tensor, +) +from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.model_management import ModelPatcher from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .compel import ConditioningField from .controlnet_image_processors import ControlField from .image import ImageOutput @@ -47,8 +48,7 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field( - default=None, description="The name of the latents") + latents_name: Optional[str] = Field(default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} @@ -56,14 +56,15 @@ class LatentsField(BaseModel): class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" - #fmt: off + + # fmt: off type: Literal["latents_output"] = "latents_output" # Inputs latents: LatentsField = Field(default=None, description="The output latents") width: int = Field(description="The width of the latents in pixels") height: int = Field(description="The height of the latents in pixels") - #fmt: on + # fmt: on def build_latents_output(latents_name: str, latents: torch.Tensor): @@ -74,9 +75,7 @@ def build_latents_output(latents_name: str, latents: torch.Tensor): ) -SAMPLER_NAME_VALUES = Literal[ - tuple(list(SCHEDULER_MAP.keys())) -] +SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] def get_scheduler( @@ -84,11 +83,10 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim'] - ) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"]) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict(), context=context, + **scheduler_info.dict(), + context=context, ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config @@ -103,7 +101,7 @@ def get_scheduler( scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py - if not hasattr(scheduler, 'uses_inpainting_model'): + if not hasattr(scheduler, "uses_inpainting_model"): scheduler.uses_inpainting_model = lambda: False return scheduler @@ -124,8 +122,8 @@ class TextToLatentsInvocation(BaseInvocation): scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) unet: UNetField = Field(default=None, description="UNet submodel") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @validator("cfg_scale") @@ -134,10 +132,10 @@ class TextToLatentsInvocation(BaseInvocation): if isinstance(v, list): for i in v: if i < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") else: if v < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") return v # Schema customisation @@ -150,8 +148,8 @@ class TextToLatentsInvocation(BaseInvocation): "model": "model", "control": "control", # "cfg_scale": "float", - "cfg_scale": "number" - } + "cfg_scale": "number", + }, }, } @@ -191,16 +189,14 @@ class TextToLatentsInvocation(BaseInvocation): threshold=0.0, # threshold, warmup=0.2, # warmup, h_symmetry_time_pct=None, # h_symmetry_time_pct, - v_symmetry_time_pct=None # v_symmetry_time_pct, + v_symmetry_time_pct=None, # v_symmetry_time_pct, ), ) conditioning_data = conditioning_data.add_scheduler_args_if_applicable( scheduler, - # for ddim scheduler eta=0.0, # ddim_eta - # for ancestral and sde schedulers generator=torch.Generator(device=unet.device).manual_seed(0), ) @@ -248,7 +244,6 @@ class TextToLatentsInvocation(BaseInvocation): exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: - # assuming fixed dimensional scaling of 8:1 for image:latents control_height_resize = latents_shape[2] * 8 control_width_resize = latents_shape[3] * 8 @@ -262,7 +257,7 @@ class TextToLatentsInvocation(BaseInvocation): control_list = control_input else: control_list = None - if (control_list is None): + if control_list is None: control_data = None # from above handling, any control that is not None should now be of type list[ControlField] else: @@ -282,9 +277,7 @@ class TextToLatentsInvocation(BaseInvocation): control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image( - control_image_field.image_name - ) + input_image = context.services.images.get_pil_image(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -322,9 +315,7 @@ class TextToLatentsInvocation(BaseInvocation): noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -333,19 +324,20 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context, + **lora.dict(exclude={"weight"}), + context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context, + **self.unet.unet.dict(), + context=context, ) - with ExitStack() as exit_stack,\ - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( + unet_info.context.model, _lora_loader() + ), unet_info as unet: noise = noise.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( @@ -358,7 +350,9 @@ class TextToLatentsInvocation(BaseInvocation): conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( - model=pipeline, context=context, control_input=self.control, + model=pipeline, + context=context, + control_input=self.control, latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, @@ -379,7 +373,7 @@ class TextToLatentsInvocation(BaseInvocation): result_latents = result_latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -390,11 +384,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to use as a base image") - strength: float = Field( - default=0.7, ge=0, le=1, - description="The strength of the latents to use") + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") + strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -406,7 +397,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): "model": "model", "control": "control", "cfg_scale": "number", - } + }, }, } @@ -416,9 +407,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -427,19 +416,20 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"}), context=context, + **lora.dict(exclude={"weight"}), + context=context, ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context, + **self.unet.unet.dict(), + context=context, ) - with ExitStack() as exit_stack,\ - ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( + unet_info.context.model, _lora_loader() + ), unet_info as unet: noise = noise.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype) @@ -453,7 +443,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): conditioning_data = self.get_conditioning_data(context, scheduler, unet) control_data = self.prep_control_data( - model=pipeline, context=context, control_input=self.control, + model=pipeline, + context=context, + control_input=self.control, latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, @@ -461,8 +453,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): ) # TODO: Verify the noise is the right size - initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype + initial_latents = ( + latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) ) timesteps, _ = pipeline.get_img2img_timesteps( @@ -478,14 +470,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): num_inference_steps=self.steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] - callback=step_callback + callback=step_callback, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) @@ -497,14 +489,13 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to generate an image from") + latents: Optional[LatentsField] = Field(description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field( - default=False, - description="Decode latents by overlapping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") + tiled: bool = Field(default=False, description="Decode latents by overlapping tiles(less memory consumption)") + fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") + metadata: Optional[CoreMetadata] = Field( + default=None, description="Optional core metadata to be written to the image" + ) # Schema customisation class Config(InvocationConfig): @@ -520,7 +511,8 @@ class LatentsToImageInvocation(BaseInvocation): latents = context.services.latents.get(self.latents.latents_name) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), context=context, + **self.vae.vae.dict(), + context=context, ) with vae_info as vae: @@ -587,8 +579,7 @@ class LatentsToImageInvocation(BaseInvocation): ) -LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", - "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -597,36 +588,30 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to resize") - width: Union[int, None] = Field(default=512, - ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(default=512, - ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode") + latents: Optional[LatentsField] = Field(description="The latents to resize") + width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") + height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") antialias: bool = Field( - default=False, - description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + ) class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Resize Latents", - "tags": ["latents", "resize"] - }, + "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]}, } def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # TODO: - device=choose_torch_device() + device = choose_torch_device() resized_latents = torch.nn.functional.interpolate( - latents.to(device), size=(self.height // 8, self.width // 8), - mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False, + latents.to(device), + size=(self.height // 8, self.width // 8), + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -645,35 +630,30 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field( - description="The latents to scale") - scale_factor: float = Field( - gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field( - default="bilinear", description="The interpolation mode") + latents: Optional[LatentsField] = Field(description="The latents to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") antialias: bool = Field( - default=False, - description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + ) class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Scale Latents", - "tags": ["latents", "scale"] - }, + "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]}, } def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # TODO: - device=choose_torch_device() + device = choose_torch_device() # resizing resized_latents = torch.nn.functional.interpolate( - latents.to(device), scale_factor=self.scale_factor, mode=self.mode, - antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False, + latents.to(device), + scale_factor=self.scale_factor, + mode=self.mode, + antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 @@ -694,19 +674,13 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Optional[ImageField] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field( - default=False, - description="Encode latents by overlaping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") - + tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Image To Latents", - "tags": ["latents", "image"] - }, + "ui": {"title": "Image To Latents", "tags": ["latents", "image"]}, } @torch.no_grad() @@ -716,9 +690,10 @@ class ImageToLatentsInvocation(BaseInvocation): # ) image = context.services.images.get_pil_image(self.image.image_name) - #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) + # vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( - **self.vae.vae.dict(), context=context, + **self.vae.vae.dict(), + context=context, ) image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) @@ -745,12 +720,12 @@ class ImageToLatentsInvocation(BaseInvocation): vae.post_quant_conv.to(orig_dtype) vae.decoder.conv_in.to(orig_dtype) vae.decoder.mid_block.to(orig_dtype) - #else: + # else: # latents = latents.float() else: vae.to(dtype=torch.float16) - #latents = latents.half() + # latents = latents.half() if self.tiled: vae.enable_tiling() @@ -761,9 +736,7 @@ class ImageToLatentsInvocation(BaseInvocation): image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype) with torch.inference_mode(): image_tensor_dist = vae.encode(image_tensor).latent_dist - latents = image_tensor_dist.sample().to( - dtype=vae.dtype - ) # FIXME: uses torch.randn. make reproducible! + latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible! latents = vae.config.scaling_factor * latents latents = latents.to(dtype=orig_dtype) diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 92cff04bf7..32b1ab2a39 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -54,10 +54,7 @@ class AddInvocation(BaseInvocation, MathInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Add", - "tags": ["math", "add"] - }, + "ui": {"title": "Add", "tags": ["math", "add"]}, } def invoke(self, context: InvocationContext) -> IntOutput: @@ -75,10 +72,7 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Subtract", - "tags": ["math", "subtract"] - }, + "ui": {"title": "Subtract", "tags": ["math", "subtract"]}, } def invoke(self, context: InvocationContext) -> IntOutput: @@ -96,10 +90,7 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Multiply", - "tags": ["math", "multiply"] - }, + "ui": {"title": "Multiply", "tags": ["math", "multiply"]}, } def invoke(self, context: InvocationContext) -> IntOutput: @@ -117,10 +108,7 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Divide", - "tags": ["math", "divide"] - }, + "ui": {"title": "Divide", "tags": ["math", "divide"]}, } def invoke(self, context: InvocationContext) -> IntOutput: @@ -140,10 +128,7 @@ class RandomIntInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Random Integer", - "tags": ["math", "random", "integer"] - }, + "ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]}, } def invoke(self, context: InvocationContext) -> IntOutput: diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index b0742d0419..3588ef4ebe 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField + class LoRAMetadataField(BaseModel): """LoRA metadata for an image generated in InvokeAI.""" @@ -37,9 +38,7 @@ class CoreMetadata(BaseModel): description="The number of skipped CLIP layers", ) model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField] = Field( - description="The ControlNets used for inference" - ) + controlnets: list[ControlField] = Field(description="The ControlNets used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") vae: Union[VAEModelField, None] = Field( default=None, @@ -51,38 +50,24 @@ class CoreMetadata(BaseModel): default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field( - default=None, description="The name of the initial image" - ) + init_image: Union[str, None] = Field(default=None, description="The name of the initial image") # SDXL - positive_style_prompt: Union[str, None] = Field( - default=None, description="The positive style prompt parameter" - ) - negative_style_prompt: Union[str, None] = Field( - default=None, description="The negative style prompt parameter" - ) + positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") + negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field( - default=None, description="The SDXL Refiner model used" - ) + refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") refiner_cfg_scale: Union[float, None] = Field( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field( - default=None, description="The number of steps used for the refiner" - ) - refiner_scheduler: Union[str, None] = Field( - default=None, description="The scheduler used for the refiner" - ) + refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") + refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") refiner_aesthetic_store: Union[float, None] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_start: Union[float, None] = Field( - default=None, description="The start value used for refiner denoising" - ) + refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") class ImageMetadata(BaseModel): @@ -92,9 +77,7 @@ class ImageMetadata(BaseModel): default=None, description="The image's core metadata, if it was created in the Linear or Canvas UI", ) - graph: Optional[dict] = Field( - default=None, description="The graph that created the image" - ) + graph: Optional[dict] = Field(default=None, description="The graph that created the image") class MetadataAccumulatorOutput(BaseInvocationOutput): @@ -126,50 +109,34 @@ class MetadataAccumulatorInvocation(BaseInvocation): description="The number of skipped CLIP layers", ) model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField] = Field( - description="The ControlNets used for inference" - ) + controlnets: list[ControlField] = Field(description="The ControlNets used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") strength: Union[float, None] = Field( default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field( - default=None, description="The name of the initial image" - ) + init_image: Union[str, None] = Field(default=None, description="The name of the initial image") vae: Union[VAEModelField, None] = Field( default=None, description="The VAE used for decoding, if the main model's default was not used", ) # SDXL - positive_style_prompt: Union[str, None] = Field( - default=None, description="The positive style prompt parameter" - ) - negative_style_prompt: Union[str, None] = Field( - default=None, description="The negative style prompt parameter" - ) + positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") + negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field( - default=None, description="The SDXL Refiner model used" - ) + refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") refiner_cfg_scale: Union[float, None] = Field( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field( - default=None, description="The number of steps used for the refiner" - ) - refiner_scheduler: Union[str, None] = Field( - default=None, description="The scheduler used for the refiner" - ) + refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") + refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") refiner_aesthetic_store: Union[float, None] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_start: Union[float, None] = Field( - default=None, description="The start value used for refiner denoising" - ) + refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") class Config(InvocationConfig): schema_extra = { diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 125ec4461f..c19e5c5c9a 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,17 +4,14 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field from ...backend.model_management import BaseModelType, ModelType, SubModelType -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field( - default=None, description="Info to load submodel" - ) + submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel") class LoraInfo(ModelInfo): @@ -33,6 +30,7 @@ class ClipField(BaseModel): skipped_layers: int = Field(description="Number of skipped layers in text_encoder") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") @@ -49,6 +47,7 @@ class ModelLoaderOutput(BaseInvocationOutput): vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on + class MainModelField(BaseModel): """Main model field""" @@ -63,6 +62,7 @@ class LoRAModelField(BaseModel): model_name: str = Field(description="Name of the LoRA model") base_model: BaseModelType = Field(description="Base model") + class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -181,7 +181,7 @@ class MainModelLoaderInvocation(BaseInvocation): ), ) - + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" @@ -198,9 +198,7 @@ class LoraLoaderInvocation(BaseInvocation): type: Literal["lora_loader"] = "lora_loader" - lora: Union[LoRAModelField, None] = Field( - default=None, description="Lora model name" - ) + lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") @@ -229,14 +227,10 @@ class LoraLoaderInvocation(BaseInvocation): ): raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any( - lora.model_name == lora_name for lora in self.unet.loras - ): + if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any( - lora.model_name == lora_name for lora in self.clip.loras - ): + if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index bff1fa7ecd..2bec128b87 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -16,8 +16,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.model_management import ONNXModelPatcher from ...backend.util import choose_torch_device -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .compel import ConditioningField from .controlnet_image_processors import ControlField from .image import ImageOutput @@ -49,9 +48,8 @@ ORT_TO_NP_TYPE = { "tensor(double)": np.float64, } -PRECISION_VALUES = Literal[ - tuple(list(ORT_TO_NP_TYPE.keys())) -] +PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] + class ONNXPromptInvocation(BaseInvocation): type: Literal["prompt_onnx"] = "prompt_onnx" @@ -66,25 +64,25 @@ class ONNXPromptInvocation(BaseInvocation): text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder,\ - ExitStack() as stack: - - #loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack: + # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] + loras = [ + (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) + for lora in self.clip.loras + ] ti_list = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): name = trigger[1:-1] try: ti_list.append( - #stack.enter_context( + # stack.enter_context( # context.services.model_manager.get_model( # model_name=name, # base_model=self.clip.text_encoder.base_model, # model_type=ModelType.TextualInversion, # ) - #) + # ) context.services.model_manager.get_model( model_name=name, base_model=self.clip.text_encoder.base_model, @@ -92,15 +90,15 @@ class ONNXPromptInvocation(BaseInvocation): ).context.model ) except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + # print(e) + # import traceback + # print(traceback.format_exc()) + print(f'Warn: trigger: "{trigger}" not found') if loras or ti_list: text_encoder.release_session() - with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - + with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti( + orig_tokenizer, text_encoder, ti_list + ) as (tokenizer, ti_manager): text_encoder.create_session() # copy from @@ -128,7 +126,6 @@ class ONNXPromptInvocation(BaseInvocation): prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" # TODO: hacky but works ;D maybe rename latents somehow? @@ -140,6 +137,7 @@ class ONNXPromptInvocation(BaseInvocation): ), ) + # Text to image class ONNXTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" @@ -157,8 +155,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation): precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents") unet: UNetField = Field(default=None, description="UNet submodel") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @validator("cfg_scale") @@ -167,10 +165,10 @@ class ONNXTextToLatentsInvocation(BaseInvocation): if isinstance(v, list): for i in v: if i < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") else: if v < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") return v # Schema customisation @@ -179,11 +177,11 @@ class ONNXTextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" - } + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number", + }, }, } @@ -192,8 +190,7 @@ class ONNXTextToLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] if isinstance(c, torch.Tensor): c = c.cpu().numpy() @@ -211,9 +208,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # get the initial random noise unless the user supplied it do_classifier_free_guidance = True - #latents_dtype = prompt_embeds.dtype - #latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - #if latents.shape != latents_shape: + # latents_dtype = prompt_embeds.dtype + # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + # if latents.shape != latents_shape: # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") scheduler = get_scheduler( @@ -229,8 +226,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation): return torch.from_numpy(latent).to(device) def dispatch_progress( - self, context: InvocationContext, source_node_id: str, - intermediate_state: PipelineIntermediateState) -> None: + self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState + ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -249,16 +246,17 @@ class ONNXTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet,\ - ExitStack() as stack: - - #loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] + with unet_info as unet, ExitStack() as stack: + # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + loras = [ + (context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) + for lora in self.unet.loras + ] if loras: unet.release_session() with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: + # TODO: _, _, h, w = latents.shape unet.create_session(h, w) @@ -290,28 +288,21 @@ class ONNXTextToLatentsInvocation(BaseInvocation): latents = torch2numpy(scheduler_output.prev_sample) state = PipelineIntermediateState( - run_id= "test", - step=i, - timestep=timestep, - latents=scheduler_output.prev_sample - ) - dispatch_progress( - self, - context=context, - source_node_id=source_node_id, - intermediate_state=state + run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample ) + dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) # call the callback, if provided - #if callback is not None and i % callback_steps == 0: + # if callback is not None and i % callback_steps == 0: # callback(i, t, latents) torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) + # Latent to image class ONNXLatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" @@ -321,8 +312,10 @@ class ONNXLatentsToImageInvocation(BaseInvocation): # Inputs latents: Optional[LatentsField] = Field(description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") - #tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + metadata: Optional[CoreMetadata] = Field( + default=None, description="Optional core metadata to be written to the image" + ) + # tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -353,15 +346,12 @@ class ONNXLatentsToImageInvocation(BaseInvocation): latents = 1 / 0.18215 * latents # image = self.vae_decoder(latent_sample=latents)[0] # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate( - [vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] - ) + image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) image = np.clip(image / 2 + 0.5, 0, 1) image = image.transpose((0, 2, 3, 1)) image = VaeImageProcessor.numpy_to_pil(image)[0] - torch.cuda.empty_cache() image_dto = context.services.images.create( @@ -380,17 +370,19 @@ class ONNXLatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) + class ONNXModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae_decoder: VaeField = Field(default=None, description="Vae submodel") vae_encoder: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on + class ONNXSD1ModelLoaderInvocation(BaseInvocation): """Loading submodels of selected model.""" @@ -403,16 +395,10 @@ class ONNXSD1ModelLoaderInvocation(BaseInvocation): # Schema customisation class Config(InvocationConfig): schema_extra = { - "ui": { - "tags": ["model", "loader"], - "type_hints": { - "model_name": "model" # TODO: rename to model_name? - } - }, + "ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name? } def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - model_name = "stable-diffusion-v1-5" base_model = BaseModelType.StableDiffusion1 @@ -424,7 +410,6 @@ class ONNXSD1ModelLoaderInvocation(BaseInvocation): ): raise Exception(f"Unkown model name: {model_name}!") - return ONNXModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -471,9 +456,10 @@ class ONNXSD1ModelLoaderInvocation(BaseInvocation): model_type=ModelType.ONNX, submodel=SubModelType.VaeEncoder, ), - ) + ), ) + class OnnxModelField(BaseModel): """Onnx model field""" @@ -481,6 +467,7 @@ class OnnxModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Model Type") + class OnnxModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -587,5 +574,5 @@ class OnnxModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.VaeEncoder, ), - ) - ) \ No newline at end of file + ), + ) diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index 70d3ddc7d2..f910e5379c 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -12,16 +12,37 @@ import matplotlib.pyplot as plt from easing_functions import ( LinearInOut, - QuadEaseInOut, QuadEaseIn, QuadEaseOut, - CubicEaseInOut, CubicEaseIn, CubicEaseOut, - QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut, - QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut, - SineEaseInOut, SineEaseIn, SineEaseOut, - CircularEaseIn, CircularEaseInOut, CircularEaseOut, - ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut, - ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut, - BackEaseIn, BackEaseInOut, BackEaseOut, - BounceEaseIn, BounceEaseInOut, BounceEaseOut) + QuadEaseInOut, + QuadEaseIn, + QuadEaseOut, + CubicEaseInOut, + CubicEaseIn, + CubicEaseOut, + QuarticEaseInOut, + QuarticEaseIn, + QuarticEaseOut, + QuinticEaseInOut, + QuinticEaseIn, + QuinticEaseOut, + SineEaseInOut, + SineEaseIn, + SineEaseOut, + CircularEaseIn, + CircularEaseInOut, + CircularEaseOut, + ExponentialEaseInOut, + ExponentialEaseIn, + ExponentialEaseOut, + ElasticEaseIn, + ElasticEaseInOut, + ElasticEaseOut, + BackEaseIn, + BackEaseInOut, + BackEaseOut, + BounceEaseIn, + BounceEaseInOut, + BounceEaseOut, +) from .baseinvocation import ( BaseInvocation, @@ -45,17 +66,12 @@ class FloatLinearRangeInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Linear Range (Float)", - "tags": ["math", "float", "linear", "range"] - }, + "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]}, } def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) - return FloatCollectionOutput( - collection=param_list - ) + return FloatCollectionOutput(collection=param_list) EASING_FUNCTIONS_MAP = { @@ -92,9 +108,7 @@ EASING_FUNCTIONS_MAP = { "BounceInOut": BounceEaseInOut, } -EASING_FUNCTION_KEYS: Any = Literal[ - tuple(list(EASING_FUNCTIONS_MAP.keys())) -] +EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] # actually I think for now could just use CollectionOutput (which is list[Any] @@ -123,13 +137,9 @@ class StepParamEasingInvocation(BaseInvocation): class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Param Easing By Step", - "tags": ["param", "step", "easing"] - }, + "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]}, } - def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False # convert from start_step_percent to nearest step <= (steps * start_step_percent) @@ -170,12 +180,13 @@ class StepParamEasingInvocation(BaseInvocation): # and create reverse copy of list[1:end-1] # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always - base_easing_duration = int(np.ceil(num_easing_steps/2.0)) - if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration)) - even_num_steps = (num_easing_steps % 2 == 0) # even number of steps - easing_function = easing_class(start=self.start_value, - end=self.end_value, - duration=base_easing_duration - 1) + base_easing_duration = int(np.ceil(num_easing_steps / 2.0)) + if log_diagnostics: + context.services.logger.debug("base easing duration: " + str(base_easing_duration)) + even_num_steps = num_easing_steps % 2 == 0 # even number of steps + easing_function = easing_class( + start=self.start_value, end=self.end_value, duration=base_easing_duration - 1 + ) base_easing_vals = list() for step_index in range(base_easing_duration): easing_val = easing_function.ease(step_index) @@ -214,9 +225,7 @@ class StepParamEasingInvocation(BaseInvocation): # else: # no mirroring (default) - easing_function = easing_class(start=self.start_value, - end=self.end_value, - duration=num_easing_steps - 1) + easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1) for step_index in range(num_easing_steps): step_val = easing_function.ease(step_index) easing_list.append(step_val) @@ -240,13 +249,11 @@ class StepParamEasingInvocation(BaseInvocation): ax = plt.gca() ax.xaxis.set_major_locator(MaxNLocator(integer=True)) buf = io.BytesIO() - plt.savefig(buf, format='png') + plt.savefig(buf, format="png") buf.seek(0) im = PIL.Image.open(buf) im.show() buf.close() # output array of size steps, each entry list[i] is param value for step i - return FloatCollectionOutput( - collection=param_list - ) + return FloatCollectionOutput(collection=param_list) diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py index 0f01d65948..127eefa21d 100644 --- a/invokeai/app/invocations/params.py +++ b/invokeai/app/invocations/params.py @@ -4,67 +4,63 @@ from typing import Literal from pydantic import Field -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .math import FloatOutput, IntOutput # Pass-through parameter nodes - used by subgraphs + class ParamIntInvocation(BaseInvocation): """An integer parameter""" - #fmt: off + + # fmt: off type: Literal["param_int"] = "param_int" a: int = Field(default=0, description="The integer value") - #fmt: on + # fmt: on class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["param", "integer"], - "title": "Integer Parameter" - }, - } + schema_extra = { + "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"}, + } def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a) + class ParamFloatInvocation(BaseInvocation): """A float parameter""" - #fmt: off + + # fmt: off type: Literal["param_float"] = "param_float" param: float = Field(default=0.0, description="The float value") - #fmt: on + # fmt: on class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["param", "float"], - "title": "Float Parameter" - }, - } + schema_extra = { + "ui": {"tags": ["param", "float"], "title": "Float Parameter"}, + } def invoke(self, context: InvocationContext) -> FloatOutput: return FloatOutput(param=self.param) + class StringOutput(BaseInvocationOutput): """A string output""" + type: Literal["string_output"] = "string_output" text: str = Field(default=None, description="The output string") class ParamStringInvocation(BaseInvocation): """A string parameter""" - type: Literal['param_string'] = 'param_string' - text: str = Field(default='', description='The string value') + + type: Literal["param_string"] = "param_string" + text: str = Field(default="", description="The string value") class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["param", "string"], - "title": "String Parameter" - }, - } + schema_extra = { + "ui": {"tags": ["param", "string"], "title": "String Parameter"}, + } def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(text=self.text) - \ No newline at end of file diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 5d07a88759..83a397ddcf 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -7,19 +7,21 @@ from pydantic import Field, validator from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator + class PromptOutput(BaseInvocationOutput): """Base class for invocations that output a prompt""" - #fmt: off + + # fmt: off type: Literal["prompt"] = "prompt" prompt: str = Field(default=None, description="The output prompt") - #fmt: on + # fmt: on class Config: schema_extra = { - 'required': [ - 'type', - 'prompt', + "required": [ + "type", + "prompt", ] } @@ -44,16 +46,11 @@ class DynamicPromptInvocation(BaseInvocation): type: Literal["dynamic_prompt"] = "dynamic_prompt" prompt: str = Field(description="The prompt to parse with dynamicprompts") max_prompts: int = Field(default=1, description="The number of prompts to generate") - combinatorial: bool = Field( - default=False, description="Whether to use the combinatorial generator" - ) + combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator") class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Dynamic Prompt", - "tags": ["prompt", "dynamic"] - }, + "ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]}, } def invoke(self, context: InvocationContext) -> PromptCollectionOutput: @@ -65,10 +62,11 @@ class DynamicPromptInvocation(BaseInvocation): prompts = generator.generate(self.prompt, num_images=self.max_prompts) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) - + class PromptsFromFileInvocation(BaseInvocation): - '''Loads prompts from a text file''' + """Loads prompts from a text file""" + # fmt: off type: Literal['prompt_from_file'] = 'prompt_from_file' @@ -78,14 +76,11 @@ class PromptsFromFileInvocation(BaseInvocation): post_prompt: Optional[str] = Field(description="String to append to each prompt") start_line: int = Field(default=1, ge=1, description="Line in the file to start start from") max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)") - #fmt: on + # fmt: on class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Prompts From File", - "tags": ["prompt", "file"] - }, + "ui": {"title": "Prompts From File", "tags": ["prompt", "file"]}, } @validator("file_path") @@ -103,11 +98,13 @@ class PromptsFromFileInvocation(BaseInvocation): with open(file_path) as f: for i, line in enumerate(f): if i >= start_line and i < end_line: - prompts.append((pre_prompt or '') + line.strip() + (post_prompt or '')) + prompts.append((pre_prompt or "") + line.strip() + (post_prompt or "")) if i >= end_line: break return prompts def invoke(self, context: InvocationContext) -> PromptCollectionOutput: - prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts) + prompts = self.promptsFromFile( + self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts + ) return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 3b4d3a9d86..303a2fedd9 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -7,13 +7,13 @@ from pydantic import Field, validator from ...backend.model_management import ModelType, SubModelType from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback -from .baseinvocation import (BaseInvocation, BaseInvocationOutput, - InvocationConfig, InvocationContext) +from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo from .compel import ConditioningField from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output + class SDXLModelLoaderOutput(BaseInvocationOutput): """SDXL base model loader output""" @@ -26,16 +26,19 @@ class SDXLModelLoaderOutput(BaseInvocationOutput): vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on + class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): """SDXL refiner model loader output""" + # fmt: off type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on - #fmt: on - + # fmt: on + + class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" @@ -125,8 +128,10 @@ class SDXLModelLoaderInvocation(BaseInvocation): ), ) + class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" + type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" model: MainModelField = Field(description="The model to load") @@ -196,7 +201,8 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation): ), ), ) - + + # Text to image class SDXLTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" @@ -213,9 +219,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation): scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) unet: UNetField = Field(default=None, description="UNet submodel") denoising_end: float = Field(default=1.0, gt=0, le=1, description="") - #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") + # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @validator("cfg_scale") @@ -224,10 +230,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation): if isinstance(v, list): for i in v: if i < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") else: if v < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") return v # Schema customisation @@ -237,10 +243,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation): "title": "SDXL Text To Latents", "tags": ["latents"], "type_hints": { - "model": "model", - # "cfg_scale": "float", - "cfg_scale": "number" - } + "model": "model", + # "cfg_scale": "float", + "cfg_scale": "number", + }, }, } @@ -265,9 +271,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] latents = context.services.latents.get(self.noise.latents_name) @@ -293,14 +297,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation): latents = latents * scheduler.init_noise_sigma - - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context - ) + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: - extra_step_kwargs = dict() if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): extra_step_kwargs.update( @@ -350,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation): if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - #del noise_pred_uncond - #del noise_pred_text + # del noise_pred_uncond + # del noise_pred_text - #if do_classifier_free_guidance and guidance_rescale > 0.0: + # if do_classifier_free_guidance and guidance_rescale > 0.0: # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) @@ -364,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) - #if callback is not None and i % callback_steps == 0: + # if callback is not None and i % callback_steps == 0: # callback(i, t, latents) else: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype) @@ -378,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation): with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = scheduler.scale_model_input(latents, t) - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # import gc + # gc.collect() + # torch.cuda.empty_cache() # predict the noise residual @@ -411,42 +411,41 @@ class SDXLTextToLatentsInvocation(BaseInvocation): # perform guidance noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - #del noise_pred_text - #del noise_pred_uncond - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # del noise_pred_text + # del noise_pred_uncond + # import gc + # gc.collect() + # torch.cuda.empty_cache() - #if do_classifier_free_guidance and guidance_rescale > 0.0: + # if do_classifier_free_guidance and guidance_rescale > 0.0: # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - #del noise_pred - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # del noise_pred + # import gc + # gc.collect() + # torch.cuda.empty_cache() # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) - #if callback is not None and i % callback_steps == 0: + # if callback is not None and i % callback_steps == 0: # callback(i, t, latents) - - ################# latents = latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) + class SDXLLatentsToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" @@ -466,9 +465,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): denoising_start: float = Field(default=0.0, ge=0, le=1, description="") denoising_end: float = Field(default=1.0, ge=0, le=1, description="") - #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") + # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @validator("cfg_scale") @@ -477,10 +476,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): if isinstance(v, list): for i in v: if i < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") else: if v < 1: - raise ValueError('cfg_scale must be greater than 1') + raise ValueError("cfg_scale must be greater than 1") return v # Schema customisation @@ -490,10 +489,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): "title": "SDXL Latents to Latents", "tags": ["latents"], "type_hints": { - "model": "model", - # "cfg_scale": "float", - "cfg_scale": "number" - } + "model": "model", + # "cfg_scale": "float", + "cfg_scale": "number", + }, }, } @@ -518,9 +517,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id - ) + graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] latents = context.services.latents.get(self.latents.latents_name) @@ -545,7 +542,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): scheduler.set_timesteps(num_inference_steps) t_start = int(round(self.denoising_start * num_inference_steps)) - timesteps = scheduler.timesteps[t_start * scheduler.order:] + timesteps = scheduler.timesteps[t_start * scheduler.order :] num_inference_steps = num_inference_steps - t_start # apply noise(if provided) @@ -555,12 +552,12 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): del noise unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), context=context, + **self.unet.unet.dict(), + context=context, ) do_classifier_free_guidance = True cross_attention_kwargs = None with unet_info as unet: - # apply scheduler extra args extra_step_kwargs = dict() if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): @@ -611,10 +608,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - #del noise_pred_uncond - #del noise_pred_text + # del noise_pred_uncond + # del noise_pred_text - #if do_classifier_free_guidance and guidance_rescale > 0.0: + # if do_classifier_free_guidance and guidance_rescale > 0.0: # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) @@ -625,7 +622,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) - #if callback is not None and i % callback_steps == 0: + # if callback is not None and i % callback_steps == 0: # callback(i, t, latents) else: negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype) @@ -639,13 +636,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = scheduler.scale_model_input(latents, t) - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # import gc + # gc.collect() + # torch.cuda.empty_cache() # predict the noise residual @@ -672,38 +669,36 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation): # perform guidance noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - #del noise_pred_text - #del noise_pred_uncond - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # del noise_pred_text + # del noise_pred_uncond + # import gc + # gc.collect() + # torch.cuda.empty_cache() - #if do_classifier_free_guidance and guidance_rescale > 0.0: + # if do_classifier_free_guidance and guidance_rescale > 0.0: # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf # noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - #del noise_pred - #import gc - #gc.collect() - #torch.cuda.empty_cache() + # del noise_pred + # import gc + # gc.collect() + # torch.cuda.empty_cache() # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0): progress_bar.update() self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps) - #if callback is not None and i % callback_steps == 0: + # if callback is not None and i % callback_steps == 0: # callback(i, t, latents) - - ################# latents = latents.to("cpu") torch.cuda.empty_cache() - name = f'{context.graph_execution_state_id}__{self.id}' + name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 1a45a631df..fd220223db 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -29,16 +29,11 @@ class ESRGANInvocation(BaseInvocation): type: Literal["esrgan"] = "esrgan" image: Union[ImageField, None] = Field(default=None, description="The input image") - model_name: ESRGAN_MODELS = Field( - default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use" - ) + model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") class Config(InvocationConfig): schema_extra = { - "ui": { - "title": "Upscale (RealESRGAN)", - "tags": ["image", "upscale", "realesrgan"] - }, + "ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]}, } def invoke(self, context: InvocationContext) -> ImageOutput: @@ -108,9 +103,7 @@ class ESRGANInvocation(BaseInvocation): upscaled_image, img_mode = upsampler.enhance(cv_image) # back to PIL - pil_image = Image.fromarray( - cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB) - ).convert("RGBA") + pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA") image_dto = context.services.images.create( image=pil_image, diff --git a/invokeai/app/models/exceptions.py b/invokeai/app/models/exceptions.py index 32ad3b8f03..662e1948ce 100644 --- a/invokeai/app/models/exceptions.py +++ b/invokeai/app/models/exceptions.py @@ -1,3 +1,4 @@ class CanceledException(Exception): """Execution canceled by user.""" + pass diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 1183cabc54..2a5a0f9d3b 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -8,6 +8,7 @@ from ..invocations.baseinvocation import ( InvocationConfig, ) + class ImageField(BaseModel): """An image field used for passing image objects between invocations""" @@ -34,6 +35,7 @@ class ProgressImage(BaseModel): height: int = Field(description="The effective height of the image in pixels") dataURL: str = Field(description="The image data as a b64 data URL") + class PILInvocationConfig(BaseModel): """Helper class to provide all PIL invocations with additional config""" @@ -44,6 +46,7 @@ class PILInvocationConfig(BaseModel): }, } + class ImageOutput(BaseInvocationOutput): """Base class for invocations that output an image""" @@ -76,6 +79,7 @@ class MaskOutput(BaseInvocationOutput): ] } + class ResourceOrigin(str, Enum, metaclass=MetaEnum): """The origin of a resource (eg image). @@ -132,5 +136,3 @@ class InvalidImageCategoryException(ValueError): def __init__(self, message="Invalid image category."): super().__init__(message) - - diff --git a/invokeai/app/services/board_image_record_storage.py b/invokeai/app/services/board_image_record_storage.py index 491972bd32..f0007c8cef 100644 --- a/invokeai/app/services/board_image_record_storage.py +++ b/invokeai/app/services/board_image_record_storage.py @@ -207,9 +207,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase): raise e finally: self._lock.release() - return OffsetPaginatedResults( - items=images, offset=offset, limit=limit, total=count - ) + return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count) def get_all_board_image_names_for_board(self, board_id: str) -> list[str]: try: diff --git a/invokeai/app/services/board_images.py b/invokeai/app/services/board_images.py index b9f9663603..22332d6c29 100644 --- a/invokeai/app/services/board_images.py +++ b/invokeai/app/services/board_images.py @@ -102,9 +102,7 @@ class BoardImagesService(BoardImagesServiceABC): self, board_id: str, ) -> list[str]: - return self._services.board_image_records.get_all_board_image_names_for_board( - board_id - ) + return self._services.board_image_records.get_all_board_image_names_for_board(board_id) def get_board_for_image( self, @@ -114,9 +112,7 @@ class BoardImagesService(BoardImagesServiceABC): return board_id -def board_record_to_dto( - board_record: BoardRecord, cover_image_name: Optional[str], image_count: int -) -> BoardDTO: +def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO: """Converts a board record to a board DTO.""" return BoardDTO( **board_record.dict(exclude={"cover_image_name"}), diff --git a/invokeai/app/services/board_record_storage.py b/invokeai/app/services/board_record_storage.py index 15ea9cc5a7..2fad7b0ab3 100644 --- a/invokeai/app/services/board_record_storage.py +++ b/invokeai/app/services/board_record_storage.py @@ -15,9 +15,7 @@ from pydantic import BaseModel, Field, Extra class BoardChanges(BaseModel, extra=Extra.forbid): board_name: Optional[str] = Field(description="The board's new name.") - cover_image_name: Optional[str] = Field( - description="The name of the board's new cover image." - ) + cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.") class BoardRecordNotFoundException(Exception): @@ -292,9 +290,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): count = cast(int, self._cursor.fetchone()[0]) - return OffsetPaginatedResults[BoardRecord]( - items=boards, offset=offset, limit=limit, total=count - ) + return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count) except sqlite3.Error as e: self._conn.rollback() diff --git a/invokeai/app/services/boards.py b/invokeai/app/services/boards.py index 9361322e6c..53d30b2e85 100644 --- a/invokeai/app/services/boards.py +++ b/invokeai/app/services/boards.py @@ -108,16 +108,12 @@ class BoardService(BoardServiceABC): def get_dto(self, board_id: str) -> BoardDTO: board_record = self._services.board_records.get(board_id) - cover_image = self._services.image_records.get_most_recent_image_for_board( - board_record.board_id - ) + cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board( - board_id - ) + image_count = self._services.board_image_records.get_image_count_for_board(board_id) return board_record_to_dto(board_record, cover_image_name, image_count) def update( @@ -126,60 +122,44 @@ class BoardService(BoardServiceABC): changes: BoardChanges, ) -> BoardDTO: board_record = self._services.board_records.update(board_id, changes) - cover_image = self._services.image_records.get_most_recent_image_for_board( - board_record.board_id - ) + cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board( - board_id - ) + image_count = self._services.board_image_records.get_image_count_for_board(board_id) return board_record_to_dto(board_record, cover_image_name, image_count) def delete(self, board_id: str) -> None: self._services.board_records.delete(board_id) - def get_many( - self, offset: int = 0, limit: int = 10 - ) -> OffsetPaginatedResults[BoardDTO]: + def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]: board_records = self._services.board_records.get_many(offset, limit) board_dtos = [] for r in board_records.items: - cover_image = self._services.image_records.get_most_recent_image_for_board( - r.board_id - ) + cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board( - r.board_id - ) + image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) - return OffsetPaginatedResults[BoardDTO]( - items=board_dtos, offset=offset, limit=limit, total=len(board_dtos) - ) + return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)) def get_all(self) -> list[BoardDTO]: board_records = self._services.board_records.get_all() board_dtos = [] for r in board_records: - cover_image = self._services.image_records.get_most_recent_image_for_board( - r.board_id - ) + cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id) if cover_image: cover_image_name = cover_image.image_name else: cover_image_name = None - image_count = self._services.board_image_records.get_image_count_for_board( - r.board_id - ) + image_count = self._services.board_image_records.get_image_count_for_board(r.board_id) board_dtos.append(board_record_to_dto(r, cover_image_name, image_count)) - return board_dtos \ No newline at end of file + return board_dtos diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index dfcff86ca0..98855fe879 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -1,6 +1,6 @@ # Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team -'''Invokeai configuration system. +"""Invokeai configuration system. Arguments and fields are taken from the pydantic definition of the model. Defaults can be set by creating a yaml configuration file that @@ -158,7 +158,7 @@ two configs are kept in separate sections of the config file: outdir: outputs ... -''' +""" from __future__ import annotations import argparse import pydoc @@ -170,64 +170,68 @@ from pathlib import Path from pydantic import BaseSettings, Field, parse_obj_as from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args -INIT_FILE = Path('invokeai.yaml') -MODEL_CORE = Path('models/core') -DB_FILE = Path('invokeai.db') -LEGACY_INIT_FILE = Path('invokeai.init') +INIT_FILE = Path("invokeai.yaml") +MODEL_CORE = Path("models/core") +DB_FILE = Path("invokeai.db") +LEGACY_INIT_FILE = Path("invokeai.init") + class InvokeAISettings(BaseSettings): - ''' + """ Runtime configuration settings in which default values are read from an omegaconf .yaml file. - ''' - initconf : ClassVar[DictConfig] = None - argparse_groups : ClassVar[Dict] = {} + """ - def parse_args(self, argv: list=sys.argv[1:]): + initconf: ClassVar[DictConfig] = None + argparse_groups: ClassVar[Dict] = {} + + def parse_args(self, argv: list = sys.argv[1:]): parser = self.get_parser() opt = parser.parse_args(argv) for name in self.__fields__: if name not in self._excluded(): - setattr(self, name, getattr(opt,name)) + setattr(self, name, getattr(opt, name)) - def to_yaml(self)->str: + def to_yaml(self) -> str: """ Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later. """ cls = self.__class__ - type = get_args(get_type_hints(cls)['type'])[0] - field_dict = dict({type:dict()}) - for name,field in self.__fields__.items(): + type = get_args(get_type_hints(cls)["type"])[0] + field_dict = dict({type: dict()}) + for name, field in self.__fields__.items(): if name in cls._excluded_from_yaml(): continue category = field.field_info.extra.get("category") or "Uncategorized" - value = getattr(self,name) + value = getattr(self, name) if category not in field_dict[type]: field_dict[type][category] = dict() # keep paths as strings to make it easier to read - field_dict[type][category][name] = str(value) if isinstance(value,Path) else value + field_dict[type][category][name] = str(value) if isinstance(value, Path) else value conf = OmegaConf.create(field_dict) return OmegaConf.to_yaml(conf) @classmethod def add_parser_arguments(cls, parser): - if 'type' in get_type_hints(cls): - settings_stanza = get_args(get_type_hints(cls)['type'])[0] + if "type" in get_type_hints(cls): + settings_stanza = get_args(get_type_hints(cls)["type"])[0] else: settings_stanza = "Uncategorized" - env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper() + env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper() - initconf = cls.initconf.get(settings_stanza) \ - if cls.initconf and settings_stanza in cls.initconf \ - else OmegaConf.create() + initconf = ( + cls.initconf.get(settings_stanza) + if cls.initconf and settings_stanza in cls.initconf + else OmegaConf.create() + ) # create an upcase version of the environment in # order to achieve case-insensitive environment # variables (the way Windows does) upcase_environ = dict() - for key,value in os.environ.items(): + for key, value in os.environ.items(): upcase_environ[key.upper()] = value fields = cls.__fields__ @@ -237,8 +241,8 @@ class InvokeAISettings(BaseSettings): if name not in cls._excluded(): current_default = field.default - category = field.field_info.extra.get("category","Uncategorized") - env_name = env_prefix + '_' + name + category = field.field_info.extra.get("category", "Uncategorized") + env_name = env_prefix + "_" + name if category in initconf and name in initconf.get(category): field.default = initconf.get(category).get(name) if env_name.upper() in upcase_environ: @@ -248,15 +252,15 @@ class InvokeAISettings(BaseSettings): field.default = current_default @classmethod - def cmd_name(self, command_field: str='type')->str: + def cmd_name(self, command_field: str = "type") -> str: hints = get_type_hints(self) if command_field in hints: return get_args(hints[command_field])[0] else: - return 'Uncategorized' + return "Uncategorized" @classmethod - def get_parser(cls)->ArgumentParser: + def get_parser(cls) -> ArgumentParser: parser = PagingArgumentParser( prog=cls.cmd_name(), description=cls.__doc__, @@ -269,24 +273,41 @@ class InvokeAISettings(BaseSettings): parser.add_parser(cls.cmd_name(), help=cls.__doc__) @classmethod - def _excluded(self)->List[str]: + def _excluded(self) -> List[str]: # internal fields that shouldn't be exposed as command line options - return ['type','initconf'] - + return ["type", "initconf"] + @classmethod - def _excluded_from_yaml(self)->List[str]: + def _excluded_from_yaml(self) -> List[str]: # combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options - return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root', 'nsfw_checker'] + return [ + "type", + "initconf", + "gpu_mem_reserved", + "max_loaded_models", + "version", + "from_file", + "model", + "restore", + "root", + "nsfw_checker", + ] class Config: - env_file_encoding = 'utf-8' + env_file_encoding = "utf-8" arbitrary_types_allowed = True case_sensitive = True @classmethod - def add_field_argument(cls, command_parser, name: str, field, default_override = None): + def add_field_argument(cls, command_parser, name: str, field, default_override=None): field_type = get_type_hints(cls).get(name) - default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory() + default = ( + default_override + if default_override is not None + else field.default + if field.default_factory is None + else field.default_factory() + ) if category := field.field_info.extra.get("category"): if category not in cls.argparse_groups: cls.argparse_groups[category] = command_parser.add_argument_group(category) @@ -315,10 +336,10 @@ class InvokeAISettings(BaseSettings): argparse_group.add_argument( f"--{name}", dest=name, - nargs='*', + nargs="*", type=field.type_, default=default, - action=argparse.BooleanOptionalAction if field.type_==bool else 'store', + action=argparse.BooleanOptionalAction if field.type_ == bool else "store", help=field.field_info.description, ) else: @@ -327,31 +348,35 @@ class InvokeAISettings(BaseSettings): dest=name, type=field.type_, default=default, - action=argparse.BooleanOptionalAction if field.type_==bool else 'store', + action=argparse.BooleanOptionalAction if field.type_ == bool else "store", help=field.field_info.description, ) -def _find_root()->Path: + + +def _find_root() -> Path: venv = Path(os.environ.get("VIRTUAL_ENV") or ".") if os.environ.get("INVOKEAI_ROOT"): root = Path(os.environ.get("INVOKEAI_ROOT")).resolve() - elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]): + elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]): root = (venv.parent).resolve() else: root = Path("~/invokeai").expanduser().resolve() return root + class InvokeAIAppConfig(InvokeAISettings): - ''' -Generate images using Stable Diffusion. Use "invokeai" to launch -the command-line client (recommended for experts only), or -"invokeai-web" to launch the web server. Global options -can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by -setting environment variables INVOKEAI_. - ''' + """ + Generate images using Stable Diffusion. Use "invokeai" to launch + the command-line client (recommended for experts only), or + "invokeai-web" to launch the web server. Global options + can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by + setting environment variables INVOKEAI_. + """ + singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_init: ClassVar[Dict] = None - #fmt: off + # fmt: off type: Literal["InvokeAI"] = "InvokeAI" host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') port : int = Field(default=9090, description="Port to bind to", category='Web Server') @@ -399,16 +424,16 @@ setting environment variables INVOKEAI_. log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") - #fmt: on + # fmt: on - def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False): - ''' + def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): + """ Update settings with contents of init file, environment, and command-line settings. :param conf: alternate Omegaconf dictionary object :param argv: aternate sys.argv list :param clobber: ovewrite any initialization parameters passed during initialization - ''' + """ # Set the runtime root directory. We parse command-line switches here # in order to pick up the --root_dir option. super().parse_args(argv) @@ -425,135 +450,139 @@ setting environment variables INVOKEAI_. if self.singleton_init and not clobber: hints = get_type_hints(self.__class__) for k in self.singleton_init: - setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k])) + setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k])) @classmethod - def get_config(cls,**kwargs)->InvokeAIAppConfig: - ''' + def get_config(cls, **kwargs) -> InvokeAIAppConfig: + """ This returns a singleton InvokeAIAppConfig configuration object. - ''' - if cls.singleton_config is None \ - or type(cls.singleton_config)!=cls \ - or (kwargs and cls.singleton_init != kwargs): + """ + if ( + cls.singleton_config is None + or type(cls.singleton_config) != cls + or (kwargs and cls.singleton_init != kwargs) + ): cls.singleton_config = cls(**kwargs) cls.singleton_init = kwargs return cls.singleton_config @property - def root_path(self)->Path: - ''' + def root_path(self) -> Path: + """ Path to the runtime root directory - ''' + """ if self.root: return Path(self.root).expanduser().absolute() else: return self.find_root() @property - def root_dir(self)->Path: - ''' + def root_dir(self) -> Path: + """ Alias for above. - ''' + """ return self.root_path - def _resolve(self,partial_path:Path)->Path: + def _resolve(self, partial_path: Path) -> Path: return (self.root_path / partial_path).resolve() @property - def init_file_path(self)->Path: - ''' + def init_file_path(self) -> Path: + """ Path to invokeai.yaml - ''' + """ return self._resolve(INIT_FILE) @property - def output_path(self)->Path: - ''' + def output_path(self) -> Path: + """ Path to defaults outputs directory. - ''' + """ return self._resolve(self.outdir) @property - def db_path(self)->Path: - ''' + def db_path(self) -> Path: + """ Path to the invokeai.db file. - ''' + """ return self._resolve(self.db_dir) / DB_FILE @property - def model_conf_path(self)->Path: - ''' + def model_conf_path(self) -> Path: + """ Path to models configuration file. - ''' + """ return self._resolve(self.conf_path) @property - def legacy_conf_path(self)->Path: - ''' + def legacy_conf_path(self) -> Path: + """ Path to directory of legacy configuration files (e.g. v1-inference.yaml) - ''' + """ return self._resolve(self.legacy_conf_dir) @property - def models_path(self)->Path: - ''' + def models_path(self) -> Path: + """ Path to the models directory - ''' + """ return self._resolve(self.models_dir) @property - def autoconvert_path(self)->Path: - ''' + def autoconvert_path(self) -> Path: + """ Path to the directory containing models to be imported automatically at startup. - ''' + """ return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None # the following methods support legacy calls leftover from the Globals era @property - def full_precision(self)->bool: + def full_precision(self) -> bool: """Return true if precision set to float32""" - return self.precision=='float32' + return self.precision == "float32" @property - def disable_xformers(self)->bool: + def disable_xformers(self) -> bool: """Return true if xformers_enabled is false""" return not self.xformers_enabled @property - def try_patchmatch(self)->bool: + def try_patchmatch(self) -> bool: """Return true if patchmatch true""" return self.patchmatch @property - def nsfw_checker(self)->bool: - """ NSFW node is always active and disabled from Web UIe""" + def nsfw_checker(self) -> bool: + """NSFW node is always active and disabled from Web UIe""" return True @property - def invisible_watermark(self)->bool: - """ invisible watermark node is always active and disabled from Web UIe""" + def invisible_watermark(self) -> bool: + """invisible watermark node is always active and disabled from Web UIe""" return True - + @staticmethod - def find_root()->Path: - ''' + def find_root() -> Path: + """ Choose the runtime root directory when not specified on command line or init file. - ''' + """ return _find_root() class PagingArgumentParser(argparse.ArgumentParser): - ''' + """ A custom ArgumentParser that uses pydoc to page its output. It also supports reading defaults from an init file. - ''' + """ + def print_help(self, file=None): text = self.format_help() pydoc.pager(text) -def get_invokeai_config(**kwargs)->InvokeAIAppConfig: - ''' + +def get_invokeai_config(**kwargs) -> InvokeAIAppConfig: + """ Legacy function which returns InvokeAIAppConfig.get_config() - ''' + """ return InvokeAIAppConfig.get_config(**kwargs) diff --git a/invokeai/app/services/default_graphs.py b/invokeai/app/services/default_graphs.py index 22e35d1d6b..cafb6f0339 100644 --- a/invokeai/app/services/default_graphs.py +++ b/invokeai/app/services/default_graphs.py @@ -7,57 +7,80 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr from .item_storage import ItemStorageABC -default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74' +default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74" def create_text_to_image() -> LibraryGraph: return LibraryGraph( id=default_text_to_image_graph_id, - name='t2i', - description='Converts text to an image', + name="t2i", + description="Converts text to an image", graph=Graph( nodes={ - 'width': ParamIntInvocation(id='width', a=512), - 'height': ParamIntInvocation(id='height', a=512), - 'seed': ParamIntInvocation(id='seed', a=-1), - '3': NoiseInvocation(id='3'), - '4': CompelInvocation(id='4'), - '5': CompelInvocation(id='5'), - '6': TextToLatentsInvocation(id='6'), - '7': LatentsToImageInvocation(id='7'), - '8': ImageNSFWBlurInvocation(id='8'), + "width": ParamIntInvocation(id="width", a=512), + "height": ParamIntInvocation(id="height", a=512), + "seed": ParamIntInvocation(id="seed", a=-1), + "3": NoiseInvocation(id="3"), + "4": CompelInvocation(id="4"), + "5": CompelInvocation(id="5"), + "6": TextToLatentsInvocation(id="6"), + "7": LatentsToImageInvocation(id="7"), + "8": ImageNSFWBlurInvocation(id="8"), }, edges=[ - Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), - Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), - Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')), - Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')), - Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')), - Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')), - Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')), - Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')), - ] + Edge( + source=EdgeConnection(node_id="width", field="a"), + destination=EdgeConnection(node_id="3", field="width"), + ), + Edge( + source=EdgeConnection(node_id="height", field="a"), + destination=EdgeConnection(node_id="3", field="height"), + ), + Edge( + source=EdgeConnection(node_id="seed", field="a"), + destination=EdgeConnection(node_id="3", field="seed"), + ), + Edge( + source=EdgeConnection(node_id="3", field="noise"), + destination=EdgeConnection(node_id="6", field="noise"), + ), + Edge( + source=EdgeConnection(node_id="6", field="latents"), + destination=EdgeConnection(node_id="7", field="latents"), + ), + Edge( + source=EdgeConnection(node_id="4", field="conditioning"), + destination=EdgeConnection(node_id="6", field="positive_conditioning"), + ), + Edge( + source=EdgeConnection(node_id="5", field="conditioning"), + destination=EdgeConnection(node_id="6", field="negative_conditioning"), + ), + Edge( + source=EdgeConnection(node_id="7", field="image"), + destination=EdgeConnection(node_id="8", field="image"), + ), + ], ), exposed_inputs=[ - ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'), - ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'), - ExposedNodeInput(node_path='width', field='a', alias='width'), - ExposedNodeInput(node_path='height', field='a', alias='height'), - ExposedNodeInput(node_path='seed', field='a', alias='seed'), + ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"), + ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"), + ExposedNodeInput(node_path="width", field="a", alias="width"), + ExposedNodeInput(node_path="height", field="a", alias="height"), + ExposedNodeInput(node_path="seed", field="a", alias="seed"), ], - exposed_outputs=[ - ExposedNodeOutput(node_path='8', field='image', alias='image') - ]) + exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")], + ) def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]: """Creates the default system graphs, or adds new versions if the old ones don't match""" - + # TODO: Uncomment this when we are ready to fix this up to prevent breaking changes graphs: list[LibraryGraph] = list() # text_to_image = graph_library.get(default_text_to_image_graph_id) - + # # TODO: Check if the graph is the same as the default one, and if not, update it # #if text_to_image is None: text_to_image = create_text_to_image() diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 73d74de2d9..30fa89bd29 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -44,9 +44,7 @@ class EventServiceBase: graph_execution_state_id=graph_execution_state_id, node=node, source_node_id=source_node_id, - progress_image=progress_image.dict() - if progress_image is not None - else None, + progress_image=progress_image.dict() if progress_image is not None else None, step=step, total_steps=total_steps, ), @@ -90,9 +88,7 @@ class EventServiceBase: ), ) - def emit_invocation_started( - self, graph_execution_state_id: str, node: dict, source_node_id: str - ) -> None: + def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None: """Emitted when an invocation has started""" self.__emit_session_event( event_name="invocation_started", diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 24096da29b..d7f021df14 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -28,6 +28,7 @@ from ..invocations.baseinvocation import ( # in 3.10 this would be "from types import NoneType" NoneType = type(None) + class EdgeConnection(BaseModel): node_id: str = Field(description="The id of the node for this edge connection") field: str = Field(description="The field for this connection") @@ -61,6 +62,7 @@ def get_input_field(node: BaseInvocation, field: str) -> Any: node_input_field = node_inputs.get(field) or None return node_input_field + def is_union_subtype(t1, t2): t1_args = get_args(t1) t2_args = get_args(t2) @@ -71,6 +73,7 @@ def is_union_subtype(t1, t2): # t1 is a Union, check that all of its types are in t2_args return all(arg in t2_args for arg in t1_args) + def is_list_or_contains_list(t): t_args = get_args(t) @@ -154,15 +157,17 @@ class GraphInvocationOutput(BaseInvocationOutput): class Config: schema_extra = { - 'required': [ - 'type', - 'image', + "required": [ + "type", + "image", ] } + # TODO: Fill this out and move to invocations class GraphInvocation(BaseInvocation): """Execute a graph""" + type: Literal["graph"] = "graph" # TODO: figure out how to create a default here @@ -182,23 +187,21 @@ class IterateInvocationOutput(BaseInvocationOutput): class Config: schema_extra = { - 'required': [ - 'type', - 'item', + "required": [ + "type", + "item", ] } + # TODO: Fill this out and move to invocations class IterateInvocation(BaseInvocation): """Iterates over a list of items""" + type: Literal["iterate"] = "iterate" - collection: list[Any] = Field( - description="The list of items to iterate over", default_factory=list - ) - index: int = Field( - description="The index, will be provided on executed iterators", default=0 - ) + collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) + index: int = Field(description="The index, will be provided on executed iterators", default=0) def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" @@ -212,12 +215,13 @@ class CollectInvocationOutput(BaseInvocationOutput): class Config: schema_extra = { - 'required': [ - 'type', - 'collection', + "required": [ + "type", + "collection", ] } + class CollectInvocation(BaseInvocation): """Collects values into a collection""" @@ -269,9 +273,7 @@ class Graph(BaseModel): if node_path in self.nodes: return (self, node_path) - node_id = ( - node_path if "." not in node_path else node_path[: node_path.index(".")] - ) + node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] if node_id not in self.nodes: raise NodeNotFoundError(f"Node {node_path} not found in graph") @@ -333,9 +335,7 @@ class Graph(BaseModel): return False # Validate all edges reference nodes in the graph - node_ids = set( - [e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges] - ) + node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]) if not all((self.has_node(node_id) for node_id in node_ids)): return False @@ -361,22 +361,14 @@ class Graph(BaseModel): # Validate all iterators # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available if not all( - ( - self._is_iterator_connection_valid(n.id) - for n in self.nodes.values() - if isinstance(n, IterateInvocation) - ) + (self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation)) ): return False # Validate all collectors # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available if not all( - ( - self._is_collector_connection_valid(n.id) - for n in self.nodes.values() - if isinstance(n, CollectInvocation) - ) + (self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation)) ): return False @@ -395,48 +387,51 @@ class Graph(BaseModel): # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): - raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists') + raise InvalidEdgeError( + f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists" + ) # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): - raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}') + raise InvalidEdgeError( + f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}" + ) # Validate that the field types are compatible - if not are_connections_compatible( - from_node, edge.source.field, to_node, edge.destination.field - ): - raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') + if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field): + raise InvalidEdgeError( + f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" + ) # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": - if not self._is_iterator_connection_valid( - edge.destination.node_id, new_input=edge.source - ): - raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') + if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source): + raise InvalidEdgeError( + f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" + ) # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": - if not self._is_iterator_connection_valid( - edge.source.node_id, new_output=edge.destination - ): - raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') + if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination): + raise InvalidEdgeError( + f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" + ) # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": - if not self._is_collector_connection_valid( - edge.destination.node_id, new_input=edge.source - ): - raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') + if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source): + raise InvalidEdgeError( + f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" + ) # Validate if collector output type matches input type (if this edge results in both being set) if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": - if not self._is_collector_connection_valid( - edge.source.node_id, new_output=edge.destination - ): - raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}') - + if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination): + raise InvalidEdgeError( + f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" + ) def has_node(self, node_path: str) -> bool: """Determines whether or not a node exists in the graph.""" @@ -465,17 +460,13 @@ class Graph(BaseModel): # Ensure the node type matches the new node if type(node) != type(new_node): - raise TypeError( - f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}" - ) + raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") # Ensure the new id is either the same or is not in the graph prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] new_path = self._get_node_path(new_node.id, prefix=prefix) if new_node.id != node.id and self.has_node(new_path): - raise NodeAlreadyInGraphError( - "Node with id {new_node.id} already exists in graph" - ) + raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph") # Set the new node in the graph graph.nodes[new_node.id] = new_node @@ -497,9 +488,7 @@ class Graph(BaseModel): graph.add_edge( Edge( source=edge.source, - destination=EdgeConnection( - node_id=new_graph_node_path, field=edge.destination.field - ) + destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field), ) ) @@ -512,16 +501,12 @@ class Graph(BaseModel): ) graph.add_edge( Edge( - source=EdgeConnection( - node_id=new_graph_node_path, field=edge.source.field - ), - destination=edge.destination + source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field), + destination=edge.destination, ) ) - def _get_input_edges( - self, node_path: str, field: Optional[str] = None - ) -> list[Edge]: + def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]: """Gets all input edges for a node""" edges = self._get_input_edges_and_graphs(node_path) @@ -538,7 +523,7 @@ class Graph(BaseModel): destination=EdgeConnection( node_id=self._get_node_path(e.destination.node_id, prefix=prefix), field=e.destination.field, - ) + ), ) for _, prefix, e in filtered_edges ] @@ -550,32 +535,20 @@ class Graph(BaseModel): edges = list() # Return any input edges that appear in this graph - edges.extend( - [(self, prefix, e) for e in self.edges if e.destination.node_id == node_path] - ) + edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) - node_id = ( - node_path if "." not in node_path else node_path[: node_path.index(".")] - ) + node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] node = self.nodes[node_id] if isinstance(node, GraphInvocation): graph = node.graph - graph_path = ( - node.id - if prefix is None or prefix == "" - else self._get_node_path(node.id, prefix=prefix) - ) - graph_edges = graph._get_input_edges_and_graphs( - node_path[(len(node_id) + 1) :], prefix=graph_path - ) + graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) + graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) edges.extend(graph_edges) return edges - def _get_output_edges( - self, node_path: str, field: str - ) -> list[Edge]: + def _get_output_edges(self, node_path: str, field: str) -> list[Edge]: """Gets all output edges for a node""" edges = self._get_output_edges_and_graphs(node_path) @@ -592,7 +565,7 @@ class Graph(BaseModel): destination=EdgeConnection( node_id=self._get_node_path(e.destination.node_id, prefix=prefix), field=e.destination.field, - ) + ), ) for _, prefix, e in filtered_edges ] @@ -604,25 +577,15 @@ class Graph(BaseModel): edges = list() # Return any input edges that appear in this graph - edges.extend( - [(self, prefix, e) for e in self.edges if e.source.node_id == node_path] - ) + edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) - node_id = ( - node_path if "." not in node_path else node_path[: node_path.index(".")] - ) + node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] node = self.nodes[node_id] if isinstance(node, GraphInvocation): graph = node.graph - graph_path = ( - node.id - if prefix is None or prefix == "" - else self._get_node_path(node.id, prefix=prefix) - ) - graph_edges = graph._get_output_edges_and_graphs( - node_path[(len(node_id) + 1) :], prefix=graph_path - ) + graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) + graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) edges.extend(graph_edges) return edges @@ -646,12 +609,8 @@ class Graph(BaseModel): return False # Get input and output fields (the fields linked to the iterator's input/output) - input_field = get_output_field( - self.get_node(inputs[0].node_id), inputs[0].field - ) - output_fields = list( - [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] - ) + input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) # Input type must be a list if get_origin(input_field) != list: @@ -659,12 +618,7 @@ class Graph(BaseModel): # Validate that all outputs match the input type input_field_item_type = get_args(input_field)[0] - if not all( - ( - are_connection_types_compatible(input_field_item_type, f) - for f in output_fields - ) - ): + if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): return False return True @@ -684,35 +638,21 @@ class Graph(BaseModel): outputs.append(new_output) # Get input and output fields (the fields linked to the iterator's input/output) - input_fields = list( - [get_output_field(self.get_node(e.node_id), e.field) for e in inputs] - ) - output_fields = list( - [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] - ) + input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs]) + output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs]) # Validate that all inputs are derived from or match a single type input_field_types = set( [ t for input_field in input_fields - for t in ( - [input_field] - if get_origin(input_field) == None - else get_args(input_field) - ) + for t in ([input_field] if get_origin(input_field) == None else get_args(input_field)) if t != NoneType ] ) # Get unique types type_tree = nx.DiGraph() type_tree.add_nodes_from(input_field_types) - type_tree.add_edges_from( - [ - e - for e in itertools.permutations(input_field_types, 2) - if issubclass(e[1], e[0]) - ] - ) + type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_degrees = type_tree.in_degree(type_tree.nodes) if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore return False # There is more than one root type @@ -729,9 +669,7 @@ class Graph(BaseModel): return False # Verify that all outputs match the input type (are a base class or the same class) - if not all( - (issubclass(input_root_type, get_args(f)[0]) for f in output_fields) - ): + if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)): return False return True @@ -751,9 +689,7 @@ class Graph(BaseModel): g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) return g - def nx_graph_flat( - self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None - ) -> nx.DiGraph: + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() @@ -762,26 +698,18 @@ class Graph(BaseModel): [ self._get_node_path(n.id, prefix) for n in self.nodes.values() - if not isinstance(n, GraphInvocation) - and not isinstance(n, IterateInvocation) + if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation) ] ) # Expand graph nodes - for sgn in ( - gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation) - ): + for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) # TODO: figure out if iteration nodes need to be expanded unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) - g.add_edges_from( - [ - (self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) - for e in unique_edges - ] - ) + g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) return g @@ -800,23 +728,19 @@ class GraphExecutionState(BaseModel): ) # Nodes that have been executed - executed: set[str] = Field( - description="The set of node ids that have been executed", default_factory=set - ) + executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set) executed_history: list[str] = Field( description="The list of node ids that have been executed, in order of execution", default_factory=list, ) # The results of executed nodes - results: dict[ - str, Annotated[InvocationOutputsUnion, Field(discriminator="type")] - ] = Field(description="The results of node executions", default_factory=dict) + results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field( + description="The results of node executions", default_factory=dict + ) # Errors raised when executing nodes - errors: dict[str, str] = Field( - description="Errors raised when executing nodes", default_factory=dict - ) + errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) # Map of prepared/executed nodes to their original nodes prepared_source_mapping: dict[str, str] = Field( @@ -832,16 +756,16 @@ class GraphExecutionState(BaseModel): class Config: schema_extra = { - 'required': [ - 'id', - 'graph', - 'execution_graph', - 'executed', - 'executed_history', - 'results', - 'errors', - 'prepared_source_mapping', - 'source_prepared_mapping', + "required": [ + "id", + "graph", + "execution_graph", + "executed", + "executed_history", + "results", + "errors", + "prepared_source_mapping", + "source_prepared_mapping", ] } @@ -899,9 +823,7 @@ class GraphExecutionState(BaseModel): """Returns true if the graph has any errors""" return len(self.errors) > 0 - def _create_execution_node( - self, node_path: str, iteration_node_map: list[tuple[str, str]] - ) -> list[str]: + def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" node = self.graph.get_node(node_path) @@ -911,20 +833,12 @@ class GraphExecutionState(BaseModel): # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) - input_collection_edge = next( - iter(self.graph._get_input_edges(node_path, "collection")) - ) + input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection"))) input_collection_prepared_node_id = next( - n[1] - for n in iteration_node_map - if n[0] == input_collection_edge.source.node_id - ) - input_collection_prepared_node_output = self.results[ - input_collection_prepared_node_id - ] - input_collection = getattr( - input_collection_prepared_node_output, input_collection_edge.source.field + n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) + input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] + input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field) self_iteration_count = len(input_collection) new_nodes = list() @@ -939,9 +853,7 @@ class GraphExecutionState(BaseModel): # For collect nodes, this may contain multiple inputs to the same field new_edges = list() for edge in input_edges: - for input_node_id in ( - n[1] for n in iteration_node_map if n[0] == edge.source.node_id - ): + for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id): new_edge = Edge( source=EdgeConnection(node_id=input_node_id, field=edge.source.field), destination=EdgeConnection(node_id="", field=edge.destination.field), @@ -982,11 +894,7 @@ class GraphExecutionState(BaseModel): def _iterator_graph(self) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" g = self.graph.nx_graph_flat() - collectors = ( - n - for n in self.graph.nodes - if isinstance(self.graph.get_node(n), CollectInvocation) - ) + collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation)) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) return g @@ -994,11 +902,7 @@ class GraphExecutionState(BaseModel): def _get_node_iterators(self, node_id: str) -> list[str]: """Gets iterators for a node""" g = self._iterator_graph() - iterators = [ - n - for n in nx.ancestors(g, node_id) - if isinstance(self.graph.get_node(n), IterateInvocation) - ] + iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] return iterators def _prepare(self) -> Optional[str]: @@ -1045,29 +949,18 @@ class GraphExecutionState(BaseModel): if isinstance(next_node, CollectInvocation): # Collapse all iterator input mappings and create a single execution node for the collect invocation all_iteration_mappings = list( - itertools.chain( - *( - ((s, p) for p in self.source_prepared_mapping[s]) - for s in next_node_parents - ) - ) + itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)) ) # all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) - create_results = self._create_execution_node( - next_node_id, all_iteration_mappings - ) + create_results = self._create_execution_node(next_node_id, all_iteration_mappings) if create_results is not None: new_node_ids.extend(create_results) else: # Iterators or normal nodes # Get all iterator combinations for this node # Will produce a list of lists of prepared iterator nodes, from which results can be iterated iterator_nodes = self._get_node_iterators(next_node_id) - iterator_nodes_prepared = [ - list(self.source_prepared_mapping[n]) for n in iterator_nodes - ] - iterator_node_prepared_combinations = list( - itertools.product(*iterator_nodes_prepared) - ) + iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] + iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) # Select the correct prepared parents for each iteration # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator @@ -1096,31 +989,16 @@ class GraphExecutionState(BaseModel): return next(iter(prepared_nodes)) # Check if the requested node is an iterator - prepared_iterator = next( - (n for n in prepared_nodes if n in prepared_iterator_nodes), None - ) + prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) if prepared_iterator is not None: return prepared_iterator # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) - iterator_source_node_mapping = [ - (n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes - ] - parent_iterators = [ - itn - for itn in iterator_source_node_mapping - if nx.has_path(graph, itn[1], source_node_path) - ] + iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] return next( - ( - n - for n in prepared_nodes - if all( - nx.has_path(execution_graph, pit[0], n) - for pit in parent_iterators - ) - ), + (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), None, ) @@ -1130,13 +1008,13 @@ class GraphExecutionState(BaseModel): # Depth-first search with pre-order traversal is a depth-first topological sort sorted_nodes = nx.dfs_preorder_nodes(g) - + next_node = next( ( n for n in sorted_nodes - if n not in self.executed # the node must not already be executed... - and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed + if n not in self.executed # the node must not already be executed... + and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed ), None, ) @@ -1221,15 +1099,18 @@ class ExposedNodeOutput(BaseModel): field: str = Field(description="The field name of the output") alias: str = Field(description="The alias of the output") + class LibraryGraph(BaseModel): id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4) graph: Graph = Field(description="The graph") name: str = Field(description="The name of the graph") description: str = Field(description="The description of the graph") exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list) - exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list) + exposed_outputs: list[ExposedNodeOutput] = Field( + description="The outputs exposed by this graph", default_factory=list + ) - @validator('exposed_inputs', 'exposed_outputs') + @validator("exposed_inputs", "exposed_outputs") def validate_exposed_aliases(cls, v): if len(v) != len(set(i.alias for i in v)): raise ValueError("Duplicate exposed alias") @@ -1237,23 +1118,27 @@ class LibraryGraph(BaseModel): @root_validator def validate_exposed_nodes(cls, values): - graph = values['graph'] + graph = values["graph"] # Validate exposed inputs - for exposed_input in values['exposed_inputs']: + for exposed_input in values["exposed_inputs"]: if not graph.has_node(exposed_input.node_path): raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist") node = graph.get_node(exposed_input.node_path) if get_input_field(node, exposed_input.field) is None: - raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}") + raise ValueError( + f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}" + ) # Validate exposed outputs - for exposed_output in values['exposed_outputs']: + for exposed_output in values["exposed_outputs"]: if not graph.has_node(exposed_output.node_path): raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist") node = graph.get_node(exposed_output.node_path) if get_output_field(node, exposed_output.field) is None: - raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}") + raise ValueError( + f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}" + ) return values diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 60ae613748..fb8563a3e4 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -85,9 +85,7 @@ class DiskImageFileStorage(ImageFileStorageBase): self.__cache_ids = Queue() self.__max_cache_size = 10 # TODO: get this from config - self.__output_folder: Path = ( - output_folder if isinstance(output_folder, Path) else Path(output_folder) - ) + self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__thumbnails_folder = self.__output_folder / "thumbnails" # Validate required output folders at launch @@ -120,7 +118,7 @@ class DiskImageFileStorage(ImageFileStorageBase): image_path = self.get_path(image_name) pnginfo = PngImagePlugin.PngInfo() - + if metadata is not None: pnginfo.add_text("invokeai_metadata", json.dumps(metadata)) if graph is not None: @@ -183,9 +181,7 @@ class DiskImageFileStorage(ImageFileStorageBase): def __set_cache(self, image_name: Path, image: PILImageType): if not image_name in self.__cache: self.__cache[image_name] = image - self.__cache_ids.put( - image_name - ) # TODO: this should refresh position for LRU cache + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache if len(self.__cache) > self.__max_cache_size: cache_id = self.__cache_ids.get() if cache_id in self.__cache: diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index eb69679a35..8c274ab8f9 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -426,9 +426,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - return OffsetPaginatedResults( - items=images, offset=offset, limit=limit, total=count - ) + return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count) def delete(self, image_name: str) -> None: try: @@ -466,7 +464,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - def delete_intermediates(self) -> list[str]: try: self._lock.acquire() @@ -505,9 +502,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): is_intermediate: bool = False, ) -> datetime: try: - metadata_json = ( - None if metadata is None else json.dumps(metadata) - ) + metadata_json = None if metadata is None else json.dumps(metadata) self._lock.acquire() self._cursor.execute( """--sql diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 007f4021d4..f8376eb626 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -217,12 +217,8 @@ class ImageService(ImageServiceABC): session_id=session_id, ) if board_id is not None: - self._services.board_image_records.add_image_to_board( - board_id=board_id, image_name=image_name - ) - self._services.image_files.save( - image_name=image_name, image=image, metadata=metadata, graph=graph - ) + self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name) + self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph) image_dto = self.get_dto(image_name) return image_dto @@ -297,9 +293,7 @@ class ImageService(ImageServiceABC): if not image_record.session_id: return ImageMetadata() - session_raw = self._services.graph_execution_manager.get_raw( - image_record.session_id - ) + session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id) graph = None if session_raw: @@ -364,9 +358,7 @@ class ImageService(ImageServiceABC): r, self._services.urls.get_image_url(r.image_name), self._services.urls.get_image_url(r.image_name, True), - self._services.board_image_records.get_board_for_image( - r.image_name - ), + self._services.board_image_records.get_board_for_image(r.image_name), ), results.items, ) @@ -398,11 +390,7 @@ class ImageService(ImageServiceABC): def delete_images_on_board(self, board_id: str): try: - image_names = ( - self._services.board_image_records.get_all_board_image_names_for_board( - board_id - ) - ) + image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id) for image_name in image_names: self._services.image_files.delete(image_name) self._services.image_records.delete_many(image_names) diff --git a/invokeai/app/services/invocation_queue.py b/invokeai/app/services/invocation_queue.py index eb78a542a6..963d500aa8 100644 --- a/invokeai/app/services/invocation_queue.py +++ b/invokeai/app/services/invocation_queue.py @@ -7,6 +7,7 @@ from queue import Queue from pydantic import BaseModel, Field from typing import Optional + class InvocationQueueItem(BaseModel): graph_execution_state_id: str = Field(description="The ID of the graph execution state") invocation_id: str = Field(description="The ID of the node being invoked") @@ -45,9 +46,11 @@ class MemoryInvocationQueue(InvocationQueueABC): def get(self) -> InvocationQueueItem: item = self.__queue.get() - while isinstance(item, InvocationQueueItem) \ - and item.graph_execution_state_id in self.__cancellations \ - and self.__cancellations[item.graph_execution_state_id] > item.timestamp: + while ( + isinstance(item, InvocationQueueItem) + and item.graph_execution_state_id in self.__cancellations + and self.__cancellations[item.graph_execution_state_id] > item.timestamp + ): item = self.__queue.get() # Clear old items diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 951d3b17c4..1a7b0de27e 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -7,6 +7,7 @@ from .graph import Graph, GraphExecutionState from .invocation_queue import InvocationQueueItem from .invocation_services import InvocationServices + class Invoker: """The invoker, used to execute invocations""" @@ -16,9 +17,7 @@ class Invoker: self.services = services self._start() - def invoke( - self, graph_execution_state: GraphExecutionState, invoke_all: bool = False - ) -> Optional[str]: + def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]: """Determines the next node to invoke and enqueues it, preparing if needed. Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" diff --git a/invokeai/app/services/item_storage.py b/invokeai/app/services/item_storage.py index 709d88bf97..5fe4eb7456 100644 --- a/invokeai/app/services/item_storage.py +++ b/invokeai/app/services/item_storage.py @@ -9,13 +9,15 @@ T = TypeVar("T", bound=BaseModel) class PaginatedResults(GenericModel, Generic[T]): """Paginated results""" - #fmt: off + + # fmt: off items: list[T] = Field(description="Items") page: int = Field(description="Current Page") pages: int = Field(description="Total number of pages") per_page: int = Field(description="Number of items per page") total: int = Field(description="Total number of items in result") - #fmt: on + # fmt: on + class ItemStorageABC(ABC, Generic[T]): _on_changed_callbacks: list[Callable[[T], None]] @@ -48,9 +50,7 @@ class ItemStorageABC(ABC, Generic[T]): pass @abstractmethod - def search( - self, query: str, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[T]: + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: pass def on_changed(self, on_changed: Callable[[T], None]) -> None: diff --git a/invokeai/app/services/latent_storage.py b/invokeai/app/services/latent_storage.py index 0e23d6d018..7e781c49ec 100644 --- a/invokeai/app/services/latent_storage.py +++ b/invokeai/app/services/latent_storage.py @@ -7,6 +7,7 @@ from typing import Dict, Union, Optional import torch + class LatentsStorageBase(ABC): """Responsible for storing and retrieving latents.""" @@ -25,7 +26,7 @@ class LatentsStorageBase(ABC): class ForwardCacheLatentsStorage(LatentsStorageBase): """Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage""" - + __cache: Dict[str, torch.Tensor] __cache_ids: Queue __max_cache_size: int @@ -87,8 +88,6 @@ class DiskLatentsStorage(LatentsStorageBase): def delete(self, name: str) -> None: latent_path = self.get_path(name) latent_path.unlink() - def get_path(self, name: str) -> Path: return self.__output_folder / name - diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index f7d3b3a7a7..c84cc3d189 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -103,7 +103,7 @@ class ModelManagerServiceBase(ABC): } """ pass - + @abstractmethod def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict: """ @@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC): base_model: BaseModelType, model_type: ModelType, model_attributes: dict, - clobber: bool = False + clobber: bool = False, ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with an @@ -148,12 +148,12 @@ class ModelManagerServiceBase(ABC): Update the named model with a dictionary of attributes. Will fail with a ModelNotFoundException if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ pass - + @abstractmethod def del_model( self, @@ -169,21 +169,20 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def rename_model(self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str, - ): + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str, + ): """ Rename the indicated model. """ pass @abstractmethod - def list_checkpoint_configs( - self - )->List[Path]: + def list_checkpoint_configs(self) -> List[Path]: """ List the checkpoint config paths from ROOT/configs/stable-diffusion. """ @@ -194,7 +193,7 @@ class ModelManagerServiceBase(ABC): self, model_name: str, base_model: BaseModelType, - model_type: Union[ModelType.Main,ModelType.Vae], + model_type: Union[ModelType.Main, ModelType.Vae], ) -> AddModelResult: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -211,11 +210,12 @@ class ModelManagerServiceBase(ABC): pass @abstractmethod - def heuristic_import(self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, - )->dict[str, AddModelResult]: - '''Import a list of paths, repo_ids or URLs. Returns the set of + def heuristic_import( + self, + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, + ) -> dict[str, AddModelResult]: + """Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. @@ -230,19 +230,23 @@ class ModelManagerServiceBase(ABC): The result is a set of successfully installed models. Each element of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. - ''' + """ pass @abstractmethod def merge_models( - self, - model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), - base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = None + self, + model_names: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model names to merge" + ), + base_model: Union[BaseModelType, str] = Field( + default=None, description="Base model shared by all models to be merged" + ), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = None, ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -250,27 +254,27 @@ class ModelManagerServiceBase(ABC): :param base_model: Base model to use for all models :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) + :param interp: Interpolation method. None (default) :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ pass @abstractmethod - def search_for_models(self, directory: Path)->List[Path]: + def search_for_models(self, directory: Path) -> List[Path]: """ Return list of all models found in the designated directory. """ pass - + @abstractmethod def sync_to_config(self): """ - Re-read models.yaml, rescan the models directory, and reimport models + Re-read models.yaml, rescan the models directory, and reimport models in the autoimport directories. Call after making changes outside the model manager API. """ pass - + @abstractmethod def commit(self, conf_file: Optional[Path] = None) -> None: """ @@ -280,9 +284,11 @@ class ModelManagerServiceBase(ABC): """ pass + # simple implementation class ModelManagerService(ModelManagerServiceBase): """Responsible for managing models on disk and in memory""" + def __init__( self, config: InvokeAIAppConfig, @@ -298,17 +304,17 @@ class ModelManagerService(ModelManagerServiceBase): config_file = config.model_conf_path else: config_file = config.root_dir / "configs/models.yaml" - - logger.debug(f'Config file={config_file}') + + logger.debug(f"Config file={config_file}") device = torch.device(choose_torch_device()) - device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else '' - logger.info(f'GPU device = {device} {device_name}') + device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" + logger.info(f"GPU device = {device} {device_name}") precision = config.precision if precision == "auto": precision = choose_precision(device) - dtype = torch.float32 if precision == 'float32' else torch.float16 + dtype = torch.float32 if precision == "float32" else torch.float16 # this is transitional backward compatibility # support for the deprecated `max_loaded_models` @@ -316,9 +322,7 @@ class ModelManagerService(ModelManagerServiceBase): # cache size is set to 2.5 GB times # the number of max_loaded_models. Otherwise # use new `max_cache_size` config setting - max_cache_size = config.max_cache_size \ - if hasattr(config,'max_cache_size') \ - else config.max_loaded_models * 2.5 + max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5 logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") @@ -332,7 +336,7 @@ class ModelManagerService(ModelManagerServiceBase): sequential_offload=sequential_offload, logger=logger, ) - logger.info('Model manager service initialized') + logger.info("Model manager service initialized") def get_model( self, @@ -371,7 +375,7 @@ class ModelManagerService(ModelManagerServiceBase): base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info + model_info=model_info, ) return model_info @@ -405,9 +409,7 @@ class ModelManagerService(ModelManagerServiceBase): return self.mgr.model_names() def list_models( - self, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None + self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None ) -> list[dict]: """ Return a list of models. @@ -418,9 +420,7 @@ class ModelManagerService(ModelManagerServiceBase): """ Return information about the model using the same format as list_models() """ - return self.mgr.list_model(model_name=model_name, - base_model=base_model, - model_type=model_type) + return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type) def add_model( self, @@ -429,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase): model_type: ModelType, model_attributes: dict, clobber: bool = False, - )->None: + ) -> None: """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. @@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase): with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ - self.logger.debug(f'add/update model {model_name}') + self.logger.debug(f"add/update model {model_name}") return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) def update_model( @@ -450,15 +450,15 @@ class ModelManagerService(ModelManagerServiceBase): """ Update the named model with a dictionary of attributes. Will fail with a ModelNotFoundException exception if the name does not already exist. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ - self.logger.debug(f'update model {model_name}') + self.logger.debug(f"update model {model_name}") if not self.model_exists(model_name, base_model, model_type): raise ModelNotFoundException(f"Unknown model {model_name}") return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) - + def del_model( self, model_name: str, @@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase): then the underlying weight file or diffusers directory will be deleted as well. """ - self.logger.debug(f'delete model {model_name}') + self.logger.debug(f"delete model {model_name}") self.mgr.del_model(model_name, base_model, model_type) self.mgr.commit() @@ -478,8 +478,10 @@ class ModelManagerService(ModelManagerServiceBase): self, model_name: str, base_model: BaseModelType, - model_type: Union[ModelType.Main,ModelType.Vae], - convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), + model_type: Union[ModelType.Main, ModelType.Vae], + convert_dest_directory: Optional[Path] = Field( + default=None, description="Optional directory location for merged model" + ), ) -> AddModelResult: """ Convert a checkpoint file into a diffusers folder, deleting the cached @@ -494,10 +496,10 @@ class ModelManagerService(ModelManagerServiceBase): also raise a ValueError in the event that there is a similarly-named diffusers directory already in place. """ - self.logger.debug(f'convert model {model_name}') + self.logger.debug(f"convert model {model_name}") return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory) - def commit(self, conf_file: Optional[Path]=None): + def commit(self, conf_file: Optional[Path] = None): """ Write current configuration out to the indicated file. If no conf_file is provided, then replaces the @@ -524,7 +526,7 @@ class ModelManagerService(ModelManagerServiceBase): base_model=base_model, model_type=model_type, submodel=submodel, - model_info=model_info + model_info=model_info, ) else: context.services.events.emit_model_load_started( @@ -535,16 +537,16 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, ) - @property def logger(self): return self.mgr.logger - def heuristic_import(self, - items_to_import: set[str], - prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, - )->dict[str, AddModelResult]: - '''Import a list of paths, repo_ids or URLs. Returns the set of + def heuristic_import( + self, + items_to_import: set[str], + prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, + ) -> dict[str, AddModelResult]: + """Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. @@ -559,18 +561,24 @@ class ModelManagerService(ModelManagerServiceBase): The result is a set of successfully installed models. Each element of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. - ''' - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + """ + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) def merge_models( - self, - model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"), - base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"), - merged_model_name: str = Field(default=None, description="Name of destination model after merging"), - alpha: Optional[float] = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: Optional[bool] = False, - merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"), + self, + model_names: List[str] = Field( + default=None, min_items=2, max_items=3, description="List of model names to merge" + ), + base_model: Union[BaseModelType, str] = Field( + default=None, description="Base model shared by all models to be merged" + ), + merged_model_name: str = Field(default=None, description="Name of destination model after merging"), + alpha: Optional[float] = 0.5, + interp: Optional[MergeInterpolationMethod] = None, + force: Optional[bool] = False, + merge_dest_directory: Optional[Path] = Field( + default=None, description="Optional directory location for merged model" + ), ) -> AddModelResult: """ Merge two to three diffusrs pipeline models and save as a new model. @@ -578,25 +586,25 @@ class ModelManagerService(ModelManagerServiceBase): :param base_model: Base model to use for all models :param merged_model_name: Name of destination merged model :param alpha: Alpha strength to apply to 2d and 3d model - :param interp: Interpolation method. None (default) + :param interp: Interpolation method. None (default) :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) """ merger = ModelMerger(self.mgr) try: result = merger.merge_diffusion_models_and_save( - model_names = model_names, - base_model = base_model, - merged_model_name = merged_model_name, - alpha = alpha, - interp = interp, - force = force, + model_names=model_names, + base_model=base_model, + merged_model_name=merged_model_name, + alpha=alpha, + interp=interp, + force=force, merge_dest_directory=merge_dest_directory, ) except AssertionError as e: raise ValueError(e) return result - def search_for_models(self, directory: Path)->List[Path]: + def search_for_models(self, directory: Path) -> List[Path]: """ Return list of all models found in the designated directory. """ @@ -605,28 +613,29 @@ class ModelManagerService(ModelManagerServiceBase): def sync_to_config(self): """ - Re-read models.yaml, rescan the models directory, and reimport models + Re-read models.yaml, rescan the models directory, and reimport models in the autoimport directories. Call after making changes outside the model manager API. """ return self.mgr.sync_to_config() - def list_checkpoint_configs(self)->List[Path]: + def list_checkpoint_configs(self) -> List[Path]: """ List the checkpoint config paths from ROOT/configs/stable-diffusion. """ config = self.mgr.app_config conf_path = config.legacy_conf_path root_path = config.root_path - return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')] + return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")] - def rename_model(self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str = None, - new_base: BaseModelType = None, - ): + def rename_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, + ): """ Rename the indicated model. Can provide a new name and/or a new base. :param model_name: Current name of the model @@ -635,10 +644,10 @@ class ModelManagerService(ModelManagerServiceBase): :param new_name: New name for the model :param new_base: New base for the model """ - self.mgr.rename_model(base_model = base_model, - model_type = model_type, - model_name = model_name, - new_name = new_name, - new_base = new_base, - ) - + self.mgr.rename_model( + base_model=base_model, + model_type=model_type, + model_name=model_name, + new_name=new_name, + new_base=new_base, + ) diff --git a/invokeai/app/services/models/board_record.py b/invokeai/app/services/models/board_record.py index bf5401b209..658698e794 100644 --- a/invokeai/app/services/models/board_record.py +++ b/invokeai/app/services/models/board_record.py @@ -11,30 +11,20 @@ class BoardRecord(BaseModel): """The unique ID of the board.""" board_name: str = Field(description="The name of the board.") """The name of the board.""" - created_at: Union[datetime, str] = Field( - description="The created timestamp of the board." - ) + created_at: Union[datetime, str] = Field(description="The created timestamp of the board.") """The created timestamp of the image.""" - updated_at: Union[datetime, str] = Field( - description="The updated timestamp of the board." - ) + updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.") """The updated timestamp of the image.""" - deleted_at: Union[datetime, str, None] = Field( - description="The deleted timestamp of the board." - ) + deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.") """The updated timestamp of the image.""" - cover_image_name: Optional[str] = Field( - description="The name of the cover image of the board." - ) + cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.") """The name of the cover image of the board.""" class BoardDTO(BoardRecord): """Deserialized board record with cover image URL and image count.""" - cover_image_name: Optional[str] = Field( - description="The name of the board's cover image." - ) + cover_image_name: Optional[str] = Field(description="The name of the board's cover image.") """The URL of the thumbnail of the most recent image in the board.""" image_count: int = Field(description="The number of images in the board.") """The number of images in the board.""" diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index cf10f6e8b2..a105d03ba8 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -20,17 +20,11 @@ class ImageRecord(BaseModel): """The actual width of the image in px. This may be different from the width in metadata.""" height: int = Field(description="The height of the image in px.") """The actual height of the image in px. This may be different from the height in metadata.""" - created_at: Union[datetime.datetime, str] = Field( - description="The created timestamp of the image." - ) + created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.") """The created timestamp of the image.""" - updated_at: Union[datetime.datetime, str] = Field( - description="The updated timestamp of the image." - ) + updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.") """The updated timestamp of the image.""" - deleted_at: Union[datetime.datetime, str, None] = Field( - description="The deleted timestamp of the image." - ) + deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.") """The deleted timestamp of the image.""" is_intermediate: bool = Field(description="Whether this is an intermediate image.") """Whether this is an intermediate image.""" @@ -55,18 +49,14 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): - `is_intermediate`: change the image's `is_intermediate` flag """ - image_category: Optional[ImageCategory] = Field( - description="The image's new category." - ) + image_category: Optional[ImageCategory] = Field(description="The image's new category.") """The image's new category.""" session_id: Optional[StrictStr] = Field( default=None, description="The image's new session ID.", ) """The image's new session ID.""" - is_intermediate: Optional[StrictBool] = Field( - default=None, description="The image's new `is_intermediate` flag." - ) + is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.") """The image's new `is_intermediate` flag.""" @@ -84,9 +74,7 @@ class ImageUrlsDTO(BaseModel): class ImageDTO(ImageRecord, ImageUrlsDTO): """Deserialized image record, enriched for the frontend.""" - board_id: Optional[str] = Field( - description="The id of the board the image belongs to, if one exists." - ) + board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.") """The id of the board the image belongs to, if one exists.""" pass @@ -110,12 +98,8 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # TODO: do we really need to handle default values here? ideally the data is the correct shape... image_name = image_dict.get("image_name", "unknown") - image_origin = ResourceOrigin( - image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) - ) - image_category = ImageCategory( - image_dict.get("image_category", ImageCategory.GENERAL.value) - ) + image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)) + image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value)) width = image_dict.get("width", 0) height = image_dict.get("height", 0) session_id = image_dict.get("session_id", None) diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 5995e4ffc3..50fe217e05 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -8,6 +8,8 @@ from .invoker import InvocationProcessorABC, Invoker from ..models.exceptions import CanceledException import invokeai.backend.util.logging as logger + + class DefaultInvocationProcessor(InvocationProcessorABC): __invoker_thread: Thread __stop_event: Event @@ -24,9 +26,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): target=self.__process, kwargs=dict(stop_event=self.__stop_event), ) - self.__invoker_thread.daemon = ( - True # TODO: make async and do not use threads - ) + self.__invoker_thread.daemon = True # TODO: make async and do not use threads self.__invoker_thread.start() def stop(self, *args, **kwargs) -> None: @@ -47,10 +47,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC): continue try: - graph_execution_state = ( - self.__invoker.services.graph_execution_manager.get( - queue_item.graph_execution_state_id - ) + graph_execution_state = self.__invoker.services.graph_execution_manager.get( + queue_item.graph_execution_state_id ) except Exception as e: self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e) @@ -60,11 +58,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC): error=traceback.format_exc(), ) continue - + try: - invocation = graph_execution_state.execution_graph.get_node( - queue_item.invocation_id - ) + invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id) except Exception as e: self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e) self.__invoker.services.events.emit_invocation_retrieval_error( @@ -82,7 +78,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): self.__invoker.services.events.emit_invocation_started( graph_execution_state_id=graph_execution_state.id, node=invocation.dict(), - source_node_id=source_node_id + source_node_id=source_node_id, ) # Invoke @@ -95,18 +91,14 @@ class DefaultInvocationProcessor(InvocationProcessorABC): ) # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled( - graph_execution_state.id - ): + if self.__invoker.services.queue.is_canceled(graph_execution_state.id): continue # Save outputs and history graph_execution_state.complete(invocation.id, outputs) # Save the state changes - self.__invoker.services.graph_execution_manager.set( - graph_execution_state - ) + self.__invoker.services.graph_execution_manager.set(graph_execution_state) # Send complete event self.__invoker.services.events.emit_invocation_complete( @@ -130,9 +122,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): graph_execution_state.set_node_error(invocation.id, error) # Save the state changes - self.__invoker.services.graph_execution_manager.set( - graph_execution_state - ) + self.__invoker.services.graph_execution_manager.set(graph_execution_state) self.__invoker.services.logger.error("Error while invoking:\n%s" % e) # Send error event @@ -147,9 +137,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC): pass # Check queue to see if this is canceled, and skip if so - if self.__invoker.services.queue.is_canceled( - graph_execution_state.id - ): + if self.__invoker.services.queue.is_canceled(graph_execution_state.id): continue # Queue any further commands if invoking all @@ -164,12 +152,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC): node=invocation.dict(), source_node_id=source_node_id, error_type=e.__class__.__name__, - error=traceback.format_exc() + error=traceback.format_exc(), ) elif is_complete: - self.__invoker.services.events.emit_graph_execution_complete( - graph_execution_state.id - ) + self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id) except KeyboardInterrupt: pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 8902415096..855f3f1939 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -66,9 +66,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def get(self, id: str) -> Optional[T]: try: self._lock.acquire() - self._cursor.execute( - f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) - ) + self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) result = self._cursor.fetchone() finally: self._lock.release() @@ -81,9 +79,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def get_raw(self, id: str) -> Optional[str]: try: self._lock.acquire() - self._cursor.execute( - f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),) - ) + self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)) result = self._cursor.fetchone() finally: self._lock.release() @@ -96,9 +92,7 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def delete(self, id: str): try: self._lock.acquire() - self._cursor.execute( - f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),) - ) + self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)) self._conn.commit() finally: self._lock.release() @@ -122,13 +116,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): pageCount = int(count / per_page) + 1 - return PaginatedResults[T]( - items=items, page=page, pages=pageCount, per_page=per_page, total=count - ) + return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) - def search( - self, query: str, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[T]: + def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]: try: self._lock.acquire() self._cursor.execute( @@ -149,6 +139,4 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): pageCount = int(count / per_page) + 1 - return PaginatedResults[T]( - items=items, page=page, pages=pageCount, per_page=per_page, total=count - ) + return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count) diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py index 342fa147c5..6f2c49fec2 100644 --- a/invokeai/app/util/controlnet_utils.py +++ b/invokeai/app/util/controlnet_utils.py @@ -17,16 +17,8 @@ from controlnet_aux.util import HWC3, resize_image # If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet. lvmin_kernels_raw = [ - np.array([ - [-1, -1, -1], - [0, 1, 0], - [1, 1, 1] - ], dtype=np.int32), - np.array([ - [0, -1, -1], - [1, 1, -1], - [0, 1, 0] - ], dtype=np.int32) + np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32), + np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32), ] lvmin_kernels = [] @@ -36,16 +28,8 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw] lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw] lvmin_prunings_raw = [ - np.array([ - [-1, -1, -1], - [-1, 1, -1], - [0, 0, -1] - ], dtype=np.int32), - np.array([ - [-1, -1, -1], - [-1, 1, -1], - [-1, 0, 0] - ], dtype=np.int32) + np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32), + np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32), ] lvmin_prunings = [] @@ -99,10 +83,10 @@ def nake_nms(x): ################################################################################ # FIXME: not using yet, if used in the future will most likely require modification of preprocessors def pixel_perfect_resolution( - image: np.ndarray, - target_H: int, - target_W: int, - resize_mode: str, + image: np.ndarray, + target_H: int, + target_W: int, + resize_mode: str, ) -> int: """ Calculate the estimated resolution for resizing an image while preserving aspect ratio. @@ -135,7 +119,7 @@ def pixel_perfect_resolution( if resize_mode == "fill_resize": estimation = min(k0, k1) * float(min(raw_H, raw_W)) - else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?) + else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?) estimation = max(k0, k1) * float(min(raw_H, raw_W)) # print(f"Pixel Perfect Computation:") @@ -154,13 +138,7 @@ def pixel_perfect_resolution( # modified for InvokeAI ########################################################################### # def detectmap_proc(detected_map, module, resize_mode, h, w): -def np_img_resize( - np_img: np.ndarray, - resize_mode: str, - h: int, - w: int, - device: torch.device = torch.device('cpu') -): +def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")): # if 'inpaint' in module: # np_img = np_img.astype(np.float32) # else: @@ -184,15 +162,14 @@ def np_img_resize( # below is very boring but do not change these. If you change these Apple or Mac may fail. y = torch.from_numpy(y) y = y.float() / 255.0 - y = rearrange(y, 'h w c -> 1 c h w') + y = rearrange(y, "h w c -> 1 c h w") y = y.clone() # y = y.to(devices.get_device_for("controlnet")) y = y.to(device) y = y.clone() return y - def high_quality_resize(x: np.ndarray, - size): + def high_quality_resize(x: np.ndarray, size): # Written by lvmin # Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges inpaint_mask = None @@ -244,7 +221,7 @@ def np_img_resize( return y # if resize_mode == external_code.ResizeMode.RESIZE: - if resize_mode == "just_resize": # RESIZE + if resize_mode == "just_resize": # RESIZE np_img = high_quality_resize(np_img, (w, h)) np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img @@ -270,20 +247,21 @@ def np_img_resize( new_h, new_w, _ = np_img.shape pad_h = max(0, (h - new_h) // 2) pad_w = max(0, (w - new_w) // 2) - high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img + high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img np_img = high_quality_background np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img - else: # resize_mode == "crop_resize" (INNER_FIT) + else: # resize_mode == "crop_resize" (INNER_FIT) k = max(k0, k1) np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) new_h, new_w, _ = np_img.shape pad_h = max(0, (new_h - h) // 2) pad_w = max(0, (new_w - w) // 2) - np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w] + np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w] np_img = safe_numpy(np_img) return get_pytorch_control(np_img), np_img + def prepare_control_image( # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]] # but now should be able to assume that image is a single PIL.Image, which simplifies things @@ -301,15 +279,17 @@ def prepare_control_image( resize_mode="just_resize_simple", ): # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out - if (resize_mode == "just_resize_simple" or - resize_mode == "crop_resize_simple" or - resize_mode == "fill_resize_simple"): + if ( + resize_mode == "just_resize_simple" + or resize_mode == "crop_resize_simple" + or resize_mode == "fill_resize_simple" + ): image = image.convert("RGB") - if (resize_mode == "just_resize_simple"): + if resize_mode == "just_resize_simple": image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - elif (resize_mode == "crop_resize_simple"): # not yet implemented + elif resize_mode == "crop_resize_simple": # not yet implemented pass - elif (resize_mode == "fill_resize_simple"): # not yet implemented + elif resize_mode == "fill_resize_simple": # not yet implemented pass nimage = np.array(image) nimage = nimage[None, :] @@ -320,7 +300,7 @@ def prepare_control_image( timage = torch.from_numpy(nimage) # use fancy lvmin controlnet resizing - elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"): + elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize": nimage = np.array(image) timage, nimage = np_img_resize( np_img=nimage, @@ -336,7 +316,7 @@ def prepare_control_image( exit(1) timage = timage.to(device=device, dtype=dtype) - cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") + cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" if do_classifier_free_guidance and not cfg_injection: timage = torch.cat([timage] * 2) return timage diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 1e8939b0bf..994d83e705 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -9,19 +9,16 @@ from ...backend.stable_diffusion import PipelineIntermediateState from invokeai.app.services.config import InvokeAIAppConfig -def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None): +def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None): latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors if smooth_matrix is not None: latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2) - latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1) + latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1) latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0) latents_ubyte = ( - ((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte() + ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255 ).cpu() return Image.fromarray(latents_ubyte.numpy()) @@ -92,6 +89,7 @@ def stable_diffusion_step_callback( total_steps=node["steps"], ) + def stable_diffusion_xl_step_callback( context: InvocationContext, node: dict, @@ -106,9 +104,9 @@ def stable_diffusion_xl_step_callback( sdxl_latent_rgb_factors = torch.tensor( [ # R G B - [ 0.3816, 0.4930, 0.5320], - [-0.3753, 0.1631, 0.1739], - [ 0.1770, 0.3588, -0.2048], + [0.3816, 0.4930, 0.5320], + [-0.3753, 0.1631, 0.1739], + [0.1770, 0.3588, -0.2048], [-0.4350, -0.2644, -0.4289], ], dtype=sample.dtype, @@ -117,9 +115,9 @@ def stable_diffusion_xl_step_callback( sdxl_smooth_matrix = torch.tensor( [ - #[ 0.0478, 0.1285, 0.0478], - #[ 0.1285, 0.2948, 0.1285], - #[ 0.0478, 0.1285, 0.0478], + # [ 0.0478, 0.1285, 0.0478], + # [ 0.1285, 0.2948, 0.1285], + # [ 0.0478, 0.1285, 0.0478], [0.0358, 0.0964, 0.0358], [0.0964, 0.4711, 0.0964], [0.0358, 0.0964, 0.0358], @@ -143,4 +141,4 @@ def stable_diffusion_xl_step_callback( progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), step=step, total_steps=total_steps, - ) \ No newline at end of file + ) diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index a33b70ceb4..aa2a1f1ca6 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -1,15 +1,6 @@ """ Initialization file for invokeai.backend """ -from .generator import ( - InvokeAIGeneratorBasicParams, - InvokeAIGenerator, - InvokeAIGeneratorOutput, - Img2Img, - Inpaint -) -from .model_management import ( - ModelManager, ModelCache, BaseModelType, - ModelType, SubModelType, ModelInfo - ) +from .generator import InvokeAIGeneratorBasicParams, InvokeAIGenerator, InvokeAIGeneratorOutput, Img2Img, Inpaint +from .model_management import ModelManager, ModelCache, BaseModelType, ModelType, SubModelType, ModelInfo from .model_management.models import SilenceWarnings diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 8af8d399d5..af3231a7d1 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -33,61 +33,66 @@ from ..stable_diffusion.schedulers import SCHEDULER_MAP downsampling = 8 + @dataclass class InvokeAIGeneratorBasicParams: - seed: Optional[int]=None - width: int=512 - height: int=512 - cfg_scale: float=7.5 - steps: int=20 - ddim_eta: float=0.0 - scheduler: str='ddim' - precision: str='float16' - perlin: float=0.0 - threshold: float=0.0 - seamless: bool=False - seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) - h_symmetry_time_pct: Optional[float]=None - v_symmetry_time_pct: Optional[float]=None + seed: Optional[int] = None + width: int = 512 + height: int = 512 + cfg_scale: float = 7.5 + steps: int = 20 + ddim_eta: float = 0.0 + scheduler: str = "ddim" + precision: str = "float16" + perlin: float = 0.0 + threshold: float = 0.0 + seamless: bool = False + seamless_axes: List[str] = field(default_factory=lambda: ["x", "y"]) + h_symmetry_time_pct: Optional[float] = None + v_symmetry_time_pct: Optional[float] = None variation_amount: float = 0.0 - with_variations: list=field(default_factory=list) + with_variations: list = field(default_factory=list) + @dataclass class InvokeAIGeneratorOutput: - ''' + """ InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation operation, including the image, its seed, the model name used to generate the image and the model hash, as well as all the generate() parameters that went into generating the image (in .params, also available as attributes) - ''' + """ + image: Image.Image seed: int model_hash: str attention_maps_images: List[Image.Image] params: Namespace + # we are interposing a wrapper around the original Generator classes so that # old code that calls Generate will continue to work. class InvokeAIGenerator(metaclass=ABCMeta): - def __init__(self, - model_info: dict, - params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), - **kwargs, - ): - self.model_info=model_info - self.params=params + def __init__( + self, + model_info: dict, + params: InvokeAIGeneratorBasicParams = InvokeAIGeneratorBasicParams(), + **kwargs, + ): + self.model_info = model_info + self.params = params self.kwargs = kwargs def generate( self, conditioning: tuple, scheduler, - callback: Optional[Callable]=None, - step_callback: Optional[Callable]=None, - iterations: int=1, + callback: Optional[Callable] = None, + step_callback: Optional[Callable] = None, + iterations: int = 1, **keyword_args, - )->Iterator[InvokeAIGeneratorOutput]: - ''' + ) -> Iterator[InvokeAIGeneratorOutput]: + """ Return an iterator across the indicated number of generations. Each time the iterator is called it will return an InvokeAIGeneratorOutput object. Use like this: @@ -107,7 +112,7 @@ class InvokeAIGenerator(metaclass=ABCMeta): for o in outputs: print(o.image, o.seed) - ''' + """ generator_args = dataclasses.asdict(self.params) generator_args.update(keyword_args) @@ -118,22 +123,21 @@ class InvokeAIGenerator(metaclass=ABCMeta): gen_class = self._generator_class() generator = gen_class(model, self.params.precision, **self.kwargs) if self.params.variation_amount > 0: - generator.set_variation(generator_args.get('seed'), - generator_args.get('variation_amount'), - generator_args.get('with_variations') - ) + generator.set_variation( + generator_args.get("seed"), + generator_args.get("variation_amount"), + generator_args.get("with_variations"), + ) if isinstance(model, DiffusionPipeline): for component in [model.unet, model.vae]: - configure_model_padding(component, - generator_args.get('seamless',False), - generator_args.get('seamless_axes') - ) + configure_model_padding( + component, generator_args.get("seamless", False), generator_args.get("seamless_axes") + ) else: - configure_model_padding(model, - generator_args.get('seamless',False), - generator_args.get('seamless_axes') - ) + configure_model_padding( + model, generator_args.get("seamless", False), generator_args.get("seamless_axes") + ) iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) for i in iteration_count: @@ -147,66 +151,66 @@ class InvokeAIGenerator(metaclass=ABCMeta): image=results[0][0], seed=results[0][1], attention_maps_images=results[0][2], - model_hash = model_hash, - params=Namespace(model_name=model_name,**generator_args), + model_hash=model_hash, + params=Namespace(model_name=model_name, **generator_args), ) if callback: callback(output) yield output @classmethod - def schedulers(self)->List[str]: - ''' + def schedulers(self) -> List[str]: + """ Return list of all the schedulers that we currently handle. - ''' + """ return list(SCHEDULER_MAP.keys()) def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): return generator_class(model, self.params.precision) @classmethod - def _generator_class(cls)->Type[Generator]: - ''' + def _generator_class(cls) -> Type[Generator]: + """ In derived classes return the name of the generator to apply. If you don't override will return the name of the derived class, which nicely parallels the generator class names. - ''' + """ return Generator + # ------------------------------------ class Img2Img(InvokeAIGenerator): - def generate(self, - init_image: Union[Image.Image, torch.FloatTensor], - strength: float=0.75, - **keyword_args - )->Iterator[InvokeAIGeneratorOutput]: - return super().generate(init_image=init_image, - strength=strength, - **keyword_args - ) + def generate( + self, init_image: Union[Image.Image, torch.FloatTensor], strength: float = 0.75, **keyword_args + ) -> Iterator[InvokeAIGeneratorOutput]: + return super().generate(init_image=init_image, strength=strength, **keyword_args) + @classmethod def _generator_class(cls): from .img2img import Img2Img + return Img2Img + # ------------------------------------ # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff class Inpaint(Img2Img): - def generate(self, - mask_image: Union[Image.Image, torch.FloatTensor], - # Seam settings - when 0, doesn't fill seam - seam_size: int = 96, - seam_blur: int = 16, - seam_strength: float = 0.7, - seam_steps: int = 30, - tile_size: int = 32, - inpaint_replace=False, - infill_method=None, - inpaint_width=None, - inpaint_height=None, - inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), - **keyword_args - )->Iterator[InvokeAIGeneratorOutput]: + def generate( + self, + mask_image: Union[Image.Image, torch.FloatTensor], + # Seam settings - when 0, doesn't fill seam + seam_size: int = 96, + seam_blur: int = 16, + seam_strength: float = 0.7, + seam_steps: int = 30, + tile_size: int = 32, + inpaint_replace=False, + infill_method=None, + inpaint_width=None, + inpaint_height=None, + inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), + **keyword_args, + ) -> Iterator[InvokeAIGeneratorOutput]: return super().generate( mask_image=mask_image, seam_size=seam_size, @@ -219,13 +223,16 @@ class Inpaint(Img2Img): inpaint_width=inpaint_width, inpaint_height=inpaint_height, inpaint_fill=inpaint_fill, - **keyword_args + **keyword_args, ) + @classmethod def _generator_class(cls): from .inpaint import Inpaint + return Inpaint + class Generator: downsampling_factor: int latent_channels: int @@ -251,9 +258,7 @@ class Generator: Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it """ - raise NotImplementedError( - "image_iterator() must be implemented in a descendent class" - ) + raise NotImplementedError("image_iterator() must be implemented in a descendent class") def set_variation(self, seed, variation_amount, with_variations): self.seed = seed @@ -280,9 +285,7 @@ class Generator: scope = nullcontext self.free_gpu_mem = free_gpu_mem attention_maps_images = [] - attention_maps_callback = lambda saver: attention_maps_images.append( - saver.get_stacked_maps_image() - ) + attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) make_image = self.get_make_image( sampler=sampler, init_image=init_image, @@ -327,11 +330,7 @@ class Generator: results.append([image, seed, attention_maps_images]) if image_callback is not None: - attention_maps_image = ( - None - if len(attention_maps_images) == 0 - else attention_maps_images[-1] - ) + attention_maps_image = None if len(attention_maps_images) == 0 else attention_maps_images[-1] image_callback( image, seed, @@ -342,9 +341,7 @@ class Generator: seed = self.new_seed() # Free up memory from the last generation. - clear_cuda_cache = ( - kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None - ) + clear_cuda_cache = kwargs["clear_cuda_cache"] if "clear_cuda_cache" in kwargs else None if clear_cuda_cache is not None: clear_cuda_cache() @@ -371,14 +368,8 @@ class Generator: # Get the original alpha channel of the mask if there is one. # Otherwise it is some other black/white image format ('1', 'L' or 'RGB') - pil_init_mask = ( - init_mask.getchannel("A") - if init_mask.mode == "RGBA" - else init_mask.convert("L") - ) - pil_init_image = init_image.convert( - "RGBA" - ) # Add an alpha channel if one doesn't exist + pil_init_mask = init_mask.getchannel("A") if init_mask.mode == "RGBA" else init_mask.convert("L") + pil_init_image = init_image.convert("RGBA") # Add an alpha channel if one doesn't exist # Build an image with only visible pixels from source to use as reference for color-matching. init_rgb_pixels = np.asarray(init_image.convert("RGB"), dtype=np.uint8) @@ -404,10 +395,7 @@ class Generator: np_matched_result[:, :, :] = ( ( ( - ( - np_matched_result[:, :, :].astype(np.float32) - - gen_means[None, None, :] - ) + (np_matched_result[:, :, :].astype(np.float32) - gen_means[None, None, :]) / gen_std[None, None, :] ) * init_std[None, None, :] @@ -433,9 +421,7 @@ class Generator: else: blurred_init_mask = pil_init_mask - multiplied_blurred_init_mask = ImageChops.multiply( - blurred_init_mask, self.pil_image.split()[-1] - ) + multiplied_blurred_init_mask = ImageChops.multiply(blurred_init_mask, self.pil_image.split()[-1]) # Paste original on color-corrected generation (using blurred mask) matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask) @@ -461,10 +447,7 @@ class Generator: latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors latents_ubyte = ( - ((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte() + ((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255 ).cpu() return Image.fromarray(latents_ubyte.numpy()) @@ -494,9 +477,7 @@ class Generator: temp_height = int((height + 7) / 8) * 8 noise = torch.stack( [ - rand_perlin_2d( - (temp_height, temp_width), (8, 8), device=self.model.device - ).to(fixdevice) + rand_perlin_2d((temp_height, temp_width), (8, 8), device=self.model.device).to(fixdevice) for _ in range(input_channels) ], dim=0, @@ -573,8 +554,6 @@ class Generator: device=device, ) if self.perlin > 0.0: - perlin_noise = self.get_perlin_noise( - width // self.downsampling_factor, height // self.downsampling_factor - ) + perlin_noise = self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) x = (1 - self.perlin) * x + self.perlin * perlin_noise return x diff --git a/invokeai/backend/generator/img2img.py b/invokeai/backend/generator/img2img.py index b3b0e8f510..5490b2325c 100644 --- a/invokeai/backend/generator/img2img.py +++ b/invokeai/backend/generator/img2img.py @@ -77,10 +77,7 @@ class Img2Img(Generator): callback=step_callback, seed=seed, ) - if ( - pipeline_output.attention_map_saver is not None - and attention_maps_callback is not None - ): + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: attention_maps_callback(pipeline_output.attention_map_saver) return pipeline.numpy_to_pil(pipeline_output.images)[0] @@ -91,7 +88,5 @@ class Img2Img(Generator): x = torch.randn_like(like, device=device) if self.perlin > 0.0: shape = like.shape - x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise( - shape[3], shape[2] - ) + x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(shape[3], shape[2]) return x diff --git a/invokeai/backend/generator/inpaint.py b/invokeai/backend/generator/inpaint.py index c91fe0c6a7..7aeb3d4809 100644 --- a/invokeai/backend/generator/inpaint.py +++ b/invokeai/backend/generator/inpaint.py @@ -68,15 +68,11 @@ class Inpaint(Img2Img): return im # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though) - im_patched_np = PatchMatch.inpaint( - im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3 - ) + im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3) im_patched = Image.fromarray(im_patched_np, mode="RGB") return im_patched - def tile_fill_missing( - self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None - ) -> Image.Image: + def tile_fill_missing(self, im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image: # Only fill if there's an alpha layer if im.mode != "RGBA": return im @@ -127,15 +123,11 @@ class Inpaint(Img2Img): return si - def mask_edge( - self, mask: Image.Image, edge_size: int, edge_blur: int - ) -> Image.Image: + def mask_edge(self, mask: Image.Image, edge_size: int, edge_blur: int) -> Image.Image: npimg = np.asarray(mask, dtype=np.uint8) # Detect any partially transparent regions - npgradient = np.uint8( - 255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0)) - ) + npgradient = np.uint8(255 * (1.0 - np.floor(np.abs(0.5 - np.float32(npimg) / 255.0) * 2.0))) # Detect hard edges npedge = cv2.Canny(npimg, threshold1=100, threshold2=200) @@ -144,9 +136,7 @@ class Inpaint(Img2Img): npmask = npgradient + npedge # Expand - npmask = cv2.dilate( - npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2) - ) + npmask = cv2.dilate(npmask, np.ones((3, 3), np.uint8), iterations=int(edge_size / 2)) new_mask = Image.fromarray(npmask) @@ -242,25 +232,19 @@ class Inpaint(Img2Img): if infill_method == "patchmatch" and PatchMatch.patchmatch_available(): init_filled = self.infill_patchmatch(self.pil_image.copy()) elif infill_method == "tile": - init_filled = self.tile_fill_missing( - self.pil_image.copy(), seed=self.seed, tile_size=tile_size - ) + init_filled = self.tile_fill_missing(self.pil_image.copy(), seed=self.seed, tile_size=tile_size) elif infill_method == "solid": solid_bg = Image.new("RGBA", init_image.size, inpaint_fill) init_filled = Image.alpha_composite(solid_bg, init_image) else: - raise ValueError( - f"Non-supported infill type {infill_method}", infill_method - ) + raise ValueError(f"Non-supported infill type {infill_method}", infill_method) init_filled.paste(init_image, (0, 0), init_image.split()[-1]) # Resize if requested for inpainting if inpaint_width and inpaint_height: init_filled = init_filled.resize((inpaint_width, inpaint_height)) - debug_image( - init_filled, "init_filled", debug_status=self.enable_image_debugging - ) + debug_image(init_filled, "init_filled", debug_status=self.enable_image_debugging) # Create init tensor init_image = image_resized_to_grid_as_tensor(init_filled.convert("RGB")) @@ -289,9 +273,7 @@ class Inpaint(Img2Img): "mask_image AFTER multiply with pil_image", debug_status=self.enable_image_debugging, ) - mask: torch.FloatTensor = image_resized_to_grid_as_tensor( - mask_image, normalize=False - ) + mask: torch.FloatTensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) else: mask: torch.FloatTensor = mask_image @@ -302,9 +284,9 @@ class Inpaint(Img2Img): # todo: support cross-attention control uc, c, _ = conditioning - conditioning_data = ConditioningData( - uc, c, cfg_scale - ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta) + conditioning_data = ConditioningData(uc, c, cfg_scale).add_scheduler_args_if_applicable( + pipeline.scheduler, eta=ddim_eta + ) def make_image(x_T: torch.Tensor, seed: int): pipeline_output = pipeline.inpaint_from_embeddings( @@ -318,15 +300,10 @@ class Inpaint(Img2Img): seed=seed, ) - if ( - pipeline_output.attention_map_saver is not None - and attention_maps_callback is not None - ): + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: attention_maps_callback(pipeline_output.attention_map_saver) - result = self.postprocess_size_and_mask( - pipeline.numpy_to_pil(pipeline_output.images)[0] - ) + result = self.postprocess_size_and_mask(pipeline.numpy_to_pil(pipeline_output.images)[0]) # Seam paint if this is our first pass (seam_size set to 0 during seam painting) if seam_size > 0: diff --git a/invokeai/backend/image_util/__init__.py b/invokeai/backend/image_util/__init__.py index 410f003f6a..9fa635b389 100644 --- a/invokeai/backend/image_util/__init__.py +++ b/invokeai/backend/image_util/__init__.py @@ -8,9 +8,7 @@ from .txt2mask import Txt2Mask from .util import InitImageResizer, make_grid -def debug_image( - debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False -): +def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False): if not debug_status: return diff --git a/invokeai/backend/image_util/invisible_watermark.py b/invokeai/backend/image_util/invisible_watermark.py index 0fb72b095d..4605daea43 100644 --- a/invokeai/backend/image_util/invisible_watermark.py +++ b/invokeai/backend/image_util/invisible_watermark.py @@ -9,26 +9,26 @@ from PIL import Image from imwatermark import WatermarkEncoder from invokeai.app.services.config import InvokeAIAppConfig import invokeai.backend.util.logging as logger + config = InvokeAIAppConfig.get_config() + class InvisibleWatermark: """ Wrapper around InvisibleWatermark module. """ - + @classmethod def invisible_watermark_available(self) -> bool: return config.invisible_watermark @classmethod - def add_watermark(self, image: Image, watermark_text:str) -> Image: + def add_watermark(self, image: Image, watermark_text: str) -> Image: if not self.invisible_watermark_available(): return image logger.debug(f'Applying invisible watermark "{watermark_text}"') bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) encoder = WatermarkEncoder() - encoder.set_watermark('bytes', watermark_text.encode('utf-8')) - bgr_encoded = encoder.encode(bgr, 'dwtDct') - return Image.fromarray( - cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB) - ).convert("RGBA") + encoder.set_watermark("bytes", watermark_text.encode("utf-8")) + bgr_encoded = encoder.encode(bgr, "dwtDct") + return Image.fromarray(cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)).convert("RGBA") diff --git a/invokeai/backend/image_util/patchmatch.py b/invokeai/backend/image_util/patchmatch.py index 2e65f08d9f..98055f60c8 100644 --- a/invokeai/backend/image_util/patchmatch.py +++ b/invokeai/backend/image_util/patchmatch.py @@ -7,8 +7,10 @@ be suppressed or deferred import numpy as np import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig + config = InvokeAIAppConfig.get_config() + class PatchMatch: """ Thin class wrapper around the patchmatch function. diff --git a/invokeai/backend/image_util/pngwriter.py b/invokeai/backend/image_util/pngwriter.py index 452bbfc783..d3f53d7e72 100644 --- a/invokeai/backend/image_util/pngwriter.py +++ b/invokeai/backend/image_util/pngwriter.py @@ -34,9 +34,7 @@ class PngWriter: # saves image named _image_ to outdir/name, writing metadata from prompt # returns full path of output - def save_image_and_prompt_to_png( - self, image, dream_prompt, name, metadata=None, compress_level=6 - ): + def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6): path = os.path.join(self.outdir, name) info = PngImagePlugin.PngInfo() info.add_text("Dream", dream_prompt) @@ -114,8 +112,6 @@ class PromptFormatter: if opt.variation_amount > 0: switches.append(f"-v{opt.variation_amount}") if opt.with_variations: - formatted_variations = ",".join( - f"{seed}:{weight}" for seed, weight in opt.with_variations - ) + formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations) switches.append(f"-V{formatted_variations}") return " ".join(switches) diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 7a31c9422d..483e563b82 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -9,14 +9,17 @@ from invokeai.backend import SilenceWarnings from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.devices import choose_torch_device import invokeai.backend.util.logging as logger + config = InvokeAIAppConfig.get_config() -CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker' +CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" + class SafetyChecker: """ Wrapper around SafetyChecker model. """ + safety_checker = None feature_extractor = None tried_load: bool = False @@ -25,21 +28,19 @@ class SafetyChecker: def _load_safety_checker(self): if self.tried_load: return - + if config.nsfw_checker: try: from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( - config.models_path / CHECKER_PATH - ) - self.feature_extractor = AutoFeatureExtractor.from_pretrained( - config.models_path / CHECKER_PATH) - logger.info('NSFW checker initialized') + + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) + self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) + logger.info("NSFW checker initialized") except Exception as e: - logger.warning(f'Could not load NSFW checker: {str(e)}') + logger.warning(f"Could not load NSFW checker: {str(e)}") else: - logger.info('NSFW checker loading disabled') + logger.info("NSFW checker loading disabled") self.tried_load = True @classmethod @@ -51,7 +52,7 @@ class SafetyChecker: def has_nsfw_concept(self, image: Image) -> bool: if not self.safety_checker_available(): return False - + device = choose_torch_device() features = self.feature_extractor([image], return_tensors="pt") features.to(device) diff --git a/invokeai/backend/image_util/seamless.py b/invokeai/backend/image_util/seamless.py index 4fbc0cd78e..6fb2617901 100644 --- a/invokeai/backend/image_util/seamless.py +++ b/invokeai/backend/image_util/seamless.py @@ -5,12 +5,8 @@ def _conv_forward_asymmetric(self, input, weight, bias): """ Patch for Conv2d._conv_forward that supports asymmetric padding """ - working = nn.functional.pad( - input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"] - ) - working = nn.functional.pad( - working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"] - ) + working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"]) + working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"]) return nn.functional.conv2d( working, weight, @@ -32,18 +28,14 @@ def configure_model_padding(model, seamless, seamless_axes): if seamless: m.asymmetric_padding_mode = {} m.asymmetric_padding = {} - m.asymmetric_padding_mode["x"] = ( - "circular" if ("x" in seamless_axes) else "constant" - ) + m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant" m.asymmetric_padding["x"] = ( m._reversed_padding_repeated_twice[0], m._reversed_padding_repeated_twice[1], 0, 0, ) - m.asymmetric_padding_mode["y"] = ( - "circular" if ("y" in seamless_axes) else "constant" - ) + m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant" m.asymmetric_padding["y"] = ( 0, 0, diff --git a/invokeai/backend/image_util/txt2mask.py b/invokeai/backend/image_util/txt2mask.py index 429c9b63fb..12db54b0db 100644 --- a/invokeai/backend/image_util/txt2mask.py +++ b/invokeai/backend/image_util/txt2mask.py @@ -39,23 +39,18 @@ CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined" CLIPSEG_SIZE = 352 config = InvokeAIAppConfig.get_config() + class SegmentedGrayscale(object): def __init__(self, image: Image, heatmap: torch.Tensor): self.heatmap = heatmap self.image = image def to_grayscale(self, invert: bool = False) -> Image: - return self._rescale( - Image.fromarray( - np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255) - ) - ) + return self._rescale(Image.fromarray(np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255))) def to_mask(self, threshold: float = 0.5) -> Image: discrete_heatmap = self.heatmap.lt(threshold).int() - return self._rescale( - Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L") - ) + return self._rescale(Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")) def to_transparent(self, invert: bool = False) -> Image: transparent_image = self.image.copy() @@ -67,11 +62,7 @@ class SegmentedGrayscale(object): # unscales and uncrops the 352x352 heatmap so that it matches the image again def _rescale(self, heatmap: Image) -> Image: - size = ( - self.image.width - if (self.image.width > self.image.height) - else self.image.height - ) + size = self.image.width if (self.image.width > self.image.height) else self.image.height resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS) return resized_image.crop((0, 0, self.image.width, self.image.height)) @@ -87,12 +78,8 @@ class Txt2Mask(object): # BUG: we are not doing anything with the device option at this time self.device = device - self.processor = AutoProcessor.from_pretrained( - CLIPSEG_MODEL, cache_dir=config.cache_dir - ) - self.model = CLIPSegForImageSegmentation.from_pretrained( - CLIPSEG_MODEL, cache_dir=config.cache_dir - ) + self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir) + self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL, cache_dir=config.cache_dir) @torch.no_grad() def segment(self, image, prompt: str) -> SegmentedGrayscale: @@ -107,9 +94,7 @@ class Txt2Mask(object): image = ImageOps.exif_transpose(image) img = self._scale_and_crop(image) - inputs = self.processor( - text=[prompt], images=[img], padding=True, return_tensors="pt" - ) + inputs = self.processor(text=[prompt], images=[img], padding=True, return_tensors="pt") outputs = self.model(**inputs) heatmap = torch.sigmoid(outputs.logits) return SegmentedGrayscale(image, heatmap) diff --git a/invokeai/backend/install/check_root.py b/invokeai/backend/install/check_root.py index e8dde681bf..ded9e66635 100644 --- a/invokeai/backend/install/check_root.py +++ b/invokeai/backend/install/check_root.py @@ -6,28 +6,31 @@ from invokeai.app.services.config import ( InvokeAIAppConfig, ) + def check_invokeai_root(config: InvokeAIAppConfig): try: - assert config.model_conf_path.exists(), f'{config.model_conf_path} not found' - assert config.db_path.parent.exists(), f'{config.db_path.parent} not found' - assert config.models_path.exists(), f'{config.models_path} not found' + assert config.model_conf_path.exists(), f"{config.model_conf_path} not found" + assert config.db_path.parent.exists(), f"{config.db_path.parent} not found" + assert config.models_path.exists(), f"{config.models_path} not found" for model in [ - 'CLIP-ViT-bigG-14-laion2B-39B-b160k', - 'bert-base-uncased', - 'clip-vit-large-patch14', - 'sd-vae-ft-mse', - 'stable-diffusion-2-clip', - 'stable-diffusion-safety-checker']: - path = config.models_path / f'core/convert/{model}' - assert path.exists(), f'{path} is missing' + "CLIP-ViT-bigG-14-laion2B-39B-b160k", + "bert-base-uncased", + "clip-vit-large-patch14", + "sd-vae-ft-mse", + "stable-diffusion-2-clip", + "stable-diffusion-safety-checker", + ]: + path = config.models_path / f"core/convert/{model}" + assert path.exists(), f"{path} is missing" except Exception as e: print() - print(f'An exception has occurred: {str(e)}') - print('== STARTUP ABORTED ==') - print('** One or more necessary files is missing from your InvokeAI root directory **') - print('** Please rerun the configuration script to fix this problem. **') - print('** From the launcher, selection option [7]. **') - print('** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **') - input('Press any key to continue...') + print(f"An exception has occurred: {str(e)}") + print("== STARTUP ABORTED ==") + print("** One or more necessary files is missing from your InvokeAI root directory **") + print("** Please rerun the configuration script to fix this problem. **") + print("** From the launcher, selection option [7]. **") + print( + '** From the command line, activate the virtual environment and run "invokeai-configure --yes --skip-sd-weights" **' + ) + input("Press any key to continue...") sys.exit(0) - diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index d3ba4f4ae2..972e6668c4 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -60,9 +60,7 @@ from invokeai.backend.install.model_install_backend import ( InstallSelections, ModelInstall, ) -from invokeai.backend.model_management.model_probe import ( - ModelType, BaseModelType - ) +from invokeai.backend.model_management.model_probe import ModelType, BaseModelType warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() @@ -77,7 +75,7 @@ Model_dir = "models" Default_config_file = config.model_conf_path SD_Configs = config.legacy_conf_path -PRECISION_CHOICES = ['auto','float16','float32'] +PRECISION_CHOICES = ["auto", "float16", "float32"] INIT_FILE_PREAMBLE = """# InvokeAI initialization file # This is the InvokeAI initialization file, which contains command-line default values. @@ -85,7 +83,8 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file # or renaming it and then running invokeai-configure again. """ -logger=InvokeAILogger.getLogger() +logger = InvokeAILogger.getLogger() + # -------------------------------------------- def postscript(errors: None): @@ -108,7 +107,9 @@ Add the '--help' argument to see all of the command-line switches available for """ else: - message = "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n" + message = ( + "\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n" + ) for err in errors: message += f"\t - {err}\n" message += "Please check the logs above and correct any issues." @@ -169,9 +170,7 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th logger.info(f"Installing {label} model file {model_url}...") if not os.path.exists(model_dest): os.makedirs(os.path.dirname(model_dest), exist_ok=True) - request.urlretrieve( - model_url, model_dest, ProgressBar(os.path.basename(model_dest)) - ) + request.urlretrieve(model_url, model_dest, ProgressBar(os.path.basename(model_dest))) logger.info("...downloaded successfully") else: logger.info("...exists") @@ -182,90 +181,93 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th def download_conversion_models(): - target_dir = config.root_path / 'models/core/convert' + target_dir = config.root_path / "models/core/convert" kwargs = dict() # for future use try: - logger.info('Downloading core tokenizers and text encoders') + logger.info("Downloading core tokenizers and text encoders") # bert with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs) - bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True) - + bert.save_pretrained(target_dir / "bert-base-uncased", safe_serialization=True) + # sd-1 - repo_id = 'openai/clip-vit-large-patch14' - hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / 'clip-vit-large-patch14') - hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / 'clip-vit-large-patch14') + repo_id = "openai/clip-vit-large-patch14" + hf_download_from_pretrained(CLIPTokenizer, repo_id, target_dir / "clip-vit-large-patch14") + hf_download_from_pretrained(CLIPTextModel, repo_id, target_dir / "clip-vit-large-patch14") # sd-2 repo_id = "stabilityai/stable-diffusion-2" pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True) + pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "tokenizer", safe_serialization=True) pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True) + pipeline.save_pretrained(target_dir / "stable-diffusion-2-clip" / "text_encoder", safe_serialization=True) # sd-xl - tokenizer_2 repo_id = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" - _, model_name = repo_id.split('/') + _, model_name = repo_id.split("/") pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs) pipeline.save_pretrained(target_dir / model_name, safe_serialization=True) - + pipeline = CLIPTextConfig.from_pretrained(repo_id, **kwargs) pipeline.save_pretrained(target_dir / model_name, safe_serialization=True) - + # VAE - logger.info('Downloading stable diffusion VAE') - vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs) - vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True) + logger.info("Downloading stable diffusion VAE") + vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", **kwargs) + vae.save_pretrained(target_dir / "sd-vae-ft-mse", safe_serialization=True) # safety checking - logger.info('Downloading safety checker') + logger.info("Downloading safety checker") repo_id = "CompVis/stable-diffusion-safety-checker" - pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) + pipeline = AutoFeatureExtractor.from_pretrained(repo_id, **kwargs) + pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True) - pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs) - pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) + pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id, **kwargs) + pipeline.save_pretrained(target_dir / "stable-diffusion-safety-checker", safe_serialization=True) except KeyboardInterrupt: raise except Exception as e: logger.error(str(e)) + # --------------------------------------------- def download_realesrgan(): logger.info("Installing ESRGAN Upscaling models...") URLs = [ dict( - url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - dest = "core/upscaling/realesrgan/RealESRGAN_x4plus.pth", - description = "RealESRGAN_x4plus.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + dest="core/upscaling/realesrgan/RealESRGAN_x4plus.pth", + description="RealESRGAN_x4plus.pth", ), dict( - url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - dest = "core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", - description = "RealESRGAN_x4plus_anime_6B.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + dest="core/upscaling/realesrgan/RealESRGAN_x4plus_anime_6B.pth", + description="RealESRGAN_x4plus_anime_6B.pth", ), dict( - url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", - description = "ESRGAN_SRx4_DF2KOST_official.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + dest="core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + description="ESRGAN_SRx4_DF2KOST_official.pth", ), dict( - url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth", - description = "RealESRGAN_x2plus.pth", + url="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + dest="core/upscaling/realesrgan/RealESRGAN_x2plus.pth", + description="RealESRGAN_x2plus.pth", ), ] for model in URLs: - download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description']) + download_with_progress_bar(model["url"], config.models_path / model["dest"], model["description"]) + # --------------------------------------------- def download_support_models(): download_realesrgan() download_conversion_models() + # ------------------------------------- def get_root(root: str = None) -> str: if root: @@ -275,6 +277,7 @@ def get_root(root: str = None) -> str: else: return str(config.root_path) + # ------------------------------------- class editOptsForm(CyclingForm, npyscreen.FormMultiPage): # for responsive resizing - disabled @@ -283,14 +286,14 @@ class editOptsForm(CyclingForm, npyscreen.FormMultiPage): def create(self): program_opts = self.parentApp.program_opts old_opts = self.parentApp.invokeai_opts - first_time = not (config.root_path / 'invokeai.yaml').exists() + first_time = not (config.root_path / "invokeai.yaml").exists() access_token = HfFolder.get_token() window_width, window_height = get_terminal_size() label = """Configure startup settings. You can come back and change these later. Use ctrl-N and ctrl-P to move to the ext and

revious fields. Use cursor arrows to make a checkbox selection, and space to toggle. """ - for i in textwrap.wrap(label,width=window_width-6): + for i in textwrap.wrap(label, width=window_width - 6): self.add_widget_intelligent( npyscreen.FixedText, value=i, @@ -300,7 +303,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle. self.nextrely += 1 label = """HuggingFace access token (OPTIONAL) for automatic model downloads. See https://huggingface.co/settings/tokens.""" - for line in textwrap.wrap(label,width=window_width-6): + for line in textwrap.wrap(label, width=window_width - 6): self.add_widget_intelligent( npyscreen.FixedText, value=line, @@ -343,7 +346,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle. relx=50, scroll_exit=True, ) - self.nextrely -=1 + self.nextrely -= 1 self.always_use_cpu = self.add_widget_intelligent( npyscreen.Checkbox, name="Force CPU to be used on GPU systems", @@ -351,10 +354,8 @@ Use cursor arrows to make a checkbox selection, and space to toggle. relx=80, scroll_exit=True, ) - precision = old_opts.precision or ( - "float32" if program_opts.full_precision else "auto" - ) - self.nextrely +=1 + precision = old_opts.precision or ("float32" if program_opts.full_precision else "auto") + self.nextrely += 1 self.add_widget_intelligent( npyscreen.TitleFixedText, name="Floating Point Precision", @@ -363,10 +364,10 @@ Use cursor arrows to make a checkbox selection, and space to toggle. color="CONTROL", scroll_exit=True, ) - self.nextrely -=1 + self.nextrely -= 1 self.precision = self.add_widget_intelligent( SingleSelectColumns, - columns = 3, + columns=3, name="Precision", values=PRECISION_CHOICES, value=PRECISION_CHOICES.index(precision), @@ -398,25 +399,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle. scroll_exit=True, ) self.autoimport_dirs = {} - self.autoimport_dirs['autoimport_dir'] = self.add_widget_intelligent( - FileBox, - name=f'Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models', - value=str(config.root_path / config.autoimport_dir), - select_dir=True, - must_exist=False, - use_two_lines=False, - labelColor="GOOD", - begin_entry_at=32, - max_height = 3, - scroll_exit=True - ) + self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent( + FileBox, + name=f"Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models", + value=str(config.root_path / config.autoimport_dir), + select_dir=True, + must_exist=False, + use_two_lines=False, + labelColor="GOOD", + begin_entry_at=32, + max_height=3, + scroll_exit=True, + ) self.nextrely += 1 label = """BY DOWNLOADING THE STABLE DIFFUSION WEIGHT FILES, YOU AGREE TO HAVE READ AND ACCEPTED THE CREATIVEML RESPONSIBLE AI LICENSES LOCATED AT https://huggingface.co/spaces/CompVis/stable-diffusion-license and https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md """ - for i in textwrap.wrap(label,width=window_width-6): + for i in textwrap.wrap(label, width=window_width - 6): self.add_widget_intelligent( npyscreen.FixedText, value=i, @@ -431,11 +432,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS scroll_exit=True, ) self.nextrely += 1 - label = ( - "DONE" - if program_opts.skip_sd_weights or program_opts.default_only - else "NEXT" - ) + label = "DONE" if program_opts.skip_sd_weights or program_opts.default_only else "NEXT" self.ok_button = self.add_widget_intelligent( CenteredButtonPress, name=label, @@ -454,13 +451,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS self.editing = False else: self.editing = True - + def validate_field_values(self, opt: Namespace) -> bool: bad_fields = [] if not opt.license_acceptance: - bad_fields.append( - "Please accept the license terms before proceeding to model downloads" - ) + bad_fields.append("Please accept the license terms before proceeding to model downloads") if not Path(opt.outdir).parent.exists(): bad_fields.append( f"The output directory does not seem to be valid. Please check that {str(Path(opt.outdir).parent)} is an existing directory." @@ -478,11 +473,11 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS new_opts = Namespace() for attr in [ - "outdir", - "free_gpu_mem", - "max_cache_size", - "xformers_enabled", - "always_use_cpu", + "outdir", + "free_gpu_mem", + "max_cache_size", + "xformers_enabled", + "always_use_cpu", ]: setattr(new_opts, attr, getattr(self, attr).value) @@ -495,7 +490,7 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS new_opts.hf_token = self.hf_token.value new_opts.license_acceptance = self.license_acceptance.value new_opts.precision = PRECISION_CHOICES[self.precision.value[0]] - + return new_opts @@ -534,19 +529,20 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam editApp.run() return editApp.new_opts() + def default_startup_options(init_file: Path) -> Namespace: opts = InvokeAIAppConfig.get_config() return opts + def default_user_selections(program_opts: Namespace) -> InstallSelections: - try: installer = ModelInstall(config) except omegaconf.errors.ConfigKeyError: - logger.warning('Your models.yaml file is corrupt or out of date. Reinitializing') + logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing") initialize_rootdir(config.root_path, True) installer = ModelInstall(config) - + models = installer.all_models() return InstallSelections( install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id] @@ -556,55 +552,46 @@ def default_user_selections(program_opts: Namespace) -> InstallSelections: else list(), ) + # ------------------------------------- def initialize_rootdir(root: Path, yes_to_all: bool = False): logger.info("Initializing InvokeAI runtime directory") - for name in ( - "models", - "databases", - "text-inversion-output", - "text-inversion-training-data", - "configs" - ): + for name in ("models", "databases", "text-inversion-output", "text-inversion-training-data", "configs"): os.makedirs(os.path.join(root, name), exist_ok=True) for model_type in ModelType: - Path(root, 'autoimport', model_type.value).mkdir(parents=True, exist_ok=True) + Path(root, "autoimport", model_type.value).mkdir(parents=True, exist_ok=True) configs_src = Path(configs.__path__[0]) configs_dest = root / "configs" if not os.path.samefile(configs_src, configs_dest): shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) - dest = root / 'models' + dest = root / "models" for model_base in BaseModelType: for model_type in ModelType: path = dest / model_base.value / model_type.value path.mkdir(parents=True, exist_ok=True) - path = dest / 'core' + path = dest / "core" path.mkdir(parents=True, exist_ok=True) maybe_create_models_yaml(root) + def maybe_create_models_yaml(root: Path): - models_yaml = root / 'configs' / 'models.yaml' + models_yaml = root / "configs" / "models.yaml" if models_yaml.exists(): - if OmegaConf.load(models_yaml).get('__metadata__'): # up to date + if OmegaConf.load(models_yaml).get("__metadata__"): # up to date return else: - logger.info('Creating new models.yaml, original saved as models.yaml.orig') - models_yaml.rename(models_yaml.parent / 'models.yaml.orig') - - with open(models_yaml,'w') as yaml_file: - yaml_file.write(yaml.dump({'__metadata__': - {'version':'3.0.0'} - } - ) - ) - + logger.info("Creating new models.yaml, original saved as models.yaml.orig") + models_yaml.rename(models_yaml.parent / "models.yaml.orig") + + with open(models_yaml, "w") as yaml_file: + yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) + + # ------------------------------------- -def run_console_ui( - program_opts: Namespace, initfile: Path = None -) -> (Namespace, Namespace): +def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace): # parse_args() will read from init file if present invokeai_opts = default_startup_options(initfile) invokeai_opts.root = program_opts.root @@ -616,8 +603,9 @@ def run_console_ui( # the install-models application spawns a subprocess to install # models, and will crash unless this is set before running. import torch + torch.multiprocessing.set_start_method("spawn") - + editApp = EditOptApplication(program_opts, invokeai_opts) editApp.run() if editApp.user_cancelled: @@ -634,39 +622,42 @@ def write_opts(opts: Namespace, init_file: Path): # this will load current settings new_config = InvokeAIAppConfig.get_config() new_config.root = config.root - - for key,value in opts.__dict__.items(): - if hasattr(new_config,key): - setattr(new_config,key,value) - with open(init_file,'w', encoding='utf-8') as file: + for key, value in opts.__dict__.items(): + if hasattr(new_config, key): + setattr(new_config, key, value) + + with open(init_file, "w", encoding="utf-8") as file: file.write(new_config.to_yaml()) - if hasattr(opts,'hf_token') and opts.hf_token: + if hasattr(opts, "hf_token") and opts.hf_token: HfLogin(opts.hf_token) + # ------------------------------------- def default_output_dir() -> Path: return config.root_path / "outputs" + # ------------------------------------- def write_default_options(program_opts: Namespace, initfile: Path): opt = default_startup_options(initfile) write_opts(opt, initfile) + # ------------------------------------- # Here we bring in # the legacy Args object in order to parse # the old init file and write out the new # yaml format. -def migrate_init_file(legacy_format:Path): - old = legacy_parser.parse_args([f'@{str(legacy_format)}']) +def migrate_init_file(legacy_format: Path): + old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() fields = list(get_type_hints(InvokeAIAppConfig).keys()) for attr in fields: - if hasattr(old,attr): - setattr(new,attr,getattr(old,attr)) + if hasattr(old, attr): + setattr(new, attr, getattr(old, attr)) # a few places where the field names have changed and we have to # manually add in the new names/values @@ -674,40 +665,43 @@ def migrate_init_file(legacy_format:Path): new.conf_path = old.conf new.root = legacy_format.parent.resolve() - invokeai_yaml = legacy_format.parent / 'invokeai.yaml' - with open(invokeai_yaml,"w", encoding="utf-8") as outfile: + invokeai_yaml = legacy_format.parent / "invokeai.yaml" + with open(invokeai_yaml, "w", encoding="utf-8") as outfile: outfile.write(new.to_yaml()) - legacy_format.replace(legacy_format.parent / 'invokeai.init.orig') + legacy_format.replace(legacy_format.parent / "invokeai.init.orig") + # ------------------------------------- def migrate_models(root: Path): from invokeai.backend.install.migrate_to_3 import do_migrate + do_migrate(root, root) -def migrate_if_needed(opt: Namespace, root: Path)->bool: - # We check for to see if the runtime directory is correctly initialized. - old_init_file = root / 'invokeai.init' - new_init_file = root / 'invokeai.yaml' - old_hub = root / 'models/hub' - migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists() - - if migration_needed: - if opt.yes_to_all or \ - yes_or_no(f'{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?'): - logger.info('** Migrating invokeai.init to invokeai.yaml') +def migrate_if_needed(opt: Namespace, root: Path) -> bool: + # We check for to see if the runtime directory is correctly initialized. + old_init_file = root / "invokeai.init" + new_init_file = root / "invokeai.yaml" + old_hub = root / "models/hub" + migration_needed = (old_init_file.exists() and not new_init_file.exists()) and old_hub.exists() + + if migration_needed: + if opt.yes_to_all or yes_or_no( + f"{str(config.root_path)} appears to be a 2.3 format root directory. Convert to version 3.0?" + ): + logger.info("** Migrating invokeai.init to invokeai.yaml") migrate_init_file(old_init_file) - config.parse_args(argv=[],conf=OmegaConf.load(new_init_file)) + config.parse_args(argv=[], conf=OmegaConf.load(new_init_file)) if old_hub.exists(): migrate_models(config.root_path) else: - print('Cannot continue without conversion. Aborting.') - + print("Cannot continue without conversion. Aborting.") + return migration_needed - + # ------------------------------------- def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") @@ -764,9 +758,9 @@ def main(): invoke_args = [] if opt.root: - invoke_args.extend(['--root',opt.root]) + invoke_args.extend(["--root", opt.root]) if opt.full_precision: - invoke_args.extend(['--precision','float32']) + invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) logger = InvokeAILogger().getLogger(config=config) @@ -782,22 +776,18 @@ def main(): initialize_rootdir(config.root_path, opt.yes_to_all) models_to_download = default_user_selections(opt) - new_init_file = config.root_path / 'invokeai.yaml' + new_init_file = config.root_path / "invokeai.yaml" if opt.yes_to_all: write_default_options(opt, new_init_file) - init_options = Namespace( - precision="float32" if opt.full_precision else "float16" - ) + init_options = Namespace(precision="float32" if opt.full_precision else "float16") else: init_options, models_to_download = run_console_ui(opt, new_init_file) if init_options: write_opts(init_options, new_init_file) else: - logger.info( - '\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n' - ) + logger.info('\n** CANCELLED AT USER\'S REQUEST. USE THE "invoke.sh" LAUNCHER TO RUN LATER **\n') sys.exit(0) - + if opt.skip_support_models: logger.info("Skipping support models at user's request") else: @@ -811,7 +801,7 @@ def main(): postscript(errors=errors) if not opt.yes_to_all: - input('Press any key to continue...') + input("Press any key to continue...") except KeyboardInterrupt: print("\nGoodbye! Come back soon.") diff --git a/invokeai/backend/install/legacy_arg_parsing.py b/invokeai/backend/install/legacy_arg_parsing.py index 684c50c77d..f6a0682f2a 100644 --- a/invokeai/backend/install/legacy_arg_parsing.py +++ b/invokeai/backend/install/legacy_arg_parsing.py @@ -47,17 +47,18 @@ PRECISION_CHOICES = [ "float16", ] + class FileArgumentParser(ArgumentParser): """ Supports reading defaults from an init file. """ + def convert_arg_line_to_args(self, arg_line): return shlex.split(arg_line, comments=True) legacy_parser = FileArgumentParser( - description= - """ + description=""" Generate images using Stable Diffusion. Use --web to launch the web interface. Use --from_file to load prompts from a file path or standard input ("-"). @@ -65,304 +66,279 @@ Generate images using Stable Diffusion. Other command-line arguments are defaults that can usually be overridden prompt the command prompt. """, - fromfile_prefix_chars='@', + fromfile_prefix_chars="@", ) -general_group = legacy_parser.add_argument_group('General') -model_group = legacy_parser.add_argument_group('Model selection') -file_group = legacy_parser.add_argument_group('Input/output') -web_server_group = legacy_parser.add_argument_group('Web server') -render_group = legacy_parser.add_argument_group('Rendering') -postprocessing_group = legacy_parser.add_argument_group('Postprocessing') -deprecated_group = legacy_parser.add_argument_group('Deprecated options') +general_group = legacy_parser.add_argument_group("General") +model_group = legacy_parser.add_argument_group("Model selection") +file_group = legacy_parser.add_argument_group("Input/output") +web_server_group = legacy_parser.add_argument_group("Web server") +render_group = legacy_parser.add_argument_group("Rendering") +postprocessing_group = legacy_parser.add_argument_group("Postprocessing") +deprecated_group = legacy_parser.add_argument_group("Deprecated options") -deprecated_group.add_argument('--laion400m') -deprecated_group.add_argument('--weights') # deprecated -general_group.add_argument( - '--version','-V', - action='store_true', - help='Print InvokeAI version number' -) +deprecated_group.add_argument("--laion400m") +deprecated_group.add_argument("--weights") # deprecated +general_group.add_argument("--version", "-V", action="store_true", help="Print InvokeAI version number") model_group.add_argument( - '--root_dir', + "--root_dir", default=None, help='Path to directory containing "models", "outputs" and "configs". If not present will read from environment variable INVOKEAI_ROOT. Defaults to ~/invokeai.', ) model_group.add_argument( - '--config', - '-c', - '-config', - dest='conf', - default='./configs/models.yaml', - help='Path to configuration file for alternate models.', + "--config", + "-c", + "-config", + dest="conf", + default="./configs/models.yaml", + help="Path to configuration file for alternate models.", ) model_group.add_argument( - '--model', + "--model", help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)', ) model_group.add_argument( - '--weight_dirs', - nargs='+', + "--weight_dirs", + nargs="+", type=str, - help='List of one or more directories that will be auto-scanned for new model weights to import', + help="List of one or more directories that will be auto-scanned for new model weights to import", ) model_group.add_argument( - '--png_compression','-z', + "--png_compression", + "-z", type=int, default=6, - choices=range(0,9), - dest='png_compression', - help='level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.' + choices=range(0, 9), + dest="png_compression", + help="level of PNG compression, from 0 (none) to 9 (maximum). Default is 6.", ) model_group.add_argument( - '-F', - '--full_precision', - dest='full_precision', - action='store_true', - help='Deprecated way to set --precision=float32', + "-F", + "--full_precision", + dest="full_precision", + action="store_true", + help="Deprecated way to set --precision=float32", ) model_group.add_argument( - '--max_loaded_models', - dest='max_loaded_models', + "--max_loaded_models", + dest="max_loaded_models", type=int, default=2, - help='Maximum number of models to keep in memory for fast switching, including the one in GPU', + help="Maximum number of models to keep in memory for fast switching, including the one in GPU", ) model_group.add_argument( - '--free_gpu_mem', - dest='free_gpu_mem', - action='store_true', - help='Force free gpu memory before final decoding', + "--free_gpu_mem", + dest="free_gpu_mem", + action="store_true", + help="Force free gpu memory before final decoding", ) model_group.add_argument( - '--sequential_guidance', - dest='sequential_guidance', - action='store_true', - help="Calculate guidance in serial instead of in parallel, lowering memory requirement " - "at the expense of speed", + "--sequential_guidance", + dest="sequential_guidance", + action="store_true", + help="Calculate guidance in serial instead of in parallel, lowering memory requirement " "at the expense of speed", ) model_group.add_argument( - '--xformers', + "--xformers", action=argparse.BooleanOptionalAction, default=True, - help='Enable/disable xformers support (default enabled if installed)', + help="Enable/disable xformers support (default enabled if installed)", ) model_group.add_argument( - "--always_use_cpu", - dest="always_use_cpu", - action="store_true", - help="Force use of CPU even if GPU is available" + "--always_use_cpu", dest="always_use_cpu", action="store_true", help="Force use of CPU even if GPU is available" ) model_group.add_argument( - '--precision', - dest='precision', + "--precision", + dest="precision", type=str, choices=PRECISION_CHOICES, - metavar='PRECISION', + metavar="PRECISION", help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', - default='auto', + default="auto", ) model_group.add_argument( - '--ckpt_convert', + "--ckpt_convert", action=argparse.BooleanOptionalAction, - dest='ckpt_convert', + dest="ckpt_convert", default=True, - help='Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.' + help="Deprecated option. Legacy ckpt files are now always converted to diffusers when loaded.", ) model_group.add_argument( - '--internet', + "--internet", action=argparse.BooleanOptionalAction, - dest='internet_available', + dest="internet_available", default=True, - help='Indicate whether internet is available for just-in-time model downloading (default: probe automatically).', + help="Indicate whether internet is available for just-in-time model downloading (default: probe automatically).", ) model_group.add_argument( - '--nsfw_checker', - '--safety_checker', + "--nsfw_checker", + "--safety_checker", action=argparse.BooleanOptionalAction, - dest='safety_checker', + dest="safety_checker", default=False, - help='Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.', + help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.", ) model_group.add_argument( - '--autoimport', + "--autoimport", default=None, type=str, - help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly', + help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly", ) model_group.add_argument( - '--autoconvert', + "--autoconvert", default=None, type=str, - help='Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models', + help="Check the indicated directory for .ckpt/.safetensors weights files at startup and import as optimized diffuser models", ) model_group.add_argument( - '--patchmatch', + "--patchmatch", action=argparse.BooleanOptionalAction, default=True, - help='Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.', + help="Load the patchmatch extension for outpainting. Use --no-patchmatch to disable.", ) file_group.add_argument( - '--from_file', - dest='infile', + "--from_file", + dest="infile", type=str, - help='If specified, load prompts from this file', + help="If specified, load prompts from this file", ) file_group.add_argument( - '--outdir', - '-o', + "--outdir", + "-o", type=str, - help='Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs', - default='outputs', + help="Directory to save generated images and a log of prompts and seeds. Default: ROOTDIR/outputs", + default="outputs", ) file_group.add_argument( - '--prompt_as_dir', - '-p', - action='store_true', - help='Place images in subdirectories named after the prompt.', + "--prompt_as_dir", + "-p", + action="store_true", + help="Place images in subdirectories named after the prompt.", ) render_group.add_argument( - '--fnformat', - default='{prefix}.{seed}.png', + "--fnformat", + default="{prefix}.{seed}.png", type=str, - help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png', + help="Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png", ) +render_group.add_argument("-s", "--steps", type=int, default=50, help="Number of steps") render_group.add_argument( - '-s', - '--steps', + "-W", + "--width", type=int, - default=50, - help='Number of steps' + help="Image width, multiple of 64", ) render_group.add_argument( - '-W', - '--width', + "-H", + "--height", type=int, - help='Image width, multiple of 64', + help="Image height, multiple of 64", ) render_group.add_argument( - '-H', - '--height', - type=int, - help='Image height, multiple of 64', -) -render_group.add_argument( - '-C', - '--cfg_scale', + "-C", + "--cfg_scale", default=7.5, type=float, help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', ) render_group.add_argument( - '--sampler', - '-A', - '-m', - dest='sampler_name', + "--sampler", + "-A", + "-m", + dest="sampler_name", type=str, choices=SAMPLER_CHOICES, - metavar='SAMPLER_NAME', + metavar="SAMPLER_NAME", help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', - default='k_lms', + default="k_lms", ) render_group.add_argument( - '--log_tokenization', - '-t', - action='store_true', - help='shows how the prompt is split into tokens' + "--log_tokenization", "-t", action="store_true", help="shows how the prompt is split into tokens" ) render_group.add_argument( - '-f', - '--strength', + "-f", + "--strength", type=float, - help='img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', + help="img2img strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely", ) render_group.add_argument( - '-T', - '-fit', - '--fit', + "-T", + "-fit", + "--fit", action=argparse.BooleanOptionalAction, - help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)', + help="If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)", ) +render_group.add_argument("--grid", "-g", action=argparse.BooleanOptionalAction, help="generate a grid") render_group.add_argument( - '--grid', - '-g', - action=argparse.BooleanOptionalAction, - help='generate a grid' -) -render_group.add_argument( - '--embedding_directory', - '--embedding_path', - dest='embedding_path', - default='embeddings', + "--embedding_directory", + "--embedding_path", + dest="embedding_path", + default="embeddings", type=str, - help='Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)' + help="Path to a directory containing .bin and/or .pt files, or a single .bin/.pt file. You may use subdirectories. (default is ROOTDIR/embeddings)", ) render_group.add_argument( - '--lora_directory', - dest='lora_path', - default='loras', + "--lora_directory", + dest="lora_path", + default="loras", type=str, - help='Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)' + help="Path to a directory containing LoRA files; subdirectories are not supported. (default is ROOTDIR/loras)", ) render_group.add_argument( - '--embeddings', + "--embeddings", action=argparse.BooleanOptionalAction, default=True, - help='Enable embedding directory (default). Use --no-embeddings to disable.', + help="Enable embedding directory (default). Use --no-embeddings to disable.", ) +render_group.add_argument("--enable_image_debugging", action="store_true", help="Generates debugging image to display") render_group.add_argument( - '--enable_image_debugging', - action='store_true', - help='Generates debugging image to display' -) -render_group.add_argument( - '--karras_max', + "--karras_max", type=int, default=None, - help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29]." + help="control the point at which the K* samplers will shift from using the Karras noise schedule (good for low step counts) to the LatentDiffusion noise schedule (good for high step counts). Set to 0 to use LatentDiffusion for all step values, and to a high value (e.g. 1000) to use Karras for all step values. [29].", ) # Restoration related args postprocessing_group.add_argument( - '--no_restore', - dest='restore', - action='store_false', - help='Disable face restoration with GFPGAN or codeformer', + "--no_restore", + dest="restore", + action="store_false", + help="Disable face restoration with GFPGAN or codeformer", ) postprocessing_group.add_argument( - '--no_upscale', - dest='esrgan', - action='store_false', - help='Disable upscaling with ESRGAN', + "--no_upscale", + dest="esrgan", + action="store_false", + help="Disable upscaling with ESRGAN", ) postprocessing_group.add_argument( - '--esrgan_bg_tile', + "--esrgan_bg_tile", type=int, default=400, - help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', + help="Tile size for background sampler, 0 for no tile during testing. Default: 400.", ) postprocessing_group.add_argument( - '--esrgan_denoise_str', + "--esrgan_denoise_str", type=float, default=0.75, - help='esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75', + help="esrgan denoise str. 0 is no denoise, 1 is max denoise. Default: 0.75", ) postprocessing_group.add_argument( - '--gfpgan_model_path', + "--gfpgan_model_path", type=str, - default='./models/gfpgan/GFPGANv1.4.pth', - help='Indicates the path to the GFPGAN model', + default="./models/gfpgan/GFPGANv1.4.pth", + help="Indicates the path to the GFPGAN model", ) web_server_group.add_argument( - '--web', - dest='web', - action='store_true', - help='Start in web server mode.', + "--web", + dest="web", + action="store_true", + help="Start in web server mode.", ) web_server_group.add_argument( - '--web_develop', - dest='web_develop', - action='store_true', - help='Start in web server development mode.', + "--web_develop", + dest="web_develop", + action="store_true", + help="Start in web server development mode.", ) web_server_group.add_argument( "--web_verbose", @@ -376,32 +352,27 @@ web_server_group.add_argument( help="Additional allowed origins, comma-separated", ) web_server_group.add_argument( - '--host', + "--host", type=str, - default='127.0.0.1', - help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.' + default="127.0.0.1", + help="Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.", ) +web_server_group.add_argument("--port", type=int, default="9090", help="Web server: Port to listen on") web_server_group.add_argument( - '--port', - type=int, - default='9090', - help='Web server: Port to listen on' -) -web_server_group.add_argument( - '--certfile', + "--certfile", type=str, default=None, - help='Web server: Path to certificate file to use for SSL. Use together with --keyfile' + help="Web server: Path to certificate file to use for SSL. Use together with --keyfile", ) web_server_group.add_argument( - '--keyfile', + "--keyfile", type=str, default=None, - help='Web server: Path to private key file to use for SSL. Use together with --certfile' + help="Web server: Path to private key file to use for SSL. Use together with --certfile", ) web_server_group.add_argument( - '--gui', - dest='gui', - action='store_true', - help='Start InvokeAI GUI', + "--gui", + dest="gui", + action="store_true", + help="Start InvokeAI GUI", ) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index 85de3f5e69..9152e46951 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -1,7 +1,7 @@ -''' +""" Migrate the models directory and models.yaml file from an existing InvokeAI 2.3 installation to 3.0.0. -''' +""" import os import argparse @@ -29,14 +29,13 @@ from transformers import ( import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_management import ModelManager -from invokeai.backend.model_management.model_probe import ( - ModelProbe, ModelType, BaseModelType, ModelProbeInfo - ) +from invokeai.backend.model_management.model_probe import ModelProbe, ModelType, BaseModelType, ModelProbeInfo warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() diffusers.logging.set_verbosity_error() + # holder for paths that we will migrate @dataclass class ModelPaths: @@ -45,81 +44,82 @@ class ModelPaths: loras: Path controlnets: Path + class MigrateTo3(object): - def __init__(self, - from_root: Path, - to_models: Path, - model_manager: ModelManager, - src_paths: ModelPaths, - ): + def __init__( + self, + from_root: Path, + to_models: Path, + model_manager: ModelManager, + src_paths: ModelPaths, + ): self.root_directory = from_root self.dest_models = to_models self.mgr = model_manager self.src_paths = src_paths - + @classmethod def initialize_yaml(cls, yaml_file: Path): - with open(yaml_file, 'w') as file: - file.write( - yaml.dump( - { - '__metadata__': {'version':'3.0.0'} - } - ) - ) - + with open(yaml_file, "w") as file: + file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) + def create_directory_structure(self): - ''' + """ Create the basic directory structure for the models folder. - ''' - for model_base in [BaseModelType.StableDiffusion1,BaseModelType.StableDiffusion2]: - for model_type in [ModelType.Main, ModelType.Vae, ModelType.Lora, - ModelType.ControlNet,ModelType.TextualInversion]: + """ + for model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: + for model_type in [ + ModelType.Main, + ModelType.Vae, + ModelType.Lora, + ModelType.ControlNet, + ModelType.TextualInversion, + ]: path = self.dest_models / model_base.value / model_type.value path.mkdir(parents=True, exist_ok=True) - path = self.dest_models / 'core' + path = self.dest_models / "core" path.mkdir(parents=True, exist_ok=True) @staticmethod - def copy_file(src:Path,dest:Path): - ''' + def copy_file(src: Path, dest: Path): + """ copy a single file with logging - ''' + """ if dest.exists(): - logger.info(f'Skipping existing {str(dest)}') + logger.info(f"Skipping existing {str(dest)}") return - logger.info(f'Copying {str(src)} to {str(dest)}') + logger.info(f"Copying {str(src)} to {str(dest)}") try: shutil.copy(src, dest) except Exception as e: - logger.error(f'COPY FAILED: {str(e)}') + logger.error(f"COPY FAILED: {str(e)}") @staticmethod - def copy_dir(src:Path,dest:Path): - ''' + def copy_dir(src: Path, dest: Path): + """ Recursively copy a directory with logging - ''' + """ if dest.exists(): - logger.info(f'Skipping existing {str(dest)}') + logger.info(f"Skipping existing {str(dest)}") return - - logger.info(f'Copying {str(src)} to {str(dest)}') + + logger.info(f"Copying {str(src)} to {str(dest)}") try: shutil.copytree(src, dest) except Exception as e: - logger.error(f'COPY FAILED: {str(e)}') + logger.error(f"COPY FAILED: {str(e)}") def migrate_models(self, src_dir: Path): - ''' + """ Recursively walk through src directory, probe anything that looks like a model, and copy the model into the appropriate location within the destination models directory. - ''' + """ directories_scanned = set() for root, dirs, files in os.walk(src_dir): for d in dirs: try: - model = Path(root,d) + model = Path(root, d) info = ModelProbe().heuristic_probe(model) if not info: continue @@ -136,9 +136,9 @@ class MigrateTo3(object): # don't copy raw learned_embeds.bin or pytorch_lora_weights.bin # let them be copied as part of a tree copy operation try: - if f in {'learned_embeds.bin','pytorch_lora_weights.bin'}: + if f in {"learned_embeds.bin", "pytorch_lora_weights.bin"}: continue - model = Path(root,f) + model = Path(root, f) if model.parent in directories_scanned: continue info = ModelProbe().heuristic_probe(model) @@ -154,148 +154,146 @@ class MigrateTo3(object): logger.error(str(e)) def migrate_support_models(self): - ''' + """ Copy the clipseg, upscaler, and restoration models to their new locations. - ''' + """ dest_directory = self.dest_models - if (self.root_directory / 'models/clipseg').exists(): - self.copy_dir(self.root_directory / 'models/clipseg', dest_directory / 'core/misc/clipseg') - if (self.root_directory / 'models/realesrgan').exists(): - self.copy_dir(self.root_directory / 'models/realesrgan', dest_directory / 'core/upscaling/realesrgan') - for d in ['codeformer','gfpgan']: - path = self.root_directory / 'models' / d + if (self.root_directory / "models/clipseg").exists(): + self.copy_dir(self.root_directory / "models/clipseg", dest_directory / "core/misc/clipseg") + if (self.root_directory / "models/realesrgan").exists(): + self.copy_dir(self.root_directory / "models/realesrgan", dest_directory / "core/upscaling/realesrgan") + for d in ["codeformer", "gfpgan"]: + path = self.root_directory / "models" / d if path.exists(): - self.copy_dir(path,dest_directory / f'core/face_restoration/{d}') + self.copy_dir(path, dest_directory / f"core/face_restoration/{d}") def migrate_tuning_models(self): - ''' + """ Migrate the embeddings, loras and controlnets directories to their new homes. - ''' + """ for src in [self.src_paths.embeddings, self.src_paths.loras, self.src_paths.controlnets]: if not src: continue if src.is_dir(): - logger.info(f'Scanning {src}') + logger.info(f"Scanning {src}") self.migrate_models(src) else: - logger.info(f'{src} directory not found; skipping') + logger.info(f"{src} directory not found; skipping") continue def migrate_conversion_models(self): - ''' + """ Migrate all the models that are needed by the ckpt_to_diffusers conversion script. - ''' + """ dest_directory = self.dest_models kwargs = dict( - cache_dir = self.root_directory / 'models/hub', - #local_files_only = True + cache_dir=self.root_directory / "models/hub", + # local_files_only = True ) try: - logger.info('Migrating core tokenizers and text encoders') - target_dir = dest_directory / 'core' / 'convert' + logger.info("Migrating core tokenizers and text encoders") + target_dir = dest_directory / "core" / "convert" - self._migrate_pretrained(BertTokenizerFast, - repo_id='bert-base-uncased', - dest = target_dir / 'bert-base-uncased', - **kwargs) + self._migrate_pretrained( + BertTokenizerFast, repo_id="bert-base-uncased", dest=target_dir / "bert-base-uncased", **kwargs + ) # sd-1 - repo_id = 'openai/clip-vit-large-patch14' - self._migrate_pretrained(CLIPTokenizer, - repo_id= repo_id, - dest= target_dir / 'clip-vit-large-patch14', - **kwargs) - self._migrate_pretrained(CLIPTextModel, - repo_id = repo_id, - dest = target_dir / 'clip-vit-large-patch14', - force = True, - **kwargs) + repo_id = "openai/clip-vit-large-patch14" + self._migrate_pretrained( + CLIPTokenizer, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", **kwargs + ) + self._migrate_pretrained( + CLIPTextModel, repo_id=repo_id, dest=target_dir / "clip-vit-large-patch14", force=True, **kwargs + ) # sd-2 repo_id = "stabilityai/stable-diffusion-2" - self._migrate_pretrained(CLIPTokenizer, - repo_id = repo_id, - dest = target_dir / 'stable-diffusion-2-clip' / 'tokenizer', - **{'subfolder':'tokenizer',**kwargs} - ) - self._migrate_pretrained(CLIPTextModel, - repo_id = repo_id, - dest = target_dir / 'stable-diffusion-2-clip' / 'text_encoder', - **{'subfolder':'text_encoder',**kwargs} - ) + self._migrate_pretrained( + CLIPTokenizer, + repo_id=repo_id, + dest=target_dir / "stable-diffusion-2-clip" / "tokenizer", + **{"subfolder": "tokenizer", **kwargs}, + ) + self._migrate_pretrained( + CLIPTextModel, + repo_id=repo_id, + dest=target_dir / "stable-diffusion-2-clip" / "text_encoder", + **{"subfolder": "text_encoder", **kwargs}, + ) # VAE - logger.info('Migrating stable diffusion VAE') - self._migrate_pretrained(AutoencoderKL, - repo_id = 'stabilityai/sd-vae-ft-mse', - dest = target_dir / 'sd-vae-ft-mse', - **kwargs) - + logger.info("Migrating stable diffusion VAE") + self._migrate_pretrained( + AutoencoderKL, repo_id="stabilityai/sd-vae-ft-mse", dest=target_dir / "sd-vae-ft-mse", **kwargs + ) + # safety checking - logger.info('Migrating safety checker') + logger.info("Migrating safety checker") repo_id = "CompVis/stable-diffusion-safety-checker" - self._migrate_pretrained(AutoFeatureExtractor, - repo_id = repo_id, - dest = target_dir / 'stable-diffusion-safety-checker', - **kwargs) - self._migrate_pretrained(StableDiffusionSafetyChecker, - repo_id = repo_id, - dest = target_dir / 'stable-diffusion-safety-checker', - **kwargs) + self._migrate_pretrained( + AutoFeatureExtractor, repo_id=repo_id, dest=target_dir / "stable-diffusion-safety-checker", **kwargs + ) + self._migrate_pretrained( + StableDiffusionSafetyChecker, + repo_id=repo_id, + dest=target_dir / "stable-diffusion-safety-checker", + **kwargs, + ) except KeyboardInterrupt: raise except Exception as e: logger.error(str(e)) - def _model_probe_to_path(self, info: ModelProbeInfo)->Path: + def _model_probe_to_path(self, info: ModelProbeInfo) -> Path: return Path(self.dest_models, info.base_type.value, info.model_type.value) - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs): + def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force: bool = False, **kwargs): if dest.exists() and not force: - logger.info(f'Skipping existing {dest}') + logger.info(f"Skipping existing {dest}") return model = model_class.from_pretrained(repo_id, **kwargs) self._save_pretrained(model, dest, overwrite=force) - def _save_pretrained(self, model, dest: Path, overwrite: bool=False): + def _save_pretrained(self, model, dest: Path, overwrite: bool = False): model_name = dest.name if overwrite: model.save_pretrained(dest, safe_serialization=True) else: - download_path = dest.with_name(f'{model_name}.downloading') + download_path = dest.with_name(f"{model_name}.downloading") model.save_pretrained(download_path, safe_serialization=True) download_path.replace(dest) - def _download_vae(self, repo_id: str, subfolder:str=None)->Path: - vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) + def _download_vae(self, repo_id: str, subfolder: str = None) -> Path: + vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / "models/hub", subfolder=subfolder) info = ModelProbe().heuristic_probe(vae) - _, model_name = repo_id.split('/') + _, model_name = repo_id.split("/") dest = self._model_probe_to_path(info) / self.unique_name(model_name, info) vae.save_pretrained(dest, safe_serialization=True) return dest - def _vae_path(self, vae: Union[str,dict])->Path: - ''' + def _vae_path(self, vae: Union[str, dict]) -> Path: + """ Convert 2.3 VAE stanza to a straight path. - ''' + """ vae_path = None - + # First get a path - if isinstance(vae,str): + if isinstance(vae, str): vae_path = vae - elif isinstance(vae,DictConfig): - if p := vae.get('path'): + elif isinstance(vae, DictConfig): + if p := vae.get("path"): vae_path = p - elif repo_id := vae.get('repo_id'): - if repo_id=='stabilityai/sd-vae-ft-mse': # this guy is already downloaded - vae_path = 'models/core/convert/sd-vae-ft-mse' + elif repo_id := vae.get("repo_id"): + if repo_id == "stabilityai/sd-vae-ft-mse": # this guy is already downloaded + vae_path = "models/core/convert/sd-vae-ft-mse" return vae_path else: - vae_path = self._download_vae(repo_id, vae.get('subfolder')) + vae_path = self._download_vae(repo_id, vae.get("subfolder")) assert vae_path is not None, "Couldn't find VAE for this model" @@ -307,152 +305,144 @@ class MigrateTo3(object): dest = self._model_probe_to_path(info) / vae_path.name if not dest.exists(): if vae_path.is_dir(): - self.copy_dir(vae_path,dest) + self.copy_dir(vae_path, dest) else: - self.copy_file(vae_path,dest) + self.copy_file(vae_path, dest) vae_path = dest if vae_path.is_relative_to(self.dest_models): rel_path = vae_path.relative_to(self.dest_models) - return Path('models',rel_path) + return Path("models", rel_path) else: return vae_path - def migrate_repo_id(self, repo_id: str, model_name: str=None, **extra_config): - ''' + def migrate_repo_id(self, repo_id: str, model_name: str = None, **extra_config): + """ Migrate a locally-cached diffusers pipeline identified with a repo_id - ''' + """ dest_dir = self.dest_models - - cache = self.root_directory / 'models/hub' + + cache = self.root_directory / "models/hub" kwargs = dict( - cache_dir = cache, - safety_checker = None, + cache_dir=cache, + safety_checker=None, # local_files_only = True, ) - owner,repo_name = repo_id.split('/') + owner, repo_name = repo_id.split("/") model_name = model_name or repo_name - model = cache / '--'.join(['models',owner,repo_name]) - - if len(list(model.glob('snapshots/**/model_index.json')))==0: + model = cache / "--".join(["models", owner, repo_name]) + + if len(list(model.glob("snapshots/**/model_index.json"))) == 0: return - revisions = [x.name for x in model.glob('refs/*')] + revisions = [x.name for x in model.glob("refs/*")] # if an fp16 is available we use that - revision = 'fp16' if len(revisions) > 1 and 'fp16' in revisions else revisions[0] - pipeline = StableDiffusionPipeline.from_pretrained( - repo_id, - revision=revision, - **kwargs) + revision = "fp16" if len(revisions) > 1 and "fp16" in revisions else revisions[0] + pipeline = StableDiffusionPipeline.from_pretrained(repo_id, revision=revision, **kwargs) info = ModelProbe().heuristic_probe(pipeline) if not info: return if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.') + logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") return dest = self._model_probe_to_path(info) / model_name self._save_pretrained(pipeline, dest) - - rel_path = Path('models',dest.relative_to(dest_dir)) + + rel_path = Path("models", dest.relative_to(dest_dir)) self._add_model(model_name, info, rel_path, **extra_config) - def migrate_path(self, location: Path, model_name: str=None, **extra_config): - ''' + def migrate_path(self, location: Path, model_name: str = None, **extra_config): + """ Migrate a model referred to using 'weights' or 'path' - ''' + """ # handle relative paths dest_dir = self.dest_models location = self.root_directory / location model_name = model_name or location.stem - + info = ModelProbe().heuristic_probe(location) if not info: return - + if self.mgr.model_exists(model_name, info.base_type, info.model_type): - logger.warning(f'A model named {model_name} already exists at the destination. Skipping migration.') + logger.warning(f"A model named {model_name} already exists at the destination. Skipping migration.") return # uh oh, weights is in the old models directory - move it into the new one if Path(location).is_relative_to(self.src_paths.models): dest = Path(dest_dir, info.base_type.value, info.model_type.value, location.name) if location.is_dir(): - self.copy_dir(location,dest) + self.copy_dir(location, dest) else: - self.copy_file(location,dest) - location = Path('models', info.base_type.value, info.model_type.value, location.name) + self.copy_file(location, dest) + location = Path("models", info.base_type.value, info.model_type.value, location.name) self._add_model(model_name, info, location, **extra_config) - def _add_model(self, - model_name: str, - info: ModelProbeInfo, - location: Path, - **extra_config): + def _add_model(self, model_name: str, info: ModelProbeInfo, location: Path, **extra_config): if info.model_type != ModelType.Main: return - - self.mgr.add_model( - model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - clobber = True, - model_attributes = { - 'path': str(location), - 'description': f'A {info.base_type.value} {info.model_type.value} model', - 'model_format': info.format, - 'variant': info.variant_type.value, - **extra_config, - } - ) - - def migrate_defined_models(self): - ''' - Migrate models defined in models.yaml - ''' - # find any models referred to in old models.yaml - conf = OmegaConf.load(self.root_directory / 'configs/models.yaml') - - for model_name, stanza in conf.items(): + self.mgr.add_model( + model_name=model_name, + base_model=info.base_type, + model_type=info.model_type, + clobber=True, + model_attributes={ + "path": str(location), + "description": f"A {info.base_type.value} {info.model_type.value} model", + "model_format": info.format, + "variant": info.variant_type.value, + **extra_config, + }, + ) + + def migrate_defined_models(self): + """ + Migrate models defined in models.yaml + """ + # find any models referred to in old models.yaml + conf = OmegaConf.load(self.root_directory / "configs/models.yaml") + + for model_name, stanza in conf.items(): try: passthru_args = {} - - if vae := stanza.get('vae'): + + if vae := stanza.get("vae"): try: - passthru_args['vae'] = str(self._vae_path(vae)) + passthru_args["vae"] = str(self._vae_path(vae)) except Exception as e: logger.warning(f'Could not find a VAE matching "{vae}" for model "{model_name}"') logger.warning(str(e)) - if config := stanza.get('config'): - passthru_args['config'] = config + if config := stanza.get("config"): + passthru_args["config"] = config - if description:= stanza.get('description'): - passthru_args['description'] = description - - if repo_id := stanza.get('repo_id'): - logger.info(f'Migrating diffusers model {model_name}') + if description := stanza.get("description"): + passthru_args["description"] = description + + if repo_id := stanza.get("repo_id"): + logger.info(f"Migrating diffusers model {model_name}") self.migrate_repo_id(repo_id, model_name, **passthru_args) - elif location := stanza.get('weights'): - logger.info(f'Migrating checkpoint model {model_name}') + elif location := stanza.get("weights"): + logger.info(f"Migrating checkpoint model {model_name}") self.migrate_path(Path(location), model_name, **passthru_args) - - elif location := stanza.get('path'): - logger.info(f'Migrating diffusers model {model_name}') + + elif location := stanza.get("path"): + logger.info(f"Migrating diffusers model {model_name}") self.migrate_path(Path(location), model_name, **passthru_args) - + except KeyboardInterrupt: raise except Exception as e: logger.error(str(e)) - + def migrate(self): self.create_directory_structure() # the configure script is doing this @@ -461,67 +451,71 @@ class MigrateTo3(object): self.migrate_tuning_models() self.migrate_defined_models() -def _parse_legacy_initfile(root: Path, initfile: Path)->ModelPaths: - ''' + +def _parse_legacy_initfile(root: Path, initfile: Path) -> ModelPaths: + """ Returns tuple of (embedding_path, lora_path, controlnet_path) - ''' - parser = argparse.ArgumentParser(fromfile_prefix_chars='@') + """ + parser = argparse.ArgumentParser(fromfile_prefix_chars="@") parser.add_argument( - '--embedding_directory', - '--embedding_path', + "--embedding_directory", + "--embedding_path", type=Path, - dest='embedding_path', - default=Path('embeddings'), + dest="embedding_path", + default=Path("embeddings"), ) parser.add_argument( - '--lora_directory', - dest='lora_path', + "--lora_directory", + dest="lora_path", type=Path, - default=Path('loras'), + default=Path("loras"), ) - opt,_ = parser.parse_known_args([f'@{str(initfile)}']) + opt, _ = parser.parse_known_args([f"@{str(initfile)}"]) return ModelPaths( - models = root / 'models', - embeddings = root / str(opt.embedding_path).strip('"'), - loras = root / str(opt.lora_path).strip('"'), - controlnets = root / 'controlnets', + models=root / "models", + embeddings=root / str(opt.embedding_path).strip('"'), + loras=root / str(opt.lora_path).strip('"'), + controlnets=root / "controlnets", ) -def _parse_legacy_yamlfile(root: Path, initfile: Path)->ModelPaths: - ''' + +def _parse_legacy_yamlfile(root: Path, initfile: Path) -> ModelPaths: + """ Returns tuple of (embedding_path, lora_path, controlnet_path) - ''' + """ # Don't use the config object because it is unforgiving of version updates # Just use omegaconf directly opt = OmegaConf.load(initfile) paths = opt.InvokeAI.Paths - models = paths.get('models_dir','models') - embeddings = paths.get('embedding_dir','embeddings') - loras = paths.get('lora_dir','loras') - controlnets = paths.get('controlnet_dir','controlnets') + models = paths.get("models_dir", "models") + embeddings = paths.get("embedding_dir", "embeddings") + loras = paths.get("lora_dir", "loras") + controlnets = paths.get("controlnet_dir", "controlnets") return ModelPaths( - models = root / models, - embeddings = root / embeddings, - loras = root /loras, - controlnets = root / controlnets, + models=root / models, + embeddings=root / embeddings, + loras=root / loras, + controlnets=root / controlnets, ) - + + def get_legacy_embeddings(root: Path) -> ModelPaths: - path = root / 'invokeai.init' + path = root / "invokeai.init" if path.exists(): return _parse_legacy_initfile(root, path) - path = root / 'invokeai.yaml' + path = root / "invokeai.yaml" if path.exists(): return _parse_legacy_yamlfile(root, path) + def do_migrate(src_directory: Path, dest_directory: Path): """ Migrate models from src to dest InvokeAI root directories """ - config_file = dest_directory / 'configs' / 'models.yaml.3' - dest_models = dest_directory / 'models.3' - - version_3 = (dest_directory / 'models' / 'core').exists() + config_file = dest_directory / "configs" / "models.yaml.3" + dest_models = dest_directory / "models.3" + + version_3 = (dest_directory / "models" / "core").exists() # Here we create the destination models.yaml file. # If we are writing into a version 3 directory and the @@ -530,80 +524,80 @@ def do_migrate(src_directory: Path, dest_directory: Path): # create a new empty one. if version_3: # write into the dest directory try: - shutil.copy(dest_directory / 'configs' / 'models.yaml', config_file) + shutil.copy(dest_directory / "configs" / "models.yaml", config_file) except: MigrateTo3.initialize_yaml(config_file) - mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory - (dest_directory / 'models').replace(dest_models) + mgr = ModelManager(config_file) # important to initialize BEFORE moving the models directory + (dest_directory / "models").replace(dest_models) else: MigrateTo3.initialize_yaml(config_file) mgr = ModelManager(config_file) - + paths = get_legacy_embeddings(src_directory) - migrator = MigrateTo3( - from_root = src_directory, - to_models = dest_models, - model_manager = mgr, - src_paths = paths - ) + migrator = MigrateTo3(from_root=src_directory, to_models=dest_models, model_manager=mgr, src_paths=paths) migrator.migrate() print("Migration successful.") if not version_3: - (dest_directory / 'models').replace(src_directory / 'models.orig') - print(f'Original models directory moved to {dest_directory}/models.orig') - - (dest_directory / 'configs' / 'models.yaml').replace(src_directory / 'configs' / 'models.yaml.orig') - print(f'Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig') - - config_file.replace(config_file.with_suffix('')) - dest_models.replace(dest_models.with_suffix('')) - + (dest_directory / "models").replace(src_directory / "models.orig") + print(f"Original models directory moved to {dest_directory}/models.orig") + + (dest_directory / "configs" / "models.yaml").replace(src_directory / "configs" / "models.yaml.orig") + print(f"Original models.yaml file moved to {dest_directory}/configs/models.yaml.orig") + + config_file.replace(config_file.with_suffix("")) + dest_models.replace(dest_models.with_suffix("")) + + def main(): - parser = argparse.ArgumentParser(prog="invokeai-migrate3", - description=""" + parser = argparse.ArgumentParser( + prog="invokeai-migrate3", + description=""" This will copy and convert the models directory and the configs/models.yaml from the InvokeAI 2.3 format '--from-directory' root to the InvokeAI 3.0 '--to-directory' root. These may be abbreviated '--from' and '--to'.a The old models directory and config file will be renamed 'models.orig' and 'models.yaml.orig' respectively. It is safe to provide the same directory for both arguments, but it is better to use the invokeai_configure -script, which will perform a full upgrade in place.""" - ) - parser.add_argument('--from-directory', - dest='src_root', - type=Path, - required=True, - help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")' - ) - parser.add_argument('--to-directory', - dest='dest_root', - type=Path, - required=True, - help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")' - ) +script, which will perform a full upgrade in place.""", + ) + parser.add_argument( + "--from-directory", + dest="src_root", + type=Path, + required=True, + help='Source InvokeAI 2.3 root directory (containing "invokeai.init" or "invokeai.yaml")', + ) + parser.add_argument( + "--to-directory", + dest="dest_root", + type=Path, + required=True, + help='Destination InvokeAI 3.0 directory (containing "invokeai.yaml")', + ) args = parser.parse_args() src_root = args.src_root assert src_root.is_dir(), f"{src_root} is not a valid directory" - assert (src_root / 'models').is_dir(), f"{src_root} does not contain a 'models' subdirectory" - assert (src_root / 'models' / 'hub').exists(), f"{src_root} does not contain a version 2.3 models directory" - assert (src_root / 'invokeai.init').exists() or (src_root / 'invokeai.yaml').exists(), f"{src_root} does not contain an InvokeAI init file." + assert (src_root / "models").is_dir(), f"{src_root} does not contain a 'models' subdirectory" + assert (src_root / "models" / "hub").exists(), f"{src_root} does not contain a version 2.3 models directory" + assert (src_root / "invokeai.init").exists() or ( + src_root / "invokeai.yaml" + ).exists(), f"{src_root} does not contain an InvokeAI init file." dest_root = args.dest_root assert dest_root.is_dir(), f"{dest_root} is not a valid directory" config = InvokeAIAppConfig.get_config() - config.parse_args(['--root',str(dest_root)]) + config.parse_args(["--root", str(dest_root)]) # TODO: revisit - don't rely on invokeai.yaml to exist yet! - dest_is_setup = (dest_root / 'models/core').exists() and (dest_root / 'databases').exists() + dest_is_setup = (dest_root / "models/core").exists() and (dest_root / "databases").exists() if not dest_is_setup: import invokeai.frontend.install.invokeai_configure from invokeai.backend.install.invokeai_configure import initialize_rootdir + initialize_rootdir(dest_root, True) - do_migrate(src_root,dest_root) + do_migrate(src_root, dest_root) -if __name__ == '__main__': + +if __name__ == "__main__": main() - - - diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 5aed8cc034..b3ab88b5dd 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -4,7 +4,7 @@ Utility (backend) functions used by model_install.py import os import shutil import warnings -from dataclasses import dataclass,field +from dataclasses import dataclass, field from pathlib import Path from tempfile import TemporaryDirectory from typing import List, Dict, Callable, Union, Set @@ -28,7 +28,7 @@ warnings.filterwarnings("ignore") # --------------------------globals----------------------- config = InvokeAIAppConfig.get_config() -logger = InvokeAILogger.getLogger(name='InvokeAI') +logger = InvokeAILogger.getLogger(name="InvokeAI") # the initial "configs" dir is now bundled in the `invokeai.configs` package Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" @@ -45,59 +45,63 @@ Config_preamble = """ LEGACY_CONFIGS = { BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: 'v1-inference.yaml', - ModelVariantType.Inpaint: 'v1-inpainting-inference.yaml', + ModelVariantType.Normal: "v1-inference.yaml", + ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", }, - BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: 'v2-inference.yaml', - SchedulerPredictionType.VPrediction: 'v2-inference-v.yaml', + SchedulerPredictionType.Epsilon: "v2-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", }, ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: 'v2-inpainting-inference.yaml', - SchedulerPredictionType.VPrediction: 'v2-inpainting-inference-v.yaml', - } + SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", + SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", + }, }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: 'sd_xl_base.yaml', + ModelVariantType.Normal: "sd_xl_base.yaml", }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: 'sd_xl_refiner.yaml', + ModelVariantType.Normal: "sd_xl_refiner.yaml", }, } + @dataclass class ModelInstallList: - '''Class for listing models to be installed/removed''' + """Class for listing models to be installed/removed""" + install_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list) -@dataclass -class InstallSelections(): - install_models: List[str]= field(default_factory=list) - remove_models: List[str]=field(default_factory=list) @dataclass -class ModelLoadInfo(): +class InstallSelections: + install_models: List[str] = field(default_factory=list) + remove_models: List[str] = field(default_factory=list) + + +@dataclass +class ModelLoadInfo: name: str model_type: ModelType base_type: BaseModelType path: Path = None repo_id: str = None - description: str = '' + description: str = "" installed: bool = False recommended: bool = False default: bool = False + class ModelInstall(object): - def __init__(self, - config:InvokeAIAppConfig, - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - model_manager: ModelManager = None, - access_token:str = None): + def __init__( + self, + config: InvokeAIAppConfig, + prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, + model_manager: ModelManager = None, + access_token: str = None, + ): self.config = config self.mgr = model_manager or ModelManager(config.model_conf_path) self.datasets = OmegaConf.load(Dataset_path) @@ -105,66 +109,66 @@ class ModelInstall(object): self.access_token = access_token or HfFolder.get_token() self.reverse_paths = self._reverse_paths(self.datasets) - def all_models(self)->Dict[str,ModelLoadInfo]: - ''' + def all_models(self) -> Dict[str, ModelLoadInfo]: + """ Return dict of model_key=>ModelLoadInfo objects. This method consolidates and simplifies the entries in both models.yaml and INITIAL_MODELS.yaml so that they can be treated uniformly. It also sorts the models alphabetically by their name, to improve the display somewhat. - ''' + """ model_dict = dict() - + # first populate with the entries in INITIAL_MODELS.yaml for key, value in self.datasets.items(): - name,base,model_type = ModelManager.parse_key(key) - value['name'] = name - value['base_type'] = base - value['model_type'] = model_type + name, base, model_type = ModelManager.parse_key(key) + value["name"] = name + value["base_type"] = base + value["model_type"] = model_type model_dict[key] = ModelLoadInfo(**value) # supplement with entries in models.yaml installed_models = self.mgr.list_models() - + for md in installed_models: - base = md['base_model'] - model_type = md['model_type'] - name = md['model_name'] + base = md["base_model"] + model_type = md["model_type"] + name = md["model_name"] key = ModelManager.create_key(name, base, model_type) if key in model_dict: model_dict[key].installed = True else: model_dict[key] = ModelLoadInfo( - name = name, - base_type = base, - model_type = model_type, - path = value.get('path'), - installed = True, + name=name, + base_type=base, + model_type=model_type, + path=value.get("path"), + installed=True, ) - return {x : model_dict[x] for x in sorted(model_dict.keys(),key=lambda y: model_dict[y].name.lower())} + return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} def list_models(self, model_type): installed = self.mgr.list_models(model_type=model_type) - print(f'Installed models of type `{model_type}`:') + print(f"Installed models of type `{model_type}`:") for i in installed: print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}") # logic here a little reversed to maintain backward compatibility - def starter_models(self, all_models: bool=False)->Set[str]: + def starter_models(self, all_models: bool = False) -> Set[str]: models = set() for key, value in self.datasets.items(): - name,base,model_type = ModelManager.parse_key(key) + name, base, model_type = ModelManager.parse_key(key) if all_models or model_type in [ModelType.Main, ModelType.Vae]: models.add(key) return models - def recommended_models(self)->Set[str]: + def recommended_models(self) -> Set[str]: starters = self.starter_models(all_models=True) - return set([x for x in starters if self.datasets[x].get('recommended',False)]) - - def default_model(self)->str: + return set([x for x in starters if self.datasets[x].get("recommended", False)]) + + def default_model(self) -> str: starters = self.starter_models() - defaults = [x for x in starters if self.datasets[x].get('default',False)] + defaults = [x for x in starters if self.datasets[x].get("default", False)] return defaults[0] def install(self, selections: InstallSelections): @@ -173,54 +177,57 @@ class ModelInstall(object): job = 1 jobs = len(selections.remove_models) + len(selections.install_models) - + # remove requested models for key in selections.remove_models: - name,base,mtype = self.mgr.parse_key(key) - logger.info(f'Deleting {mtype} model {name} [{job}/{jobs}]') + name, base, mtype = self.mgr.parse_key(key) + logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") try: - self.mgr.del_model(name,base,mtype) + self.mgr.del_model(name, base, mtype) except FileNotFoundError as e: logger.warning(e) job += 1 - + # add requested models for path in selections.install_models: - logger.info(f'Installing {path} [{job}/{jobs}]') + logger.info(f"Installing {path} [{job}/{jobs}]") try: self.heuristic_import(path) except (ValueError, KeyError) as e: logger.error(str(e)) job += 1 - + dlogging.set_verbosity(verbosity) self.mgr.commit() - def heuristic_import(self, - model_path_id_or_url: Union[str,Path], - models_installed: Set[Path]=None, - )->Dict[str, AddModelResult]: - ''' + def heuristic_import( + self, + model_path_id_or_url: Union[str, Path], + models_installed: Set[Path] = None, + ) -> Dict[str, AddModelResult]: + """ :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param models_installed: Set of installed models, used for recursive invocation Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. - ''' + """ if not models_installed: models_installed = dict() - + # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url path = Path(model_path_id_or_url) # checkpoint file, or similar if path.is_file(): - models_installed.update({str(path):self._install_path(path)}) + models_installed.update({str(path): self._install_path(path)}) # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in \ - {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} - ] - ): + elif path.is_dir() and any( + [ + (path / x).exists() + for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"} + ] + ): models_installed.update({str(model_path_id_or_url): self._install_path(path)}) # recursive scan @@ -229,7 +236,7 @@ class ModelInstall(object): self.heuristic_import(child, models_installed=models_installed) # huggingface repo - elif len(str(model_path_id_or_url).split('/')) == 2: + elif len(str(model_path_id_or_url).split("/")) == 2: models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) # a URL @@ -237,42 +244,43 @@ class ModelInstall(object): models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) else: - raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') + raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") return models_installed # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. - def _install_path(self, path: Path, info: ModelProbeInfo=None)->AddModelResult: - info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) + def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: + info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) if not info: - logger.warning(f'Unable to parse format of {path}') + logger.warning(f"Unable to parse format of {path}") return None model_name = path.stem if path.is_file() else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') - attributes = self._make_attributes(path,info) - return self.mgr.add_model(model_name = model_name, - base_model = info.base_type, - model_type = info.model_type, - model_attributes = attributes, - ) + attributes = self._make_attributes(path, info) + return self.mgr.add_model( + model_name=model_name, + base_model=info.base_type, + model_type=info.model_type, + model_attributes=attributes, + ) - def _install_url(self, url: str)->AddModelResult: + def _install_url(self, url: str) -> AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: - location = download_with_resume(url,Path(staging)) + location = download_with_resume(url, Path(staging)) if not location: - logger.error(f'Unable to download {url}. Skipping.') + logger.error(f"Unable to download {url}. Skipping.") info = ModelProbe().heuristic_probe(location) dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name - models_path = shutil.move(location,dest) + models_path = shutil.move(location, dest) # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) - def _install_repo(self, repo_id: str)->AddModelResult: + def _install_repo(self, repo_id: str) -> AddModelResult: hinfo = HfApi().model_info(repo_id) - + # we try to figure out how to download this most economically # list all the files in the repo files = [x.rfilename for x in hinfo.siblings] @@ -280,42 +288,49 @@ class ModelInstall(object): with TemporaryDirectory(dir=self.config.models_path) as staging: staging = Path(staging) - if 'model_index.json' in files: - location = self._download_hf_pipeline(repo_id, staging) # pipeline + if "model_index.json" in files: + location = self._download_hf_pipeline(repo_id, staging) # pipeline else: - for suffix in ['safetensors','bin']: - if f'pytorch_lora_weights.{suffix}' in files: - location = self._download_hf_model(repo_id, ['pytorch_lora_weights.bin'], staging) # LoRA + for suffix in ["safetensors", "bin"]: + if f"pytorch_lora_weights.{suffix}" in files: + location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA break - elif self.config.precision=='float16' and f'diffusion_pytorch_model.fp16.{suffix}' in files: # vae, controlnet or some other standalone - files = ['config.json', f'diffusion_pytorch_model.fp16.{suffix}'] + elif ( + self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files + ): # vae, controlnet or some other standalone + files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] location = self._download_hf_model(repo_id, files, staging) break - elif f'diffusion_pytorch_model.{suffix}' in files: - files = ['config.json', f'diffusion_pytorch_model.{suffix}'] + elif f"diffusion_pytorch_model.{suffix}" in files: + files = ["config.json", f"diffusion_pytorch_model.{suffix}"] location = self._download_hf_model(repo_id, files, staging) break - elif f'learned_embeds.{suffix}' in files: - location = self._download_hf_model(repo_id, [f'learned_embeds.{suffix}'], staging) + elif f"learned_embeds.{suffix}" in files: + location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging) break if not location: - logger.warning(f'Could not determine type of repo {repo_id}. Skipping install.') + logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") return {} info = ModelProbe().heuristic_probe(location, self.prediction_helper) if not info: - logger.warning(f'Could not probe {location}. Skipping install.') + logger.warning(f"Could not probe {location}. Skipping install.") return {} - dest = self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id,location) + dest = ( + self.config.models_path + / info.base_type.value + / info.model_type.value + / self._get_model_name(repo_id, location) + ) if dest.exists(): shutil.rmtree(dest) - shutil.copytree(location,dest) + shutil.copytree(location, dest) return self._install_path(dest, info) - def _get_model_name(self,path_name: str, location: Path)->str: - ''' + def _get_model_name(self, path_name: str, location: Path) -> str: + """ Calculate a name for the model - primitive implementation. - ''' + """ if key := self.reverse_paths.get(path_name): (name, base, mtype) = ModelManager.parse_key(key) return name @@ -324,99 +339,103 @@ class ModelInstall(object): else: return location.stem - def _make_attributes(self, path: Path, info: ModelProbeInfo)->dict: + def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: model_name = path.name if path.is_dir() else path.stem - description = f'{info.base_type.value} {info.model_type.value} model {model_name}' + description = f"{info.base_type.value} {info.model_type.value} model {model_name}" if key := self.reverse_paths.get(self.current_id): if key in self.datasets: - description = self.datasets[key].get('description') or description + description = self.datasets[key].get("description") or description rel_path = self.relative_to_root(path) attributes = dict( - path = str(rel_path), - description = str(description), - model_format = info.format, - ) + path=str(rel_path), + description=str(description), + model_format=info.format, + ) legacy_conf = None if info.model_type == ModelType.Main: - attributes.update(dict(variant = info.variant_type,)) - if info.format=="checkpoint": + attributes.update( + dict( + variant=info.variant_type, + ) + ) + if info.format == "checkpoint": try: - possible_conf = path.with_suffix('.yaml') + possible_conf = path.with_suffix(".yaml") if possible_conf.exists(): legacy_conf = str(self.relative_to_root(possible_conf)) elif info.base_type == BaseModelType.StableDiffusion2: - legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type]) + legacy_conf = Path( + self.config.legacy_conf_dir, + LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], + ) else: - legacy_conf = Path(self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]) + legacy_conf = Path( + self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] + ) except KeyError: - legacy_conf = Path(self.config.legacy_conf_dir, 'v1-inference.yaml') # best guess - - if info.model_type == ModelType.ControlNet and info.format=="checkpoint": - possible_conf = path.with_suffix('.yaml') + legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess + + if info.model_type == ModelType.ControlNet and info.format == "checkpoint": + possible_conf = path.with_suffix(".yaml") if possible_conf.exists(): legacy_conf = str(self.relative_to_root(possible_conf)) if legacy_conf: - attributes.update( - dict( - config = str(legacy_conf) - ) - ) + attributes.update(dict(config=str(legacy_conf))) return attributes - def relative_to_root(self, path: Path)->Path: + def relative_to_root(self, path: Path) -> Path: root = self.config.root_path if path.is_relative_to(root): return path.relative_to(root) else: return path - def _download_hf_pipeline(self, repo_id: str, staging: Path)->Path: - ''' + def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path: + """ This retrieves a StableDiffusion model from cache or remote and then does a save_pretrained() to the indicated staging area. - ''' - _,name = repo_id.split("/") - revisions = ['fp16','main'] if self.config.precision=='float16' else ['main'] + """ + _, name = repo_id.split("/") + revisions = ["fp16", "main"] if self.config.precision == "float16" else ["main"] model = None for revision in revisions: try: - model = DiffusionPipeline.from_pretrained(repo_id,revision=revision,safety_checker=None) + model = DiffusionPipeline.from_pretrained(repo_id, revision=revision, safety_checker=None) except: # most errors are due to fp16 not being present. Fix this to catch other errors pass if model: break if not model: - logger.error(f'Diffusers model {repo_id} could not be downloaded. Skipping.') + logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") return None model.save_pretrained(staging / name, safe_serialization=True) return staging / name - def _download_hf_model(self, repo_id: str, files: List[str], staging: Path)->Path: - _,name = repo_id.split("/") + def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path: + _, name = repo_id.split("/") location = staging / name paths = list() for filename in files: - p = hf_download_with_resume(repo_id, - model_dir=location, - model_name=filename, - access_token = self.access_token - ) + p = hf_download_with_resume( + repo_id, model_dir=location, model_name=filename, access_token=self.access_token + ) if p: paths.append(p) else: - logger.warning(f'Could not download {filename} from {repo_id}.') - - return location if len(paths)>0 else None + logger.warning(f"Could not download {filename} from {repo_id}.") + + return location if len(paths) > 0 else None @classmethod - def _reverse_paths(cls,datasets)->dict: - ''' + def _reverse_paths(cls, datasets) -> dict: + """ Reverse mapping from repo_id/path to destination name. - ''' - return {v.get('path') or v.get('repo_id') : k for k, v in datasets.items()} + """ + return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} + # ------------------------------------- def yes_or_no(prompt: str, default_yes=True): @@ -427,13 +446,12 @@ def yes_or_no(prompt: str, default_yes=True): else: return response[0] in ("y", "Y") + # --------------------------------------------- -def hf_download_from_pretrained( - model_class: object, model_name: str, destination: Path, **kwargs -): - logger = InvokeAILogger.getLogger('InvokeAI') - logger.addFilter(lambda x: 'fp16 is not a valid' not in x.getMessage()) - +def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): + logger = InvokeAILogger.getLogger("InvokeAI") + logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) + model = model_class.from_pretrained( model_name, resume_download=True, @@ -442,13 +460,14 @@ def hf_download_from_pretrained( model.save_pretrained(destination, safe_serialization=True) return destination + # --------------------------------------------- def hf_download_with_resume( - repo_id: str, - model_dir: str, - model_name: str, - model_dest: Path = None, - access_token: str = None, + repo_id: str, + model_dir: str, + model_name: str, + model_dest: Path = None, + access_token: str = None, ) -> Path: model_dest = model_dest or Path(os.path.join(model_dir, model_name)) os.makedirs(model_dir, exist_ok=True) @@ -467,9 +486,7 @@ def hf_download_with_resume( resp = requests.get(url, headers=header, stream=True) total = int(resp.headers.get("content-length", 0)) - if ( - resp.status_code == 416 - ): # "range not satisfiable", which means nothing to return + if resp.status_code == 416: # "range not satisfiable", which means nothing to return logger.info(f"{model_name}: complete file found. Skipping.") return model_dest elif resp.status_code == 404: @@ -498,5 +515,3 @@ def hf_download_with_resume( logger.error(f"An error occurred while downloading {model_name}: {str(e)}") return None return model_dest - - diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 011540e05e..cf057f3a89 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -4,6 +4,12 @@ Initialization file for invokeai.backend.model_management from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_cache import ModelCache from .lora import ModelPatcher, ONNXModelPatcher -from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException, DuplicateModelException +from .models import ( + BaseModelType, + ModelType, + SubModelType, + ModelVariantType, + ModelNotFoundException, + DuplicateModelException, +) from .model_merge import ModelMerger, MergeInterpolationMethod - diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 0124da7f56..2c62b8b192 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -56,9 +56,7 @@ from diffusers.schedulers import ( ) from diffusers.utils import is_accelerate_available, is_omegaconf_available, is_safetensors_available from diffusers.utils.import_utils import BACKENDS_MAPPING -from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import ( - LDMBertConfig, LDMBertModel -) +from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -85,6 +83,7 @@ if is_accelerate_available(): logger = InvokeAILogger.getLogger(__name__) CONVERT_MODEL_ROOT = InvokeAIAppConfig.get_config().root_path / MODEL_CORE / "convert" + def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. @@ -509,9 +508,7 @@ def convert_ldm_unet_checkpoint( paths = renew_resnet_paths(resnets) meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) if len(attentions): paths = renew_attention_paths(attentions) @@ -796,7 +793,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config): def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder=None): if text_encoder is None: - config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / 'clip-vit-large-patch14') + config = CLIPTextConfig.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): @@ -1008,7 +1005,9 @@ def stable_unclip_image_encoder(original_config): elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": feature_extractor = CLIPImageProcessor() # InvokeAI doesn't use CLIPVisionModelWithProjection so it isn't in the core - if this code is hit a download will occur - image_encoder = CLIPVisionModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K") + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "CLIP-ViT-H-14-laion2B-s32B-b79K" + ) else: raise NotImplementedError( f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" @@ -1071,17 +1070,17 @@ def convert_controlnet_checkpoint( extract_ema, use_linear_projection=None, cross_attention_dim=None, - precision: torch.dtype=torch.float32, + precision: torch.dtype = torch.float32, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") original_config = ctrlnet_config.copy() - - ctrlnet_config.pop('addition_embed_type') - ctrlnet_config.pop('addition_time_embed_dim') - ctrlnet_config.pop('transformer_layers_per_block') + + ctrlnet_config.pop("addition_embed_type") + ctrlnet_config.pop("addition_time_embed_dim") + ctrlnet_config.pop("transformer_layers_per_block") if use_linear_projection is not None: ctrlnet_config["use_linear_projection"] = use_linear_projection @@ -1111,6 +1110,7 @@ def convert_controlnet_checkpoint( return controlnet.to(precision) + # TO DO - PASS PRECISION def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, @@ -1249,8 +1249,8 @@ def download_from_original_stable_diffusion_ckpt( # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - - logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}') + + logger.debug(f"model_type = {model_type}; original_config_file = {original_config_file}") if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -1258,7 +1258,9 @@ def download_from_original_stable_diffusion_ckpt( key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" # model_type = "v1" - config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + config_url = ( + "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: # model_type = "v2" @@ -1277,7 +1279,10 @@ def download_from_original_stable_diffusion_ckpt( original_config_file = BytesIO(requests.get(config_url).content) original_config = OmegaConf.load(original_config_file) - if model_version == BaseModelType.StableDiffusion2 and original_config["model"]["params"]["parameterization"] == "v": + if ( + model_version == BaseModelType.StableDiffusion2 + and original_config["model"]["params"]["parameterization"] == "v" + ): prediction_type = "v_prediction" upcast_attention = True image_size = 768 @@ -1436,7 +1441,7 @@ def download_from_original_stable_diffusion_ckpt( config_kwargs = {"subfolder": "text_encoder"} text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / 'stable-diffusion-2-clip', subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer") if stable_unclip is None: if controlnet: @@ -1491,7 +1496,9 @@ def download_from_original_stable_diffusion_ckpt( prior = PriorTransformer.from_pretrained(karlo_model, subfolder="prior") prior_tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") - prior_text_model = CLIPTextModelWithProjection.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + prior_text_model = CLIPTextModelWithProjection.from_pretrained( + CONVERT_MODEL_ROOT / "clip-vit-large-patch14" + ) prior_scheduler = UnCLIPScheduler.from_pretrained(karlo_model, subfolder="prior_scheduler") prior_scheduler = DDPMScheduler.from_config(prior_scheduler.config) @@ -1533,11 +1540,19 @@ def download_from_original_stable_diffusion_ckpt( text_model = convert_ldm_clip_checkpoint( checkpoint, local_files_only=local_files_only, text_encoder=text_encoder ) - tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") if tokenizer is None else tokenizer + tokenizer = ( + CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") + if tokenizer is None + else tokenizer + ) if load_safety_checker: - safety_checker = StableDiffusionSafetyChecker.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") - feature_extractor = AutoFeatureExtractor.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker") + safety_checker = StableDiffusionSafetyChecker.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) + feature_extractor = AutoFeatureExtractor.from_pretrained( + CONVERT_MODEL_ROOT / "stable-diffusion-safety-checker" + ) else: safety_checker = None feature_extractor = None @@ -1567,7 +1582,7 @@ def download_from_original_stable_diffusion_ckpt( if model_type == "SDXL": tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "clip-vit-large-patch14") text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - + tokenizer_name = CONVERT_MODEL_ROOT / "CLIP-ViT-bigG-14-laion2B-39B-b160k" tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_name, pad_token="!") @@ -1577,7 +1592,7 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs ) - pipe = StableDiffusionXLPipeline ( + pipe = StableDiffusionXLPipeline( vae=vae.to(precision), text_encoder=text_encoder, tokenizer=tokenizer, @@ -1686,24 +1701,22 @@ def download_controlnet_from_original_ckpt( return controlnet -def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: - vae_config = create_vae_diffusers_config( - vae_config, image_size=image_size - ) - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - checkpoint, vae_config - ) +def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int) -> AutoencoderKL: + vae_config = create_vae_diffusers_config(vae_config, image_size=image_size) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) return vae + def convert_ckpt_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - use_safetensors: bool=True, - **kwargs, + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + use_safetensors: bool = True, + **kwargs, ): """ Takes all the arguments of download_from_original_stable_diffusion_ckpt(), @@ -1717,10 +1730,11 @@ def convert_ckpt_to_diffusers( safe_serialization=use_safetensors and is_safetensors_available(), ) + def convert_controlnet_to_diffusers( - checkpoint_path: Union[str, Path], - dump_path: Union[str, Path], - **kwargs, + checkpoint_path: Union[str, Path], + dump_path: Union[str, Path], + **kwargs, ): """ Takes all the arguments of download_controlnet_from_original_ckpt(), diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 8d64a7b5f7..dc43e3c6e6 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -22,14 +22,15 @@ from transformers import CLIPTextModel, CLIPTokenizer # TODO: rename and split this file -class LoRALayerBase: - #rank: Optional[int] - #alpha: Optional[float] - #bias: Optional[torch.Tensor] - #layer_key: str - #@property - #def scale(self): +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 def __init__( @@ -42,11 +43,7 @@ class LoRALayerBase: else: self.alpha = None - if ( - "bias_indices" in values - and "bias_values" in values - and "bias_size" in values - ): + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: self.bias = torch.sparse_coo_tensor( values["bias_indices"], values["bias_values"], @@ -56,13 +53,13 @@ class LoRALayerBase: else: self.bias = None - self.rank = None # set in layer implementation + self.rank = None # set in layer implementation self.layer_key = layer_key def forward( self, module: torch.nn.Module, - input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure + input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure multiplier: float, ): if type(module) == torch.nn.Conv2d: @@ -82,12 +79,16 @@ class LoRALayerBase: bias = self.bias if self.bias is not None else 0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - return op( - *input_h, - (weight + bias).view(module.weight.shape), - None, - **extra_args, - ) * multiplier * scale + return ( + op( + *input_h, + (weight + bias).view(module.weight.shape), + None, + **extra_args, + ) + * multiplier + * scale + ) def get_weight(self): raise NotImplementedError() @@ -110,9 +111,9 @@ class LoRALayerBase: # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): - #up: torch.Tensor - #mid: Optional[torch.Tensor] - #down: torch.Tensor + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor def __init__( self, @@ -162,12 +163,12 @@ class LoRALayer(LoRALayerBase): class LoHALayer(LoRALayerBase): - #w1_a: torch.Tensor - #w1_b: torch.Tensor - #w2_a: torch.Tensor - #w2_b: torch.Tensor - #t1: Optional[torch.Tensor] = None - #t2: Optional[torch.Tensor] = None + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None def __init__( self, @@ -198,12 +199,8 @@ class LoHALayer(LoRALayerBase): weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) else: - rebuild1 = torch.einsum( - "i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a - ) - rebuild2 = torch.einsum( - "i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a - ) + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) weight = rebuild1 * rebuild2 return weight @@ -234,20 +231,20 @@ class LoHALayer(LoRALayerBase): class LoKRLayer(LoRALayerBase): - #w1: Optional[torch.Tensor] = None - #w1_a: Optional[torch.Tensor] = None - #w1_b: Optional[torch.Tensor] = None - #w2: Optional[torch.Tensor] = None - #w2_a: Optional[torch.Tensor] = None - #w2_b: Optional[torch.Tensor] = None - #t2: Optional[torch.Tensor] = None + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None def __init__( self, layer_key: str, values: dict, ): - super().__init__(layer_key, values) + super().__init__(layer_key, values) if "lokr_w1" in values: self.w1 = values["lokr_w1"] @@ -277,7 +274,7 @@ class LoKRLayer(LoRALayerBase): elif "lokr_w2_b" in values: self.rank = values["lokr_w2_b"].shape[0] else: - self.rank = None # unscaled + self.rank = None # unscaled def get_weight(self): w1 = self.w1 @@ -289,7 +286,7 @@ class LoKRLayer(LoRALayerBase): if self.t2 is None: w2 = self.w2_a @ self.w2_b else: - w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b) + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -328,7 +325,7 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) -class LoRAModel: #(torch.nn.Module): +class LoRAModel: # (torch.nn.Module): _name: str layers: Dict[str, LoRALayer] _device: torch.device @@ -356,7 +353,7 @@ class LoRAModel: #(torch.nn.Module): @property def dtype(self): - return self._dtype + return self._dtype def to( self, @@ -391,7 +388,7 @@ class LoRAModel: #(torch.nn.Module): model = cls( device=device, dtype=dtype, - name=file_path.stem, # TODO: + name=file_path.stem, # TODO: layers=dict(), ) @@ -403,7 +400,6 @@ class LoRAModel: #(torch.nn.Module): state_dict = cls._group_state(state_dict) for layer_key, values in state_dict.items(): - # lora and locon if "lora_down.weight" in values: layer = LoRALayer(layer_key, values) @@ -418,9 +414,7 @@ class LoRAModel: #(torch.nn.Module): else: # TODO: diff/ia3/... format - print( - f">> Encountered unknown lora layer module in {model.name}: {layer_key}" - ) + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") return # lower memory consumption by removing already parsed layer values @@ -454,9 +448,10 @@ with LoRAHelper.apply_lora_unet(unet, loras): # unmodified unet """ + + # TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: - @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key @@ -466,10 +461,10 @@ class ModelPatcher: module = model module_key = "" - key_parts = lora_key[len(prefix):].split('_') + key_parts = lora_key[len(prefix) :].split("_") submodule_name = key_parts.pop(0) - + while len(key_parts) > 0: try: module = module.get_submodule(submodule_name) @@ -488,7 +483,6 @@ class ModelPatcher: applied_loras: List[Tuple[LoRAModel, float]], layer_name: str, ): - def lora_forward(module, input_h, output): if len(applied_loras) == 0: return output @@ -502,7 +496,6 @@ class ModelPatcher: return lora_forward - @classmethod @contextmanager def apply_lora_unet( @@ -513,7 +506,6 @@ class ModelPatcher: with cls.apply_lora(unet, loras, "lora_unet_"): yield - @classmethod @contextmanager def apply_lora_text_encoder( @@ -524,7 +516,6 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield - @classmethod @contextmanager def apply_lora( @@ -537,7 +528,7 @@ class ModelPatcher: try: with torch.no_grad(): for lora, lora_weight in loras: - #assert lora.device.type == "cpu" + # assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): if not layer_key.startswith(prefix): continue @@ -547,7 +538,7 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - #with torch.autocast(device_type="cpu"): + # with torch.autocast(device_type="cpu"): layer.to(dtype=torch.float32) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_weight = layer.get_weight() * lora_weight * layer_scale @@ -558,14 +549,13 @@ class ModelPatcher: module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype) - yield # wait for context manager exit + yield # wait for context manager exit finally: with torch.no_grad(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) - @classmethod @contextmanager def apply_ti( @@ -613,7 +603,9 @@ class ModelPatcher: f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." ) - model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype) + model_embeddings.weight.data[token_id] = embedding.to( + device=text_encoder.device, dtype=text_encoder.dtype + ) ti_tokens.append(token_id) if len(ti_tokens) > 1: @@ -625,7 +617,6 @@ class ModelPatcher: if init_tokens_count and new_tokens_added: text_encoder.resize_token_embeddings(init_tokens_count) - @classmethod @contextmanager def apply_clip_skip( @@ -644,9 +635,10 @@ class ModelPatcher: while len(skipped_layers) > 0: text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + class TextualInversionModel: name: str - embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding: torch.Tensor # [n, 768]|[n, 1280] @classmethod def from_checkpoint( @@ -658,8 +650,8 @@ class TextualInversionModel: if not isinstance(file_path, Path): file_path = Path(file_path) - result = cls() # TODO: - result.name = file_path.stem # TODO: + result = cls() # TODO: + result.name = file_path.stem # TODO: if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") @@ -670,7 +662,9 @@ class TextualInversionModel: # difference mostly in metadata if "string_to_param" in state_dict: if len(state_dict["string_to_param"]) > 1: - print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.") + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.' + ) result.embedding = next(iter(state_dict["string_to_param"].values())) @@ -699,10 +693,7 @@ class TextualInversionManager(BaseTextualInversionManager): self.pad_tokens = dict() self.tokenizer = tokenizer - def expand_textual_inversion_token_ids_if_necessary( - self, token_ids: list[int] - ) -> list[int]: - + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: if len(self.pad_tokens) == 0: return token_ids @@ -721,7 +712,6 @@ class TextualInversionManager(BaseTextualInversionManager): class ONNXModelPatcher: - @classmethod @contextmanager def apply_lora_unet( @@ -732,7 +722,6 @@ class ONNXModelPatcher: with cls.apply_lora(unet, loras, "lora_unet_"): yield - @classmethod @contextmanager def apply_lora_text_encoder( @@ -754,6 +743,7 @@ class ONNXModelPatcher: prefix: str, ): from .models.base import IAIOnnxRuntimeModel + if not isinstance(model, IAIOnnxRuntimeModel): raise Exception("Only IAIOnnxRuntimeModel models supported") @@ -830,8 +820,6 @@ class ONNXModelPatcher: for name, orig_weight in orig_weights.items(): model.tensors[name] = orig_weight - - @classmethod @contextmanager def apply_ti( @@ -841,6 +829,7 @@ class ONNXModelPatcher: ti_list: List[Any], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: from .models.base import IAIOnnxRuntimeModel + if not isinstance(text_encoder, IAIOnnxRuntimeModel): raise Exception("Only IAIOnnxRuntimeModel models supported") @@ -866,10 +855,7 @@ class ONNXModelPatcher: orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] embeddings = np.concatenate( - ( - np.copy(orig_embeddings), - np.zeros((new_tokens_added, orig_embeddings.shape[1])) - ), + (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), axis=0, ) @@ -894,7 +880,9 @@ class ONNXModelPatcher: if len(ti_tokens) > 1: ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] - text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype) + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( + orig_embeddings.dtype + ) yield ti_tokenizer, ti_manager diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 79ddd624fc..ed69c07041 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -37,19 +37,22 @@ from .models import BaseModelType, ModelType, SubModelType, ModelBase DEFAULT_MAX_CACHE_SIZE = 6.0 # amount of GPU memory to hold in reserve for use by generations (GB) -DEFAULT_MAX_VRAM_CACHE_SIZE= 2.75 +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 # actual size of a gig GIG = 1073741824 + class ModelLocker(object): "Forward declaration" pass + class ModelCache(object): "Forward declaration" pass + class _CacheRecord: size: int model: Any @@ -79,22 +82,22 @@ class _CacheRecord: return self.model.device != self.cache.storage_device else: return False - + class ModelCache(object): def __init__( self, - max_cache_size: float=DEFAULT_MAX_CACHE_SIZE, - max_vram_cache_size: float=DEFAULT_MAX_VRAM_CACHE_SIZE, - execution_device: torch.device=torch.device('cuda'), - storage_device: torch.device=torch.device('cpu'), - precision: torch.dtype=torch.float16, - sequential_offload: bool=False, - lazy_offloading: bool=True, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, sha_chunksize: int = 16777216, - logger: types.ModuleType = logger + logger: types.ModuleType = logger, ): - ''' + """ :param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] @@ -102,16 +105,16 @@ class ModelCache(object): :param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sha_chunksize: Chunksize to use when calculating sha256 model hash - ''' + """ self.model_infos: Dict[str, ModelBase] = dict() # allow lazy offloading only when vram cache enabled self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self.precision: torch.dtype=precision - self.max_cache_size: float=max_cache_size - self.max_vram_cache_size: float=max_vram_cache_size - self.execution_device: torch.device=execution_device - self.storage_device: torch.device=storage_device - self.sha_chunksize=sha_chunksize + self.precision: torch.dtype = precision + self.max_cache_size: float = max_cache_size + self.max_vram_cache_size: float = max_vram_cache_size + self.execution_device: torch.device = execution_device + self.storage_device: torch.device = storage_device + self.sha_chunksize = sha_chunksize self.logger = logger self._cached_models = dict() @@ -124,7 +127,6 @@ class ModelCache(object): model_type: ModelType, submodel_type: Optional[SubModelType] = None, ): - key = f"{model_path}:{base_model}:{model_type}" if submodel_type: key += f":{submodel_type}" @@ -163,7 +165,6 @@ class ModelCache(object): submodel: Optional[SubModelType] = None, gpu_load: bool = True, ) -> Any: - if not isinstance(model_path, Path): model_path = Path(model_path) @@ -186,7 +187,7 @@ class ModelCache(object): # TODO: lock for no copies on simultaneous calls? cache_entry = self._cached_models.get(key, None) if cache_entry is None: - self.logger.info(f'Loading model {model_path}, type {base_model}:{model_type}:{submodel}') + self.logger.info(f"Loading model {model_path}, type {base_model}:{model_type}:{submodel}") # this will remove older cached models until # there is sufficient room to load the requested model @@ -196,7 +197,7 @@ class ModelCache(object): gc.collect() model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) if mem_used := model_info.get_size(submodel): - self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') + self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB") cache_entry = _CacheRecord(self, model, mem_used) self._cached_models[key] = cache_entry @@ -209,13 +210,13 @@ class ModelCache(object): class ModelLocker(object): def __init__(self, cache, key, model, gpu_load, size_needed): - ''' + """ :param cache: The model_cache object :param key: The key of the model to lock in GPU :param model: The model to lock :param gpu_load: True if load into gpu :param size_needed: Size of the model to load - ''' + """ self.gpu_load = gpu_load self.cache = cache self.key = key @@ -224,7 +225,7 @@ class ModelCache(object): self.cache_entry = self.cache._cached_models[self.key] def __enter__(self) -> Any: - if not hasattr(self.model, 'to'): + if not hasattr(self.model, "to"): return self.model # NOTE that the model has to have the to() method in order for this @@ -234,22 +235,21 @@ class ModelCache(object): try: if self.cache.lazy_offloading: - self.cache._offload_unlocked_models(self.size_needed) - + self.cache._offload_unlocked_models(self.size_needed) + if self.model.device != self.cache.execution_device: - self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}') + self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}") with VRAMUsage() as mem: self.model.to(self.cache.execution_device) # move into GPU - self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB') - - self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}') + self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB") + + self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}") self.cache._print_cuda_stats() except: self.cache_entry.unlock() raise - # TODO: not fully understand # in the event that the caller wants the model in RAM, we # move it into CPU if it is in GPU and not locked @@ -259,7 +259,7 @@ class ModelCache(object): return self.model def __exit__(self, type, value, traceback): - if not hasattr(self.model, 'to'): + if not hasattr(self.model, "to"): return self.cache_entry.unlock() @@ -277,11 +277,11 @@ class ModelCache(object): self, model_path: Union[str, Path], ) -> str: - ''' + """ Given the HF repo id or path to a model on disk, returns a unique hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs :param model_path: Path to model file/directory on disk. - ''' + """ return self._local_model_hash(model_path) def cache_size(self) -> float: @@ -290,7 +290,7 @@ class ModelCache(object): return current_cache_size / GIG def _has_cuda(self) -> bool: - return self.execution_device.type == 'cuda' + return self.execution_device.type == "cuda" def _print_cuda_stats(self): vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) @@ -306,18 +306,21 @@ class ModelCache(object): if model_info.locked: locked_models += 1 - self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}") - + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}" + ) def _make_cache_room(self, model_size): # calculate how much memory this model will require - #multiplier = 2 if self.precision==torch.float32 else 1 + # multiplier = 2 if self.precision==torch.float32 else 1 bytes_needed = model_size maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes current_size = sum([m.size for m in self._cached_models.values()]) if current_size + bytes_needed > maximum_size: - self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB') + self.logger.debug( + f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB" + ) self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") @@ -339,7 +342,7 @@ class ModelCache(object): with suppress(RuntimeError): referrer.clear() cleared = True - #break + # break # repeat if referrers changes(due to frame clear), else exit loop if cleared: @@ -348,13 +351,17 @@ class ModelCache(object): break device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}") + self.logger.debug( + f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}" + ) # 2 refs: # 1 from cache_entry # 1 from getrefcount function - if not cache_entry.locked and refs <= 3 if 'onnx' in model_key else 2: - self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)') + if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2: + self.logger.debug( + f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" + ) current_size -= cache_entry.size del self._cache_stack[pos] del self._cached_models[model_key] @@ -368,38 +375,36 @@ class ModelCache(object): self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") - def _offload_unlocked_models(self, size_needed: int=0): + def _offload_unlocked_models(self, size_needed: int = 0): reserved = self.max_vram_cache_size * GIG vram_in_use = torch.cuda.memory_allocated() - self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB') - for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x:x[1].size): + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.locked and cache_entry.loaded: - self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}') + self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}") with VRAMUsage() as mem: cache_entry.model.to(self.storage_device) - self.logger.debug(f'GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB') + self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB") vram_in_use += mem.vram_used # note vram_used is negative - self.logger.debug(f'{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB') + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") gc.collect() torch.cuda.empty_cache() - + def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256() path = Path(model_path) - + hashpath = path / "checksum.sha256" if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: with open(hashpath) as f: hash = f.read() return hash - - self.logger.debug(f'computing hash of model {path.name}') - for file in list(path.rglob("*.ckpt")) \ - + list(path.rglob("*.safetensors")) \ - + list(path.rglob("*.pth")): + + self.logger.debug(f"computing hash of model {path.name}") + for file in list(path.rglob("*.ckpt")) + list(path.rglob("*.safetensors")) + list(path.rglob("*.pth")): with open(file, "rb") as f: while chunk := f.read(self.sha_chunksize): sha.update(chunk) @@ -408,11 +413,12 @@ class ModelCache(object): f.write(hash) return hash + class VRAMUsage(object): def __init__(self): self.vram = None self.vram_used = 0 - + def __enter__(self): self.vram = torch.cuda.memory_allocated() return self diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 184c95e5c1..4d41a7d0ac 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -249,20 +249,26 @@ from invokeai.backend.util import CUDA_DEVICE, Chdir from .model_cache import ModelCache, ModelLocker from .model_search import ModelSearch from .models import ( - BaseModelType, ModelType, SubModelType, - ModelError, SchedulerPredictionType, MODEL_CLASSES, + BaseModelType, + ModelType, + SubModelType, + ModelError, + SchedulerPredictionType, + MODEL_CLASSES, ModelConfigBase, - ModelNotFoundException, InvalidModelException, + ModelNotFoundException, + InvalidModelException, DuplicateModelException, ) # We are only starting to number the config file with release 3. # The config file version doesn't have to start at release version, but it will help # reduce confusion. -CONFIG_FILE_VERSION='3.0.0' +CONFIG_FILE_VERSION = "3.0.0" + @dataclass -class ModelInfo(): +class ModelInfo: context: ModelLocker name: str base_model: BaseModelType @@ -275,20 +281,24 @@ class ModelInfo(): def __enter__(self): return self.context.__enter__() - def __exit__(self,*args, **kwargs): + def __exit__(self, *args, **kwargs): self.context.__exit__(*args, **kwargs) + class AddModelResult(BaseModel): name: str = Field(description="The name of the model after installation") model_type: ModelType = Field(description="The type of model") base_model: BaseModelType = Field(description="The base model") config: ModelConfigBase = Field(description="The configuration of the model") + MAX_CACHE_SIZE = 6.0 # GB + class ConfigMeta(BaseModel): version: str + class ModelManager(object): """ High-level interface to model management. @@ -315,12 +325,12 @@ class ModelManager(object): if isinstance(config, (str, Path)): self.config_path = Path(config) if not self.config_path.exists(): - logger.warning(f'The file {self.config_path} was not found. Initializing a new file') + logger.warning(f"The file {self.config_path} was not found. Initializing a new file") self.initialize_model_config(self.config_path) config = OmegaConf.load(self.config_path) elif not isinstance(config, DictConfig): - raise ValueError('config argument must be an OmegaConf object, a Path or a string') + raise ValueError("config argument must be an OmegaConf object, a Path or a string") self.config_meta = ConfigMeta(**config.pop("__metadata__")) # TODO: metadata not found @@ -330,11 +340,11 @@ class ModelManager(object): self.logger = logger self.cache = ModelCache( max_cache_size=max_cache_size, - max_vram_cache_size = self.app_config.max_vram_cache_size, - execution_device = device_type, - precision = precision, - sequential_offload = sequential_offload, - logger = logger, + max_vram_cache_size=self.app_config.max_vram_cache_size, + execution_device=device_type, + precision=precision, + sequential_offload=sequential_offload, + logger=logger, ) self._read_models(config) @@ -348,7 +358,7 @@ class ModelManager(object): self.models = dict() for model_key, model_config in config.items(): - if model_key.startswith('_'): + if model_key.startswith("_"): continue model_name, base_model, model_type = self.parse_key(model_key) model_class = MODEL_CLASSES[base_model][model_type] @@ -395,7 +405,7 @@ class ModelManager(object): @classmethod def parse_key(cls, model_key: str) -> Tuple[str, BaseModelType, ModelType]: - base_model_str, model_type_str, model_name = model_key.split('/', 2) + base_model_str, model_type_str, model_name = model_key.split("/", 2) try: model_type = ModelType(model_type_str) except: @@ -414,20 +424,16 @@ class ModelManager(object): @classmethod def initialize_model_config(cls, config_path: Path): """Create empty config file""" - with open(config_path,'w') as yaml_file: - yaml_file.write(yaml.dump({'__metadata__': - {'version':'3.0.0'} - } - ) - ) + with open(config_path, "w") as yaml_file: + yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}})) def get_model( self, model_name: str, base_model: BaseModelType, model_type: ModelType, - submodel_type: Optional[SubModelType] = None - )->ModelInfo: + submodel_type: Optional[SubModelType] = None, + ) -> ModelInfo: """Given a model named identified in models.yaml, return an ModelInfo object describing it. :param model_name: symbolic name of the model in models.yaml @@ -451,7 +457,7 @@ class ModelManager(object): if not model_path.exists(): if model_class.save_to_config: self.models[model_key].error = ModelError.NotFound - raise Exception(f"Files for model \"{model_key}\" not found") + raise Exception(f'Files for model "{model_key}" not found') else: self.models.pop(model_key, None) @@ -473,7 +479,7 @@ class ModelManager(object): model_path = model_class.convert_if_required( base_model=base_model, - model_path=str(model_path), # TODO: refactor str/Path types logic + model_path=str(model_path), # TODO: refactor str/Path types logic output_path=dst_convert_path, config=model_config, ) @@ -490,17 +496,17 @@ class ModelManager(object): self.cache_keys[model_key] = set() self.cache_keys[model_key].add(model_context.key) - model_hash = "" # TODO: + model_hash = "" # TODO: return ModelInfo( - context = model_context, - name = model_name, - base_model = base_model, - type = submodel_type or model_type, - hash = model_hash, - location = model_path, # TODO: - precision = self.cache.precision, - _cache = self.cache, + context=model_context, + name=model_name, + base_model=base_model, + type=submodel_type or model_type, + hash=model_hash, + location=model_path, # TODO: + precision=self.cache.precision, + _cache=self.cache, ) def model_info( @@ -516,7 +522,7 @@ class ModelManager(object): if model_key in self.models: return self.models[model_key].dict(exclude_defaults=True) else: - return None # TODO: None or empty dict on not found + return None # TODO: None or empty dict on not found def model_names(self) -> List[Tuple[str, BaseModelType, ModelType]]: """ @@ -526,16 +532,16 @@ class ModelManager(object): return [(self.parse_key(x)) for x in self.models.keys()] def list_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, ) -> dict: """ Returns a dict describing one installed model, using the combined format of the list_models() method. """ - models = self.list_models(base_model,model_type,model_name) + models = self.list_models(base_model, model_type, model_name) return models[0] if models else None def list_models( @@ -548,13 +554,17 @@ class ModelManager(object): Return a list of models. """ - model_keys = [self.create_key(model_name, base_model, model_type)] if model_name else sorted(self.models, key=str.casefold) + model_keys = ( + [self.create_key(model_name, base_model, model_type)] + if model_name + else sorted(self.models, key=str.casefold) + ) models = [] for model_key in model_keys: model_config = self.models.get(model_key) if not model_config: - self.logger.error(f'Unknown model {model_name}') - raise ModelNotFoundException(f'Unknown model {model_name}') + self.logger.error(f"Unknown model {model_name}") + raise ModelNotFoundException(f"Unknown model {model_name}") cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) if base_model is not None and cur_base_model != base_model: @@ -571,8 +581,8 @@ class ModelManager(object): ) # expose paths as absolute to help web UI - if path := model_dict.get('path'): - model_dict['path'] = str(self.app_config.root_path / path) + if path := model_dict.get("path"): + model_dict["path"] = str(self.app_config.root_path / path) models.append(model_dict) return models @@ -641,15 +651,15 @@ class ModelManager(object): model_info(). """ # relativize paths as they go in - this makes it easier to move the root directory around - if path := model_attributes.get('path'): + if path := model_attributes.get("path"): if Path(path).is_relative_to(self.app_config.root_path): - model_attributes['path'] = str(Path(path).relative_to(self.app_config.root_path)) + model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path)) model_class = MODEL_CLASSES[base_model][model_type] model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) - if model_key in self.models and not clobber: + if model_key in self.models and not clobber: raise Exception(f'Attempt to overwrite existing model definition "{model_key}"') old_model = self.models.pop(model_key, None) @@ -675,23 +685,23 @@ class ModelManager(object): self.commit() return AddModelResult( - name = model_name, - model_type = model_type, - base_model = base_model, - config = model_config, + name=model_name, + model_type=model_type, + base_model=base_model, + config=model_config, ) def rename_model( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - new_name: str = None, - new_base: BaseModelType = None, + self, + model_name: str, + base_model: BaseModelType, + model_type: ModelType, + new_name: str = None, + new_base: BaseModelType = None, ): - ''' + """ Rename or rebase a model. - ''' + """ if new_name is None and new_base is None: self.logger.error("rename_model() called with neither a new_name nor a new_base. {model_name} unchanged.") return @@ -710,7 +720,13 @@ class ModelManager(object): # if this is a model file/directory that we manage ourselves, we need to move it if old_path.is_relative_to(self.app_config.models_path): - new_path = self.app_config.root_path / 'models' / BaseModelType(new_base).value / ModelType(model_type).value / new_name + new_path = ( + self.app_config.root_path + / "models" + / BaseModelType(new_base).value + / ModelType(model_type).value + / new_name + ) move(old_path, new_path) model_cfg.path = str(new_path.relative_to(self.app_config.root_path)) @@ -726,18 +742,18 @@ class ModelManager(object): for cache_id in cache_ids: self.cache.uncache_model(cache_id) - self.models.pop(model_key, None) # delete + self.models.pop(model_key, None) # delete self.models[new_key] = model_cfg self.commit() - def convert_model ( - self, - model_name: str, - base_model: BaseModelType, - model_type: Union[ModelType.Main,ModelType.Vae], - dest_directory: Optional[Path]=None, + def convert_model( + self, + model_name: str, + base_model: BaseModelType, + model_type: Union[ModelType.Main, ModelType.Vae], + dest_directory: Optional[Path] = None, ) -> AddModelResult: - ''' + """ Convert a checkpoint file into a diffusers folder, deleting the cached version and deleting the original checkpoint file if it is in the models directory. @@ -746,7 +762,7 @@ class ModelManager(object): :param model_type: Type of model ['vae' or 'main'] This will raise a ValueError unless the model is a checkpoint. - ''' + """ info = self.model_info(model_name, base_model, model_type) if info["model_format"] != "checkpoint": raise ValueError(f"not a checkpoint format model: {model_name}") @@ -754,27 +770,32 @@ class ModelManager(object): # We are taking advantage of a side effect of get_model() that converts check points # into cached diffusers directories stored at `location`. It doesn't matter # what submodeltype we request here, so we get the smallest. - submodel = {"submodel_type": SubModelType.Scheduler} if model_type==ModelType.Main else {} - model = self.get_model(model_name, - base_model, - model_type, - **submodel, - ) + submodel = {"submodel_type": SubModelType.Scheduler} if model_type == ModelType.Main else {} + model = self.get_model( + model_name, + base_model, + model_type, + **submodel, + ) checkpoint_path = self.app_config.root_path / info["path"] old_diffusers_path = self.app_config.models_path / model.location - new_diffusers_path = (dest_directory or self.app_config.models_path / base_model.value / model_type.value) / model_name + new_diffusers_path = ( + dest_directory or self.app_config.models_path / base_model.value / model_type.value + ) / model_name if new_diffusers_path.exists(): raise ValueError(f"A diffusers model already exists at {new_diffusers_path}") try: - move(old_diffusers_path,new_diffusers_path) + move(old_diffusers_path, new_diffusers_path) info["model_format"] = "diffusers" - info["path"] = str(new_diffusers_path) if dest_directory else str(new_diffusers_path.relative_to(self.app_config.root_path)) - info.pop('config') + info["path"] = ( + str(new_diffusers_path) + if dest_directory + else str(new_diffusers_path.relative_to(self.app_config.root_path)) + ) + info.pop("config") - result = self.add_model(model_name, base_model, model_type, - model_attributes = info, - clobber=True) + result = self.add_model(model_name, base_model, model_type, model_attributes=info, clobber=True) except: # something went wrong, so don't leave dangling diffusers model in directory or it will cause a duplicate model error! rmtree(new_diffusers_path) @@ -798,15 +819,12 @@ class ModelManager(object): found_models = [] for file in files: location = str(file.resolve()).replace("\\", "/") - if ( - "model.safetensors" not in location - and "diffusion_pytorch_model.safetensors" not in location - ): + if "model.safetensors" not in location and "diffusion_pytorch_model.safetensors" not in location: found_models.append({"name": file.stem, "location": location}) return search_folder, found_models - def commit(self, conf_file: Path=None) -> None: + def commit(self, conf_file: Path = None) -> None: """ Write current configuration out to the indicated file. """ @@ -824,7 +842,7 @@ class ModelManager(object): yaml_str = OmegaConf.to_yaml(data_to_save) config_file_path = conf_file or self.config_path - assert config_file_path is not None,'no config file path to write to' + assert config_file_path is not None, "no config file path to write to" config_file_path = self.app_config.root_path / config_file_path tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") try: @@ -857,11 +875,10 @@ class ModelManager(object): base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, ): - loaded_files = set() new_models_found = False - self.logger.info(f'Scanning {self.app_config.models_path} for new models') + self.logger.info(f"Scanning {self.app_config.models_path} for new models") with Chdir(self.app_config.root_path): for model_key, model_config in list(self.models.items()): model_name, cur_base_model, cur_model_type = self.parse_key(model_key) @@ -887,10 +904,10 @@ class ModelManager(object): models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value if not models_dir.exists(): - continue # TODO: or create all folders? + continue # TODO: or create all folders? for model_path in models_dir.iterdir(): - if model_path not in loaded_files: # TODO: check + if model_path not in loaded_files: # TODO: check model_name = model_path.name if model_path.is_dir() else model_path.stem model_key = self.create_key(model_name, cur_base_model, cur_model_type) @@ -900,7 +917,7 @@ class ModelManager(object): if model_path.is_relative_to(self.app_config.root_path): model_path = model_path.relative_to(self.app_config.root_path) - + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) self.models[model_key] = model_config new_models_found = True @@ -916,11 +933,10 @@ class ModelManager(object): if (new_models_found or imported_models) and self.config_path: self.commit() - - def autoimport(self)->Dict[str, AddModelResult]: - ''' + def autoimport(self) -> Dict[str, AddModelResult]: + """ Scan the autoimport directory (if defined) and import new models, delete defunct models. - ''' + """ # avoid circular import from invokeai.backend.install.model_install_backend import ModelInstall from invokeai.frontend.install.model_install import ask_user_for_prediction_type @@ -939,7 +955,9 @@ class ModelManager(object): self.new_models_found.update(self.installer.heuristic_import(model)) def on_search_completed(self): - self.logger.info(f'Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models') + self.logger.info( + f"Scanned {self._items_scanned} files and directories, imported {len(self.new_models_found)} models" + ) def models_found(self): return self.new_models_found @@ -949,31 +967,37 @@ class ModelManager(object): # LS: hacky # Patch in the SD VAE from core so that it is available for use by the UI try: - self.heuristic_import({config.root_path / 'models/core/convert/sd-vae-ft-mse'}) + self.heuristic_import({config.root_path / "models/core/convert/sd-vae-ft-mse"}) except: pass - installer = ModelInstall(config = self.app_config, - model_manager = self, - prediction_type_helper = ask_user_for_prediction_type, - ) - known_paths = {config.root_path / x['path'] for x in self.list_models()} - directories = {config.root_path / x for x in [config.autoimport_dir, - config.lora_dir, - config.embedding_dir, - config.controlnet_dir, - ] if x - } + installer = ModelInstall( + config=self.app_config, + model_manager=self, + prediction_type_helper=ask_user_for_prediction_type, + ) + known_paths = {config.root_path / x["path"] for x in self.list_models()} + directories = { + config.root_path / x + for x in [ + config.autoimport_dir, + config.lora_dir, + config.embedding_dir, + config.controlnet_dir, + ] + if x + } scanner = ScanAndImport(directories, self.logger, ignore=known_paths, installer=installer) scanner.search() - + return scanner.models_found() - def heuristic_import(self, - items_to_import: Set[str], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->Dict[str, AddModelResult]: - '''Import a list of paths, repo_ids or URLs. Returns the set of + def heuristic_import( + self, + items_to_import: Set[str], + prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, + ) -> Dict[str, AddModelResult]: + """Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. :param items_to_import: Set of strings corresponding to models to be imported. :param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType. @@ -992,14 +1016,15 @@ class ModelManager(object): May return the following exceptions: - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL - ValueError - a corresponding model already exists - ''' + """ # avoid circular import here from invokeai.backend.install.model_install_backend import ModelInstall + successfully_installed = dict() - installer = ModelInstall(config = self.app_config, - prediction_type_helper = prediction_type_helper, - model_manager = self) + installer = ModelInstall( + config=self.app_config, prediction_type_helper=prediction_type_helper, model_manager=self + ) for thing in items_to_import: installed = installer.heuristic_import(thing) successfully_installed.update(installed) diff --git a/invokeai/backend/model_management/model_merge.py b/invokeai/backend/model_management/model_merge.py index 6427b9e430..8cf3ce4ad0 100644 --- a/invokeai/backend/model_management/model_merge.py +++ b/invokeai/backend/model_management/model_merge.py @@ -17,23 +17,25 @@ import invokeai.backend.util.logging as logger from ...backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult + class MergeInterpolationMethod(str, Enum): WeightedSum = "weighted_sum" Sigmoid = "sigmoid" InvSigmoid = "inv_sigmoid" AddDifference = "add_difference" + class ModelMerger(object): def __init__(self, manager: ModelManager): self.manager = manager def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: MergeInterpolationMethod = None, - force: bool = False, - **kwargs, + self, + model_paths: List[Path], + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + **kwargs, ) -> DiffusionPipeline: """ :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids @@ -58,24 +60,23 @@ class ModelMerger(object): merged_pipe = pipe.merge( pretrained_model_name_or_path_list=model_paths, alpha=alpha, - interp=interp.value if interp else None, #diffusers API treats None as "weighted sum" + interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" force=force, **kwargs, ) dlogging.set_verbosity(verbosity) return merged_pipe - - def merge_diffusion_models_and_save ( - self, - model_names: List[str], - base_model: Union[BaseModelType,str], - merged_model_name: str, - alpha: float = 0.5, - interp: MergeInterpolationMethod = None, - force: bool = False, - merge_dest_directory: Optional[Path] = None, - **kwargs, + def merge_diffusion_models_and_save( + self, + model_names: List[str], + base_model: Union[BaseModelType, str], + merged_model_name: str, + alpha: float = 0.5, + interp: MergeInterpolationMethod = None, + force: bool = False, + merge_dest_directory: Optional[Path] = None, + **kwargs, ) -> AddModelResult: """ :param models: up to three models, designated by their InvokeAI models.yaml model name @@ -94,39 +95,45 @@ class ModelMerger(object): config = self.manager.app_config base_model = BaseModelType(base_model) vae = None - + for mod in model_names: info = self.manager.list_model(mod, base_model=base_model, model_type=ModelType.Main) - assert info, f"model {mod}, base_model {base_model}, is unknown" - assert info["model_format"] == "diffusers", f"{mod} is not a diffusers model. It must be optimized before merging" - assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" - assert len(model_names) <= 2 or \ - interp==MergeInterpolationMethod.AddDifference, "When merging three models, only the 'add_difference' merge method is supported" + assert info, f"model {mod}, base_model {base_model}, is unknown" + assert ( + info["model_format"] == "diffusers" + ), f"{mod} is not a diffusers model. It must be optimized before merging" + assert info["variant"] == "normal", f"{mod} is a {info['variant']} model, which cannot currently be merged" + assert ( + len(model_names) <= 2 or interp == MergeInterpolationMethod.AddDifference + ), "When merging three models, only the 'add_difference' merge method is supported" # pick up the first model's vae if mod == model_names[0]: vae = info.get("vae") model_paths.extend([config.root_path / info["path"]]) - merge_method = None if interp == 'weighted_sum' else MergeInterpolationMethod(interp) - logger.debug(f'interp = {interp}, merge_method={merge_method}') - merged_pipe = self.merge_diffusion_models( - model_paths, alpha, merge_method, force, **kwargs + merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) + logger.debug(f"interp = {interp}, merge_method={merge_method}") + merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, **kwargs) + dump_path = ( + Path(merge_dest_directory) + if merge_dest_directory + else config.models_path / base_model.value / ModelType.Main.value ) - dump_path = Path(merge_dest_directory) if merge_dest_directory else config.models_path / base_model.value / ModelType.Main.value dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name merged_pipe.save_pretrained(dump_path, safe_serialization=1) attributes = dict( - path = str(dump_path), - description = f"Merge of models {', '.join(model_names)}", - model_format = "diffusers", - variant = ModelVariantType.Normal.value, - vae = vae, + path=str(dump_path), + description=f"Merge of models {', '.join(model_names)}", + model_format="diffusers", + variant=ModelVariantType.Normal.value, + vae=vae, + ) + return self.manager.add_model( + merged_model_name, + base_model=base_model, + model_type=ModelType.Main, + model_attributes=attributes, + clobber=True, ) - return self.manager.add_model(merged_model_name, - base_model = base_model, - model_type = ModelType.Main, - model_attributes = attributes, - clobber = True - ) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index d2e5379024..ee14d8ba93 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -10,12 +10,16 @@ from typing import Callable, Literal, Union, Dict, Optional from picklescan.scanner import scan_file_path from .models import ( - BaseModelType, ModelType, ModelVariantType, - SchedulerPredictionType, SilenceWarnings, - InvalidModelException + BaseModelType, + ModelType, + ModelVariantType, + SchedulerPredictionType, + SilenceWarnings, + InvalidModelException, ) from .models.base import read_checkpoint_meta + @dataclass class ModelProbeInfo(object): model_type: ModelType @@ -23,70 +27,74 @@ class ModelProbeInfo(object): variant_type: ModelVariantType prediction_type: SchedulerPredictionType upcast_attention: bool - format: Literal['diffusers','checkpoint', 'lycoris', 'olive', 'onnx'] + format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"] image_size: int + class ProbeBase(object): - '''forward declaration''' + """forward declaration""" + pass + class ModelProbe(object): - PROBES = { - 'diffusers': { }, - 'checkpoint': { }, + "diffusers": {}, + "checkpoint": {}, } CLASS2TYPE = { - 'StableDiffusionPipeline' : ModelType.Main, - 'StableDiffusionInpaintPipeline' : ModelType.Main, - 'StableDiffusionXLPipeline' : ModelType.Main, - 'StableDiffusionXLImg2ImgPipeline' : ModelType.Main, - 'AutoencoderKL' : ModelType.Vae, - 'ControlNetModel' : ModelType.ControlNet, + "StableDiffusionPipeline": ModelType.Main, + "StableDiffusionInpaintPipeline": ModelType.Main, + "StableDiffusionXLPipeline": ModelType.Main, + "StableDiffusionXLImg2ImgPipeline": ModelType.Main, + "AutoencoderKL": ModelType.Vae, + "ControlNetModel": ModelType.ControlNet, } - + @classmethod - def register_probe(cls, - format: Literal['diffusers','checkpoint'], - model_type: ModelType, - probe_class: ProbeBase): + def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase): cls.PROBES[format][model_type] = probe_class @classmethod - def heuristic_probe(cls, - model: Union[Dict, ModelMixin, Path], - prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, - )->ModelProbeInfo: - if isinstance(model,Path): - return cls.probe(model_path=model,prediction_type_helper=prediction_type_helper) - elif isinstance(model,(dict,ModelMixin,ConfigMixin)): + def heuristic_probe( + cls, + model: Union[Dict, ModelMixin, Path], + prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None, + ) -> ModelProbeInfo: + if isinstance(model, Path): + return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper) + elif isinstance(model, (dict, ModelMixin, ConfigMixin)): return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper) else: raise InvalidModelException("model parameter {model} is neither a Path, nor a model") @classmethod - def probe(cls, - model_path: Path, - model: Optional[Union[Dict, ModelMixin]] = None, - prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]] = None)->ModelProbeInfo: - ''' + def probe( + cls, + model_path: Path, + model: Optional[Union[Dict, ModelMixin]] = None, + prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, + ) -> ModelProbeInfo: + """ Probe the model at model_path and return sufficient information about it to place it somewhere in the models directory hierarchy. If the model is already loaded into memory, you may provide it as model in order to avoid opening it a second time. The prediction_type_helper callable is a function that receives the path to the model and returns the BaseModelType. It is called to distinguish between V2-Base and V2-768 SD models. - ''' + """ if model_path: - format_type = 'diffusers' if model_path.is_dir() else 'checkpoint' + format_type = "diffusers" if model_path.is_dir() else "checkpoint" else: - format_type = 'diffusers' if isinstance(model,(ConfigMixin,ModelMixin)) else 'checkpoint' + format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint" model_info = None try: - model_type = cls.get_model_type_from_folder(model_path, model) \ - if format_type == 'diffusers' \ - else cls.get_model_type_from_checkpoint(model_path, model) + model_type = ( + cls.get_model_type_from_folder(model_path, model) + if format_type == "diffusers" + else cls.get_model_type_from_checkpoint(model_path, model) + ) probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: return None @@ -96,17 +104,23 @@ class ModelProbe(object): prediction_type = probe.get_scheduler_prediction_type() format = probe.get_format() model_info = ModelProbeInfo( - model_type = model_type, - base_type = base_type, - variant_type = variant_type, - prediction_type = prediction_type, - upcast_attention = (base_type==BaseModelType.StableDiffusion2 \ - and prediction_type==SchedulerPredictionType.VPrediction), - format = format, - image_size = 1024 if (base_type in {BaseModelType.StableDiffusionXL,BaseModelType.StableDiffusionXLRefiner}) else \ - 768 if (base_type==BaseModelType.StableDiffusion2 \ - and prediction_type==SchedulerPredictionType.VPrediction ) else \ - 512 + model_type=model_type, + base_type=base_type, + variant_type=variant_type, + prediction_type=prediction_type, + upcast_attention=( + base_type == BaseModelType.StableDiffusion2 + and prediction_type == SchedulerPredictionType.VPrediction + ), + format=format, + image_size=1024 + if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) + else 768 + if ( + base_type == BaseModelType.StableDiffusion2 + and prediction_type == SchedulerPredictionType.VPrediction + ) + else 512, ) except Exception: raise @@ -115,7 +129,7 @@ class ModelProbe(object): @classmethod def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType: - if model_path.suffix not in ('.bin','.pt','.ckpt','.safetensors','.pth'): + if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"): return None if model_path.name == "learned_embeds.bin": @@ -142,32 +156,32 @@ class ModelProbe(object): # diffusers-ti if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion - + raise InvalidModelException(f"Unable to determine model type for {model_path}") @classmethod - def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType: - ''' + def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType: + """ Get the model type of a hugging-face style folder. - ''' + """ class_name = None if model: class_name = model.__class__.__name__ else: - if (folder_path / 'learned_embeds.bin').exists(): + if (folder_path / "learned_embeds.bin").exists(): return ModelType.TextualInversion - if (folder_path / 'pytorch_lora_weights.bin').exists(): + if (folder_path / "pytorch_lora_weights.bin").exists(): return ModelType.Lora - i = folder_path / 'model_index.json' - c = folder_path / 'config.json' + i = folder_path / "model_index.json" + c = folder_path / "config.json" config_path = i if i.exists() else c if c.exists() else None if config_path: - with open(config_path,'r') as file: + with open(config_path, "r") as file: conf = json.load(file) - class_name = conf['_class_name'] + class_name = conf["_class_name"] if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type @@ -176,7 +190,7 @@ class ModelProbe(object): raise InvalidModelException(f"Unable to determine model type for {folder_path}") @classmethod - def _scan_and_load_checkpoint(cls,model_path: Path)->dict: + def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: with SilenceWarnings(): if model_path.suffix.endswith((".ckpt", ".pt", ".bin")): cls._scan_model(model_path, model_path) @@ -186,55 +200,53 @@ class ModelProbe(object): @classmethod def _scan_model(cls, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - raise "The model {model_name} is potentially infected by malware. Aborting import." + """ + Apply picklescanner to the indicated checkpoint and issue a warning + and option to exit if an infected file is identified. + """ + # scan model + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + raise "The model {model_name} is potentially infected by malware. Aborting import." + ###################################################3 # Checkpoint probing ###################################################3 class ProbeBase(object): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: pass - def get_variant_type(self)->ModelVariantType: - pass - - def get_scheduler_prediction_type(self)->SchedulerPredictionType: + def get_variant_type(self) -> ModelVariantType: pass - def get_format(self)->str: + def get_scheduler_prediction_type(self) -> SchedulerPredictionType: pass + def get_format(self) -> str: + pass + + class CheckpointProbeBase(ProbeBase): - def __init__(self, - checkpoint_path: Path, - checkpoint: dict, - helper: Callable[[Path],SchedulerPredictionType] = None - )->BaseModelType: + def __init__( + self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None + ) -> BaseModelType: self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path) self.checkpoint_path = checkpoint_path self.helper = helper - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: pass - def get_format(self)->str: - return 'checkpoint' + def get_format(self) -> str: + return "checkpoint" - def get_variant_type(self)-> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path,self.checkpoint) + def get_variant_type(self) -> ModelVariantType: + model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint) if model_type != ModelType.Main: return ModelVariantType.Normal - state_dict = self.checkpoint.get('state_dict') or self.checkpoint - in_channels = state_dict[ - "model.diffusion_model.input_blocks.0.0.weight" - ].shape[1] + state_dict = self.checkpoint.get("state_dict") or self.checkpoint + in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: @@ -242,18 +254,21 @@ class CheckpointProbeBase(ProbeBase): elif in_channels == 4: return ModelVariantType.Normal else: - raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}") + raise InvalidModelException( + f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}" + ) + class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint - state_dict = self.checkpoint.get('state_dict') or checkpoint + state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 768: return BaseModelType.StableDiffusion1 if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 - key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' + key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: return BaseModelType.StableDiffusionXL elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: @@ -261,35 +276,38 @@ class PipelineCheckpointProbe(CheckpointProbeBase): else: raise InvalidModelException("Cannot determine base type") - def get_scheduler_prediction_type(self)->SchedulerPredictionType: + def get_scheduler_prediction_type(self) -> SchedulerPredictionType: type = self.get_base_type() if type == BaseModelType.StableDiffusion1: return SchedulerPredictionType.Epsilon checkpoint = self.checkpoint - state_dict = self.checkpoint.get('state_dict') or checkpoint + state_dict = self.checkpoint.get("state_dict") or checkpoint key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if 'global_step' in checkpoint: - if checkpoint['global_step'] == 220000: + if "global_step" in checkpoint: + if checkpoint["global_step"] == 220000: return SchedulerPredictionType.Epsilon elif checkpoint["global_step"] == 110000: return SchedulerPredictionType.VPrediction - if self.checkpoint_path and self.helper \ - and not self.checkpoint_path.with_suffix('.yaml').exists(): # if a .yaml config file exists, then this step not needed + if ( + self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists() + ): # if a .yaml config file exists, then this step not needed return self.helper(self.checkpoint_path) else: return None + class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: # I can't find any standalone 2.X VAEs to test with! return BaseModelType.StableDiffusion1 -class LoRACheckpointProbe(CheckpointProbeBase): - def get_format(self)->str: - return 'lycoris' - def get_base_type(self)->BaseModelType: +class LoRACheckpointProbe(CheckpointProbeBase): + def get_format(self) -> str: + return "lycoris" + + def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" @@ -307,16 +325,17 @@ class LoRACheckpointProbe(CheckpointProbeBase): else: return None + class TextualInversionCheckpointProbe(CheckpointProbeBase): - def get_format(self)->str: + def get_format(self) -> str: return None - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint - if 'string_to_token' in checkpoint: - token_dim = list(checkpoint['string_to_param'].values())[0].shape[-1] - elif 'emb_params' in checkpoint: - token_dim = checkpoint['emb_params'].shape[-1] + if "string_to_token" in checkpoint: + token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] + elif "emb_params" in checkpoint: + token_dim = checkpoint["emb_params"].shape[-1] else: token_dim = list(checkpoint.values())[0].shape[0] if token_dim == 768: @@ -326,12 +345,14 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase): else: return None + class ControlNetCheckpointProbe(CheckpointProbeBase): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint - for key_name in ('control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight', - 'input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight' - ): + for key_name in ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + ): if key_name not in checkpoint: continue if checkpoint[key_name].shape[-1] == 768: @@ -342,56 +363,54 @@ class ControlNetCheckpointProbe(CheckpointProbeBase): return self.helper(self.checkpoint_path) raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") + ######################################################## # classes for probing folders ####################################################### class FolderProbeBase(ProbeBase): - def __init__(self, - folder_path: Path, - model: ModelMixin = None, - helper: Callable=None # not used - ): + def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None): # not used self.model = model self.folder_path = folder_path - def get_variant_type(self)->ModelVariantType: + def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal - def get_format(self)->str: - return 'diffusers' - + def get_format(self) -> str: + return "diffusers" + + class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: if self.model: unet_conf = self.model.unet.config else: - with open(self.folder_path / 'unet' / 'config.json','r') as file: + with open(self.folder_path / "unet" / "config.json", "r") as file: unet_conf = json.load(file) - if unet_conf['cross_attention_dim'] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf['cross_attention_dim'] == 1024: + if unet_conf["cross_attention_dim"] == 768: + return BaseModelType.StableDiffusion1 + elif unet_conf["cross_attention_dim"] == 1024: return BaseModelType.StableDiffusion2 - elif unet_conf['cross_attention_dim'] == 1280: + elif unet_conf["cross_attention_dim"] == 1280: return BaseModelType.StableDiffusionXLRefiner - elif unet_conf['cross_attention_dim'] == 2048: + elif unet_conf["cross_attention_dim"] == 2048: return BaseModelType.StableDiffusionXL else: - raise InvalidModelException(f'Unknown base model for {self.folder_path}') + raise InvalidModelException(f"Unknown base model for {self.folder_path}") - def get_scheduler_prediction_type(self)->SchedulerPredictionType: + def get_scheduler_prediction_type(self) -> SchedulerPredictionType: if self.model: scheduler_conf = self.model.scheduler.config else: - with open(self.folder_path / 'scheduler' / 'scheduler_config.json','r') as file: + with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file: scheduler_conf = json.load(file) - if scheduler_conf['prediction_type'] == "v_prediction": + if scheduler_conf["prediction_type"] == "v_prediction": return SchedulerPredictionType.VPrediction - elif scheduler_conf['prediction_type'] == 'epsilon': + elif scheduler_conf["prediction_type"] == "epsilon": return SchedulerPredictionType.Epsilon else: return None - - def get_variant_type(self)->ModelVariantType: + + def get_variant_type(self) -> ModelVariantType: # This only works for pipelines! Any kind of # exception results in our returning the # "normal" variant type @@ -399,11 +418,11 @@ class PipelineFolderProbe(FolderProbeBase): if self.model: conf = self.model.unet.config else: - config_file = self.folder_path / 'unet' / 'config.json' - with open(config_file,'r') as file: + config_file = self.folder_path / "unet" / "config.json" + with open(config_file, "r") as file: conf = json.load(file) - - in_channels = conf['in_channels'] + + in_channels = conf["in_channels"] if in_channels == 9: return ModelVariantType.Inpaint elif in_channels == 5: @@ -414,60 +433,67 @@ class PipelineFolderProbe(FolderProbeBase): pass return ModelVariantType.Normal + class VaeFolderProbe(FolderProbeBase): - def get_base_type(self)->BaseModelType: - config_file = self.folder_path / 'config.json' + def get_base_type(self) -> BaseModelType: + config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file,'r') as file: + with open(config_file, "r") as file: config = json.load(file) - return BaseModelType.StableDiffusionXL \ - if config.get('scaling_factor',0)==0.13025 and config.get('sample_size') in [512, 1024] \ + return ( + BaseModelType.StableDiffusionXL + if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] else BaseModelType.StableDiffusion1 + ) + class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self)->str: + def get_format(self) -> str: return None - - def get_base_type(self)->BaseModelType: - path = self.folder_path / 'learned_embeds.bin' + + def get_base_type(self) -> BaseModelType: + path = self.folder_path / "learned_embeds.bin" if not path.exists(): return None checkpoint = ModelProbe._scan_and_load_checkpoint(path) - return TextualInversionCheckpointProbe(None,checkpoint=checkpoint).get_base_type() + return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() + class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self)->BaseModelType: - config_file = self.folder_path / 'config.json' + def get_base_type(self) -> BaseModelType: + config_file = self.folder_path / "config.json" if not config_file.exists(): raise InvalidModelException(f"Cannot determine base type for {self.folder_path}") - with open(config_file,'r') as file: + with open(config_file, "r") as file: config = json.load(file) # no obvious way to distinguish between sd2-base and sd2-768 - return BaseModelType.StableDiffusion1 \ - if config['cross_attention_dim']==768 \ - else BaseModelType.StableDiffusion2 + return ( + BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2 + ) + class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self)->BaseModelType: + def get_base_type(self) -> BaseModelType: model_file = None - for suffix in ['safetensors','bin']: - base_file = self.folder_path / f'pytorch_lora_weights.{suffix}' + for suffix in ["safetensors", "bin"]: + base_file = self.folder_path / f"pytorch_lora_weights.{suffix}" if base_file.exists(): model_file = base_file break if not model_file: - raise InvalidModelException('Unknown LoRA format encountered') - return LoRACheckpointProbe(model_file,None).get_base_type() + raise InvalidModelException("Unknown LoRA format encountered") + return LoRACheckpointProbe(model_file, None).get_base_type() + ############## register probe classes ###### -ModelProbe.register_probe('diffusers', ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe('diffusers', ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe('diffusers', ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe('diffusers', ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe('diffusers', ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe('checkpoint', ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe('checkpoint', ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe('checkpoint', ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe('checkpoint', ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe('checkpoint', ModelType.ControlNet, ControlNetCheckpointProbe) +ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) +ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) diff --git a/invokeai/backend/model_management/model_search.py b/invokeai/backend/model_management/model_search.py index 5657bd9549..9c87d6c408 100644 --- a/invokeai/backend/model_management/model_search.py +++ b/invokeai/backend/model_management/model_search.py @@ -10,8 +10,9 @@ from pathlib import Path import invokeai.backend.util.logging as logger + class ModelSearch(ABC): - def __init__(self, directories: List[Path], logger: types.ModuleType=logger): + def __init__(self, directories: List[Path], logger: types.ModuleType = logger): """ Initialize a recursive model directory search. :param directories: List of directory Paths to recurse through @@ -56,18 +57,23 @@ class ModelSearch(ABC): def walk_directory(self, path: Path): for root, dirs, files in os.walk(path): - if str(Path(root).name).startswith('.'): + if str(Path(root).name).startswith("."): self._pruned_paths.add(root) if any([Path(root).is_relative_to(x) for x in self._pruned_paths]): continue - + self._items_scanned += len(dirs) + len(files) for d in dirs: path = Path(root) / d if path in self._scanned_paths or path.parent in self._scanned_dirs: self._scanned_dirs.add(path) continue - if any([(path/x).exists() for x in {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}]): + if any( + [ + (path / x).exists() + for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"} + ] + ): try: self.on_model_found(path) self._models_found += 1 @@ -79,18 +85,19 @@ class ModelSearch(ABC): path = Path(root) / f if path.parent in self._scanned_dirs: continue - if path.suffix in {'.ckpt','.bin','.pth','.safetensors','.pt'}: + if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}: try: self.on_model_found(path) self._models_found += 1 except Exception as e: self.logger.warning(str(e)) + class FindModels(ModelSearch): def on_search_started(self): self.models_found: Set[Path] = set() - def on_model_found(self,model: Path): + def on_model_found(self, model: Path): self.models_found.add(model) def on_search_completed(self): @@ -99,5 +106,3 @@ class FindModels(ModelSearch): def list_models(self) -> List[Path]: self.search() return list(self.models_found) - - diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 7cafea652e..931da1b159 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -3,15 +3,24 @@ from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin from .base import ( - BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, - ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, - ModelNotFoundException, InvalidModelException, DuplicateModelException - ) + BaseModelType, + ModelType, + SubModelType, + ModelBase, + ModelConfigBase, + ModelVariantType, + SchedulerPredictionType, + ModelError, + SilenceWarnings, + ModelNotFoundException, + InvalidModelException, + DuplicateModelException, +) from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .sdxl import StableDiffusionXLModel from .vae import VaeModel from .lora import LoRAModel -from .controlnet import ControlNetModel # TODO: +from .controlnet import ControlNetModel # TODO: from .textual_inversion import TextualInversionModel from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model @@ -51,18 +60,19 @@ MODEL_CLASSES = { ModelType.TextualInversion: TextualInversionModel, ModelType.ONNX: ONNXStableDiffusion2Model, }, - #BaseModelType.Kandinsky2_1: { + # BaseModelType.Kandinsky2_1: { # ModelType.Main: Kandinsky2_1Model, # ModelType.MoVQ: MoVQModel, # ModelType.Lora: LoRAModel, # ModelType.ControlNet: ControlNetModel, # ModelType.TextualInversion: TextualInversionModel, - #}, + # }, } MODEL_CONFIGS = list() OPENAPI_MODEL_CONFIGS = list() + class OpenAPIModelInfoBase(BaseModel): model_name: str base_model: BaseModelType @@ -78,27 +88,31 @@ for base_model, models in MODEL_CLASSES.items(): # LS: sort to get the checkpoint configs first, which makes # for a better template in the Swagger docs for cfg in sorted(model_configs, key=lambda x: str(x)): - model_name, cfg_name = cfg.__qualname__.split('.')[-2:] + model_name, cfg_name = cfg.__qualname__.split(".")[-2:] openapi_cfg_name = model_name + cfg_name if openapi_cfg_name in vars(): continue - api_wrapper = type(openapi_cfg_name, (cfg, OpenAPIModelInfoBase), dict( - __annotations__ = dict( - model_type=Literal[model_type.value], + api_wrapper = type( + openapi_cfg_name, + (cfg, OpenAPIModelInfoBase), + dict( + __annotations__=dict( + model_type=Literal[model_type.value], + ), ), - )) + ) - #globals()[openapi_cfg_name] = api_wrapper + # globals()[openapi_cfg_name] = api_wrapper vars()[openapi_cfg_name] = api_wrapper OPENAPI_MODEL_CONFIGS.append(api_wrapper) + def get_model_config_enums(): enums = list() for model_config in MODEL_CONFIGS: - - if hasattr(inspect,'get_annotations'): + if hasattr(inspect, "get_annotations"): fields = inspect.get_annotations(model_config) else: fields = model_config.__annotations__ @@ -115,7 +129,9 @@ def get_model_config_enums(): if isinstance(field, type) and issubclass(field, str) and issubclass(field, Enum): enums.append(field) - elif get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + elif get_origin(field) is Literal and all( + isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ + ): enums.append(type(field.__args__[0])) elif field is None: @@ -125,4 +141,3 @@ def get_model_config_enums(): raise Exception(f"Unsupported format definition in {model_configs.__qualname__}") return enums - diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index d8b851411b..1a1041012d 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -20,32 +20,45 @@ from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callab import onnx from onnx import numpy_helper from onnx.external_data_helper import set_external_data -from onnxruntime import InferenceSession, OrtValue, SessionOptions, ExecutionMode, GraphOptimizationLevel, get_available_providers +from onnxruntime import ( + InferenceSession, + OrtValue, + SessionOptions, + ExecutionMode, + GraphOptimizationLevel, + get_available_providers, +) + class DuplicateModelException(Exception): pass + class InvalidModelException(Exception): pass + class ModelNotFoundException(Exception): pass + class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" - #Kandinsky2_1 = "kandinsky-2.1" + # Kandinsky2_1 = "kandinsky-2.1" + class ModelType(str, Enum): ONNX = "onnx" Main = "main" Vae = "vae" Lora = "lora" - ControlNet = "controlnet" # used by model_probe + ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" + class SubModelType(str, Enum): UNet = "unet" TextEncoder = "text_encoder" @@ -57,23 +70,27 @@ class SubModelType(str, Enum): VaeEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" - #MoVQ = "movq" + # MoVQ = "movq" + class ModelVariantType(str, Enum): Normal = "normal" Inpaint = "inpaint" Depth = "depth" + class SchedulerPredictionType(str, Enum): Epsilon = "epsilon" VPrediction = "v_prediction" Sample = "sample" - + + class ModelError(str, Enum): NotFound = "not_found" + class ModelConfigBase(BaseModel): - path: str # or Path + path: str # or Path description: Optional[str] = Field(None) model_format: Optional[str] = Field(None) error: Optional[ModelError] = Field(None) @@ -81,13 +98,17 @@ class ModelConfigBase(BaseModel): class Config: use_enum_values = True + class EmptyConfigLoader(ConfigMixin): @classmethod def load_config(cls, *args, **kwargs): cls.config_name = kwargs.pop("config_name") return super().load_config(*args, **kwargs) -T_co = TypeVar('T_co', covariant=True) + +T_co = TypeVar("T_co", covariant=True) + + class classproperty(Generic[T_co]): def __init__(self, fget: Callable[[Any], T_co]) -> None: self.fget = fget @@ -96,12 +117,13 @@ class classproperty(Generic[T_co]): return self.fget(owner) def __set__(self, instance: Optional[Any], value: Any) -> None: - raise AttributeError('cannot set attribute') + raise AttributeError("cannot set attribute") + class ModelBase(metaclass=ABCMeta): - #model_path: str - #base_model: BaseModelType - #model_type: ModelType + # model_path: str + # base_model: BaseModelType + # model_type: ModelType def __init__( self, @@ -120,7 +142,7 @@ class ModelBase(metaclass=ABCMeta): return None elif any(t is None for t in subtypes): raise Exception(f"Unsupported definition: {subtypes}") - + if subtypes[0] in ["diffusers", "transformers"]: res_type = sys.modules[subtypes[0]] subtypes = subtypes[1:] @@ -129,7 +151,6 @@ class ModelBase(metaclass=ABCMeta): res_type = sys.modules["diffusers"] res_type = getattr(res_type, "pipelines") - for subtype in subtypes: res_type = getattr(res_type, subtype) return res_type @@ -138,7 +159,7 @@ class ModelBase(metaclass=ABCMeta): def _get_configs(cls): with suppress(Exception): return cls.__configs - + configs = dict() for name in dir(cls): if name.startswith("__"): @@ -148,7 +169,7 @@ class ModelBase(metaclass=ABCMeta): if not isinstance(value, type) or not issubclass(value, ModelConfigBase): continue - if hasattr(inspect,'get_annotations'): + if hasattr(inspect, "get_annotations"): fields = inspect.get_annotations(value) else: fields = value.__annotations__ @@ -161,7 +182,9 @@ class ModelBase(metaclass=ABCMeta): for model_format in field: configs[model_format.value] = value - elif typing.get_origin(field) is Literal and all(isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__): + elif typing.get_origin(field) is Literal and all( + isinstance(arg, str) and isinstance(arg, Enum) for arg in field.__args__ + ): for model_format in field.__args__: configs[model_format.value] = value @@ -213,8 +236,8 @@ class ModelBase(metaclass=ABCMeta): class DiffusersModel(ModelBase): - #child_types: Dict[str, Type] - #child_sizes: Dict[str, int] + # child_types: Dict[str, Type] + # child_sizes: Dict[str, int] def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): super().__init__(model_path, base_model, model_type) @@ -224,7 +247,7 @@ class DiffusersModel(ModelBase): try: config_data = DiffusionPipeline.load_config(self.model_path) - #config_data = json.loads(os.path.join(self.model_path, "model_index.json")) + # config_data = json.loads(os.path.join(self.model_path, "model_index.json")) except: raise Exception("Invalid diffusers model! (model_index.json not found or invalid)") @@ -238,15 +261,12 @@ class DiffusersModel(ModelBase): self.child_types[child_name] = child_type self.child_sizes[child_name] = calc_model_size_by_fs(self.model_path, subfolder=child_name) - - def get_size(self, child_type: Optional[SubModelType] = None): if child_type is None: return sum(self.child_sizes.values()) else: return self.child_sizes[child_type] - def get_model( self, torch_dtype: Optional[torch.dtype], @@ -256,7 +276,7 @@ class DiffusersModel(ModelBase): if child_type is None: raise Exception("Child model type can't be null on diffusers model") if child_type not in self.child_types: - return None # TODO: or raise + return None # TODO: or raise if torch_dtype == torch.float16: variants = ["fp16", None] @@ -276,8 +296,8 @@ class DiffusersModel(ModelBase): ) break except Exception as e: - #print("====ERR LOAD====") - #print(f"{variant}: {e}") + # print("====ERR LOAD====") + # print(f"{variant}: {e}") pass else: raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") @@ -286,15 +306,10 @@ class DiffusersModel(ModelBase): self.child_sizes[child_type] = calc_model_size_by_data(model) return model - #def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: + # def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: - -def calc_model_size_by_fs( - model_path: str, - subfolder: Optional[str] = None, - variant: Optional[str] = None -): +def calc_model_size_by_fs(model_path: str, subfolder: Optional[str] = None, variant: Optional[str] = None): if subfolder is not None: model_path = os.path.join(model_path, subfolder) @@ -336,12 +351,12 @@ def calc_model_size_by_fs( # calculate files size if there is no index file formats = [ - (".safetensors",), # safetensors - (".bin",), # torch - (".onnx", ".pb"), # onnx - (".msgpack",), # flax - (".ckpt",), # tf - (".h5",), # tf2 + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 ] for file_format in formats: @@ -354,9 +369,9 @@ def calc_model_size_by_fs( file_stats = os.stat(os.path.join(model_path, model_file)) model_size += file_stats.st_size return model_size - - #raise NotImplementedError(f"Unknown model structure! Files: {all_files}") - return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu + + # raise NotImplementedError(f"Unknown model structure! Files: {all_files}") + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu def calc_model_size_by_data(model) -> int: @@ -377,18 +392,18 @@ def _calc_pipeline_by_data(pipeline) -> int: if submodel is not None and isinstance(submodel, torch.nn.Module): res += _calc_model_by_data(submodel) return res - + def _calc_model_by_data(model) -> int: - mem_params = sum([param.nelement()*param.element_size() for param in model.parameters()]) - mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) - mem = mem_params + mem_bufs # in bytes + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem = mem_params + mem_bufs # in bytes return mem def _calc_onnx_model_by_data(model) -> int: - tensor_size = model.tensors.size() * 2 # The session doubles this - mem = tensor_size # in bytes + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes return mem @@ -396,11 +411,15 @@ def _fast_safetensors_reader(path: str): checkpoint = dict() device = torch.device("meta") with open(path, "rb") as f: - definition_len = int.from_bytes(f.read(8), 'little') + definition_len = int.from_bytes(f.read(8), "little") definition_json = f.read(definition_len) definition = json.loads(definition_json) - if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}: + if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in { + "pt", + "torch", + "pytorch", + }: raise Exception("Supported only pytorch safetensors files") definition.pop("__metadata__", None) @@ -419,6 +438,7 @@ def _fast_safetensors_reader(path: str): return checkpoint + def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): if str(path).endswith(".safetensors"): try: @@ -430,33 +450,37 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False): if scan: scan_result = scan_file_path(path) if scan_result.infected_files != 0: - raise Exception(f"The model file \"{path}\" is potentially infected by malware. Aborting import.") + raise Exception(f'The model file "{path}" is potentially infected by malware. Aborting import.') checkpoint = torch.load(path, map_location=torch.device("meta")) return checkpoint + import warnings from diffusers import logging as diffusers_logging from transformers import logging as transformers_logging + class SilenceWarnings(object): def __init__(self): self.transformers_verbosity = transformers_logging.get_verbosity() self.diffusers_verbosity = diffusers_logging.get_verbosity() - + def __enter__(self): transformers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error() - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") def __exit__(self, type, value, traceback): transformers_logging.set_verbosity(self.transformers_verbosity) diffusers_logging.set_verbosity(self.diffusers_verbosity) - warnings.simplefilter('default') + warnings.simplefilter("default") + ONNX_WEIGHTS_NAME = "model.onnx" + + class IAIOnnxRuntimeModel: class _tensor_access: - def __init__(self, model): self.model = model self.indexes = dict() @@ -483,14 +507,14 @@ class IAIOnnxRuntimeModel: def items(self): raise NotImplementedError("tensor.items") - #return [(obj.name, obj) for obj in self.raw_proto] + # return [(obj.name, obj) for obj in self.raw_proto] def keys(self): return self.indexes.keys() def values(self): raise NotImplementedError("tensor.values") - #return [obj for obj in self.raw_proto] + # return [obj for obj in self.raw_proto] def size(self): bytesSum = 0 @@ -498,7 +522,6 @@ class IAIOnnxRuntimeModel: bytesSum += sys.getsizeof(node.raw_data) return bytesSum - class _access_helper: def __init__(self, raw_proto): self.indexes = dict() @@ -527,7 +550,7 @@ class IAIOnnxRuntimeModel: def values(self): return [obj for obj in self.raw_proto] - + def __init__(self, model_path: str, provider: Optional[str]): self.path = model_path self.session = None @@ -567,12 +590,12 @@ class IAIOnnxRuntimeModel: # TODO: integrate with model manager/cache def create_session(self, height=None, width=None): if self.session is None or self.session_width != width or self.session_height != height: - #onnx.save(self.proto, "tmp.onnx") - #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) # TODO: something to be able to get weight when they already moved outside of model proto - #(trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) sess = SessionOptions() - #self._external_data.update(**external_data) + # self._external_data.update(**external_data) # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) # sess.enable_profiling = True @@ -603,13 +626,14 @@ class IAIOnnxRuntimeModel: try: self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) except Exception as e: - raise e - #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) # self.io_binding = self.session.io_binding() def release_session(self): self.session = None import gc + gc.collect() return @@ -655,4 +679,3 @@ class IAIOnnxRuntimeModel: # TODO: session options return cls(model_path, provider=provider) - diff --git a/invokeai/backend/model_management/models/controlnet.py b/invokeai/backend/model_management/models/controlnet.py index 41952af5d9..e075843a56 100644 --- a/invokeai/backend/model_management/models/controlnet.py +++ b/invokeai/backend/model_management/models/controlnet.py @@ -18,13 +18,15 @@ from .base import ( ) from invokeai.app.services.config import InvokeAIAppConfig + class ControlNetModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" + class ControlNetModel(ModelBase): - #model_class: Type - #model_size: int + # model_class: Type + # model_size: int class DiffusersConfig(ModelConfigBase): model_format: Literal[ControlNetModelFormat.Diffusers] @@ -39,7 +41,7 @@ class ControlNetModel(ModelBase): try: config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - #config = json.loads(os.path.join(self.model_path, "config.json")) + # config = json.loads(os.path.join(self.model_path, "config.json")) except: raise Exception("Invalid controlnet model! (config.json not found or invalid)") @@ -67,7 +69,7 @@ class ControlNetModel(ModelBase): raise Exception("There is no child models in controlnet model") model = None - for variant in ['fp16',None]: + for variant in ["fp16", None]: try: model = self.model_class.from_pretrained( self.model_path, @@ -79,7 +81,7 @@ class ControlNetModel(ModelBase): pass if not model: raise ModelNotFoundException() - + # calc more accurate size self.model_size = calc_model_size_by_data(model) return model @@ -105,29 +107,30 @@ class ControlNetModel(ModelBase): @classmethod def convert_if_required( - cls, - model_path: str, - output_path: str, - config: ModelConfigBase, - base_model: BaseModelType, - ) -> str: - if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: - return _convert_controlnet_ckpt_and_cache( - model_path = model_path, - model_config = config.config, - output_path = output_path, - base_model = base_model, - ) - else: - return model_path - -@classmethod -def _convert_controlnet_ckpt_and_cache( cls, model_path: str, output_path: str, + config: ModelConfigBase, base_model: BaseModelType, - model_config: ControlNetModel.CheckpointConfig, + ) -> str: + if cls.detect_format(model_path) == ControlNetModelFormat.Checkpoint: + return _convert_controlnet_ckpt_and_cache( + model_path=model_path, + model_config=config.config, + output_path=output_path, + base_model=base_model, + ) + else: + return model_path + + +@classmethod +def _convert_controlnet_ckpt_and_cache( + cls, + model_path: str, + output_path: str, + base_model: BaseModelType, + model_config: ControlNetModel.CheckpointConfig, ) -> str: """ Convert the controlnet from checkpoint format to diffusers format, @@ -144,12 +147,13 @@ def _convert_controlnet_ckpt_and_cache( # to avoid circular import errors from ..convert_ckpt_to_diffusers import convert_controlnet_to_diffusers + convert_controlnet_to_diffusers( weights, output_path, - original_config_file = app_config.root_path / model_config, - image_size = 512, - scan_needed = True, - from_safetensors = weights.suffix == ".safetensors" + original_config_file=app_config.root_path / model_config, + image_size=512, + scan_needed=True, + from_safetensors=weights.suffix == ".safetensors", ) return output_path diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index eb771841ec..7cf87d8fe7 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -12,18 +12,21 @@ from .base import ( InvalidModelException, ModelNotFoundException, ) + # TODO: naming from ..lora import LoRAModel as LoRAModelRaw + class LoRAModelFormat(str, Enum): LyCORIS = "lycoris" Diffusers = "diffusers" + class LoRAModel(ModelBase): - #model_size: int + # model_size: int class Config(ModelConfigBase): - model_format: LoRAModelFormat # TODO: + model_format: LoRAModelFormat # TODO: def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert model_type == ModelType.Lora diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py index 9a0de4ea79..7fc3efb77c 100644 --- a/invokeai/backend/model_management/models/sdxl.py +++ b/invokeai/backend/model_management/models/sdxl.py @@ -15,12 +15,13 @@ from .base import ( ) from omegaconf import OmegaConf + class StableDiffusionXLModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" - -class StableDiffusionXLModel(DiffusersModel): + +class StableDiffusionXLModel(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusionXLModelFormat.Diffusers] @@ -53,7 +54,7 @@ class StableDiffusionXLModel(DiffusersModel): else: checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get('state_dict', checkpoint) + checkpoint = checkpoint.get("state_dict", checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusionXLModelFormat.Diffusers: @@ -61,7 +62,7 @@ class StableDiffusionXLModel(DiffusersModel): if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) - in_channels = unet_config['in_channels'] + in_channels = unet_config["in_channels"] else: raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") @@ -81,11 +82,10 @@ class StableDiffusionXLModel(DiffusersModel): if ckpt_config_path is None: # TO DO: implement picking pass - + return cls.create_config( path=path, model_format=model_format, - config=ckpt_config_path, variant=variant, ) @@ -114,11 +114,12 @@ class StableDiffusionXLModel(DiffusersModel): # source code changes, we simply translate here if isinstance(config, cls.CheckpointConfig): from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache + return _convert_ckpt_and_cache( version=base_model, model_config=config, output_path=output_path, - use_safetensors=False, # corrupts sdxl models for some reason + use_safetensors=False, # corrupts sdxl models for some reason ) else: return model_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index dd0b78ebf7..76b4833f9c 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -26,8 +26,8 @@ class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" -class StableDiffusion1Model(DiffusersModel): +class StableDiffusion1Model(DiffusersModel): class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion1ModelFormat.Diffusers] vae: Optional[str] = Field(None) @@ -38,7 +38,7 @@ class StableDiffusion1Model(DiffusersModel): vae: Optional[str] = Field(None) config: str variant: ModelVariantType - + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.Main @@ -59,7 +59,7 @@ class StableDiffusion1Model(DiffusersModel): else: checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get('state_dict', checkpoint) + checkpoint = checkpoint.get("state_dict", checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion1ModelFormat.Diffusers: @@ -67,7 +67,7 @@ class StableDiffusion1Model(DiffusersModel): if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) - in_channels = unet_config['in_channels'] + in_channels = unet_config["in_channels"] else: raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format") @@ -88,7 +88,6 @@ class StableDiffusion1Model(DiffusersModel): return cls.create_config( path=path, model_format=model_format, - config=ckpt_config_path, variant=variant, ) @@ -125,16 +124,17 @@ class StableDiffusion1Model(DiffusersModel): version=BaseModelType.StableDiffusion1, model_config=config, output_path=output_path, - ) + ) else: return model_path + class StableDiffusion2ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" -class StableDiffusion2Model(DiffusersModel): +class StableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class DiffusersConfig(ModelConfigBase): model_format: Literal[StableDiffusion2ModelFormat.Diffusers] @@ -167,7 +167,7 @@ class StableDiffusion2Model(DiffusersModel): else: checkpoint = read_checkpoint_meta(path) - checkpoint = checkpoint.get('state_dict', checkpoint) + checkpoint = checkpoint.get("state_dict", checkpoint) in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] elif model_format == StableDiffusion2ModelFormat.Diffusers: @@ -175,7 +175,7 @@ class StableDiffusion2Model(DiffusersModel): if os.path.exists(unet_config_path): with open(unet_config_path, "r") as f: unet_config = json.loads(f.read()) - in_channels = unet_config['in_channels'] + in_channels = unet_config["in_channels"] else: raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") @@ -198,7 +198,6 @@ class StableDiffusion2Model(DiffusersModel): return cls.create_config( path=path, model_format=model_format, - config=ckpt_config_path, variant=variant, ) @@ -239,17 +238,19 @@ class StableDiffusion2Model(DiffusersModel): else: return model_path + # TODO: rework # pass precision - currently defaulting to fp16 def _convert_ckpt_and_cache( - version: BaseModelType, - model_config: Union[StableDiffusion1Model.CheckpointConfig, - StableDiffusion2Model.CheckpointConfig, - StableDiffusionXLModel.CheckpointConfig, - ], - output_path: str, - use_save_model: bool=False, - **kwargs, + version: BaseModelType, + model_config: Union[ + StableDiffusion1Model.CheckpointConfig, + StableDiffusion2Model.CheckpointConfig, + StableDiffusionXLModel.CheckpointConfig, + ], + output_path: str, + use_save_model: bool = False, + **kwargs, ) -> str: """ Convert the checkpoint model indicated in mconfig into a @@ -270,13 +271,14 @@ def _convert_ckpt_and_cache( from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from ...util.devices import choose_torch_device, torch_dtype - model_base_to_model_type = {BaseModelType.StableDiffusion1: 'FrozenCLIPEmbedder', - BaseModelType.StableDiffusion2: 'FrozenOpenCLIPEmbedder', - BaseModelType.StableDiffusionXL: 'SDXL', - BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner', - } - logger.info(f'Converting {weights} to diffusers format') - with SilenceWarnings(): + model_base_to_model_type = { + BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", + BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusionXL: "SDXL", + BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", + } + logger.info(f"Converting {weights} to diffusers format") + with SilenceWarnings(): convert_ckpt_to_diffusers( weights, output_path, @@ -286,12 +288,13 @@ def _convert_ckpt_and_cache( original_config_file=config_file, extract_ema=True, scan_needed=True, - from_safetensors = weights.suffix == ".safetensors", - precision = torch_dtype(choose_torch_device()), + from_safetensors=weights.suffix == ".safetensors", + precision=torch_dtype(choose_torch_device()), **kwargs, ) return output_path + def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ckpt_configs = { BaseModelType.StableDiffusion1: { @@ -299,7 +302,7 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", }, BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) + ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512) ModelVariantType.Inpaint: "v2-inpainting-inference.yaml", ModelVariantType.Depth: "v2-midas-inference.yaml", }, @@ -321,8 +324,6 @@ def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType): if config_path.is_relative_to(app_config.root_path): config_path = config_path.relative_to(app_config.root_path) return str(config_path) - + except: return None - - diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_management/models/stable_diffusion_onnx.py index 0c769c01fc..03693e2c3e 100644 --- a/invokeai/backend/model_management/models/stable_diffusion_onnx.py +++ b/invokeai/backend/model_management/models/stable_diffusion_onnx.py @@ -26,13 +26,12 @@ class StableDiffusionOnnxModelFormat(str, Enum): Olive = "olive" Onnx = "onnx" -class ONNXStableDiffusion1Model(DiffusersModel): +class ONNXStableDiffusion1Model(DiffusersModel): class Config(ModelConfigBase): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] variant: ModelVariantType - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion1 assert model_type == ModelType.ONNX @@ -51,7 +50,7 @@ class ONNXStableDiffusion1Model(DiffusersModel): @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) - in_channels = 4 # TODO: + in_channels = 4 # TODO: if in_channels == 9: variant = ModelVariantType.Inpaint @@ -60,11 +59,9 @@ class ONNXStableDiffusion1Model(DiffusersModel): else: raise Exception("Unkown stable diffusion 1.* model format") - return cls.create_config( path=path, model_format=model_format, - variant=variant, ) @@ -87,8 +84,8 @@ class ONNXStableDiffusion1Model(DiffusersModel): ) -> str: return model_path -class ONNXStableDiffusion2Model(DiffusersModel): +class ONNXStableDiffusion2Model(DiffusersModel): # TODO: check that configs overwriten properly class Config(ModelConfigBase): model_format: Literal[StableDiffusionOnnxModelFormat.Onnx] @@ -96,7 +93,6 @@ class ONNXStableDiffusion2Model(DiffusersModel): prediction_type: SchedulerPredictionType upcast_attention: bool - def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): assert base_model == BaseModelType.StableDiffusion2 assert model_type == ModelType.ONNX @@ -114,7 +110,7 @@ class ONNXStableDiffusion2Model(DiffusersModel): @classmethod def probe_config(cls, path: str, **kwargs): model_format = cls.detect_format(path) - in_channels = 4 # TODO: + in_channels = 4 # TODO: if in_channels == 9: variant = ModelVariantType.Inpaint @@ -136,7 +132,6 @@ class ONNXStableDiffusion2Model(DiffusersModel): return cls.create_config( path=path, model_format=model_format, - variant=variant, prediction_type=prediction_type, upcast_attention=upcast_attention, @@ -160,4 +155,3 @@ class ONNXStableDiffusion2Model(DiffusersModel): base_model: BaseModelType, ) -> str: return model_path - diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index eea0e85245..a949a15be1 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -11,11 +11,13 @@ from .base import ( ModelNotFoundException, InvalidModelException, ) + # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw + class TextualInversionModel(ModelBase): - #model_size: int + # model_size: int class Config(ModelConfigBase): model_format: None @@ -65,7 +67,7 @@ class TextualInversionModel(ModelBase): if os.path.isdir(path): if os.path.exists(os.path.join(path, "learned_embeds.bin")): - return None # diffusers-ti + return None # diffusers-ti if os.path.isfile(path): if any([path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt", "bin"]]): diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index f740615509..b15844bcf8 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -22,13 +22,15 @@ from invokeai.app.services.config import InvokeAIAppConfig from diffusers.utils import is_safetensors_available from omegaconf import OmegaConf + class VaeModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" + class VaeModel(ModelBase): - #vae_class: Type - #model_size: int + # vae_class: Type + # model_size: int class Config(ModelConfigBase): model_format: VaeModelFormat @@ -39,7 +41,7 @@ class VaeModel(ModelBase): try: config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json") - #config = json.loads(os.path.join(self.model_path, "config.json")) + # config = json.loads(os.path.join(self.model_path, "config.json")) except: raise Exception("Invalid vae model! (config.json not found or invalid)") @@ -95,7 +97,7 @@ class VaeModel(ModelBase): cls, model_path: str, output_path: str, - config: ModelConfigBase, # empty config or config of parent model + config: ModelConfigBase, # empty config or config of parent model base_model: BaseModelType, ) -> str: if cls.detect_format(model_path) == VaeModelFormat.Checkpoint: @@ -108,6 +110,7 @@ class VaeModel(ModelBase): else: return model_path + # TODO: rework def _convert_vae_ckpt_and_cache( weights_path: str, @@ -138,13 +141,14 @@ def _convert_vae_ckpt_and_cache( 2.1 - 768 """ image_size = 512 - + # return cached version if it exists if output_path.exists(): return output_path if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: from .stable_diffusion import _select_ckpt_config + # all sd models use same vae settings config_file = _select_ckpt_config(base_model, ModelVariantType.Normal) else: @@ -152,7 +156,8 @@ def _convert_vae_ckpt_and_cache( # this avoids circular import error from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers - if weights_path.suffix == '.safetensors': + + if weights_path.suffix == ".safetensors": checkpoint = safetensors.torch.load_file(weights_path, device="cpu") else: checkpoint = torch.load(weights_path, map_location="cpu") @@ -161,15 +166,12 @@ def _convert_vae_ckpt_and_cache( if "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - config = OmegaConf.load(app_config.root_path/config_file) + config = OmegaConf.load(app_config.root_path / config_file) vae_model = convert_ldm_vae_to_diffusers( - checkpoint = checkpoint, - vae_config = config, - image_size = image_size, - ) - vae_model.save_pretrained( - output_path, - safe_serialization=is_safetensors_available() + checkpoint=checkpoint, + vae_config=config, + image_size=image_size, ) + vae_model.save_pretrained(output_path, safe_serialization=is_safetensors_available()) return output_path diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8acfb100a6..624d47ff64 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -47,6 +47,7 @@ from .diffusion import ( ) from .offloading import FullyLoadedModelGroup, ModelGroup + @dataclass class PipelineIntermediateState: run_id: str @@ -72,7 +73,11 @@ class AddsMaskLatents: initial_image_latents: torch.Tensor def __call__( - self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor, **kwargs, + self, + latents: torch.Tensor, + t: torch.Tensor, + text_embeddings: torch.Tensor, + **kwargs, ) -> torch.Tensor: model_input = self.add_mask_channels(latents) return self.forward(model_input, t, text_embeddings, **kwargs) @@ -80,12 +85,8 @@ class AddsMaskLatents: def add_mask_channels(self, latents): batch_size = latents.size(0) # duplicate mask and latents for each batch - mask = einops.repeat( - self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size - ) - image_latents = einops.repeat( - self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size - ) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) # add mask and image as additional channels model_input, _ = einops.pack([latents, mask, image_latents], "b * h w") return model_input @@ -103,9 +104,7 @@ class AddsMaskGuidance: noise: torch.Tensor _debug: Optional[Callable] = None - def __call__( - self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning - ) -> BaseOutput: + def __call__(self, step_output: Union[BaseOutput, SchedulerOutput], t: torch.Tensor, conditioning) -> BaseOutput: output_class = step_output.__class__ # We'll create a new one with masked data. # The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it. @@ -116,11 +115,7 @@ class AddsMaskGuidance: # Mask anything that has the same shape as prev_sample, return others as-is. return output_class( { - k: ( - self.apply_mask(v, self._t_for_field(k, t)) - if are_like_tensors(prev_sample, v) - else v - ) + k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v) for k, v in step_output.items() } ) @@ -132,9 +127,7 @@ class AddsMaskGuidance: def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: batch_size = latents.size(0) - mask = einops.repeat( - self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size - ) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) if t.dim() == 0: # some schedulers expect t to be one-dimensional. # TODO: file diffusers bug about inconsistency? @@ -144,12 +137,8 @@ class AddsMaskGuidance: mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t) # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? # mask_latents = self.scheduler.scale_model_input(mask_latents, t) - mask_latents = einops.repeat( - mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size - ) - masked_input = torch.lerp( - mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype) - ) + mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) if self._debug: self._debug(masked_input, f"t={t} lerped") return masked_input @@ -159,9 +148,7 @@ def trim_to_multiple_of(*args, multiple_of=8): return tuple((x - x % multiple_of) for x in args) -def image_resized_to_grid_as_tensor( - image: PIL.Image.Image, normalize: bool = True, multiple_of=8 -) -> torch.FloatTensor: +def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = True, multiple_of=8) -> torch.FloatTensor: """ :param image: input image @@ -211,6 +198,7 @@ class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): raise AssertionError("why was that an empty generator?") return result + @dataclass class ControlNetData: model: ControlNetModel = Field(default=None) @@ -341,9 +329,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # FIXME: can't currently register control module # control_model=control_model, ) - self.invokeai_diffuser = InvokeAIDiffuserComponent( - self.unet, self._unet_forward - ) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device) self._model_group.install(*self._submodels) @@ -354,11 +340,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if xformers is available, use it, otherwise use sliced attention. """ config = InvokeAIAppConfig.get_config() - if ( - torch.cuda.is_available() - and is_xformers_available() - and not config.disable_xformers - ): + if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers: self.enable_xformers_memory_efficient_attention() else: if self.device.type == "cpu" or self.device.type == "mps": @@ -369,9 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): raise ValueError(f"unrecognized device {self.device}") # input tensor of [1, 4, h/8, w/8] # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = ( - latents.element_size() + 4 - ) + bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 max_size_required_for_baddbmm = ( 16 * latents.size(dim=2) @@ -380,9 +360,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): * latents.size(dim=3) * bytes_per_element_needed_for_baddbmm_duplication ) - if max_size_required_for_baddbmm > ( - mem_free * 3.0 / 4.0 - ): # 3.3 / 4.0 is from old Invoke code + if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code self.enable_attention_slicing(slice_size="max") elif torch.backends.mps.is_available(): # diffusers recommends always enabling for mps @@ -470,7 +448,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): control_data: List[ControlNetData] = None, ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: if self.scheduler.config.get("cpu_only", False): - scheduler_device = torch.device('cpu') + scheduler_device = torch.device("cpu") else: scheduler_device = self._model_group.device_for(self.unet) @@ -488,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id=run_id, additional_guidance=additional_guidance, control_data=control_data, - callback=callback, ) return result.latents, result.attention_map_saver @@ -511,9 +488,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): additional_guidance = [] extra_conditioning_info = conditioning_data.extra with self.invokeai_diffuser.custom_attention_context( - self.invokeai_diffuser.model, - extra_conditioning_info=extra_conditioning_info, - step_count=len(self.scheduler.timesteps), + self.invokeai_diffuser.model, + extra_conditioning_info=extra_conditioning_info, + step_count=len(self.scheduler.timesteps), ): yield PipelineIntermediateState( run_id=run_id, @@ -607,16 +584,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # that are combined at higher level to make control_mode enum # soft_injection determines whether to do per-layer re-weighting adjustment (if True) # or default weighting (if False) - soft_injection = (control_mode == "more_prompt" or control_mode == "more_control") + soft_injection = control_mode == "more_prompt" or control_mode == "more_control" # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) # or the default both conditional and unconditional (if False) - cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced") + cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) # only apply controlnet if current step is within the controlnet's begin/end step range if step_index >= first_control_step and step_index <= last_control_step: - if cfg_injection: control_latent_input = unet_latent_input else: @@ -629,7 +605,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): encoder_hidden_states = conditioning_data.text_embeddings encoder_attention_mask = None else: - encoder_hidden_states, encoder_attention_mask = self.invokeai_diffuser._concat_conditionings_for_batch( + ( + encoder_hidden_states, + encoder_attention_mask, + ) = self.invokeai_diffuser._concat_conditionings_for_batch( conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings, ) @@ -646,9 +625,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timestep=timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_datum.image_tensor, - conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale + conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale encoder_attention_mask=encoder_attention_mask, - guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel return_dict=False, ) if cfg_injection: @@ -678,13 +657,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=step_index, total_step_count=total_step_count, down_block_additional_residuals=down_block_res_samples, # from controlnet(s) - mid_block_additional_residual=mid_block_res_sample, # from controlnet(s) + mid_block_additional_residual=mid_block_res_sample, # from controlnet(s) ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step( - noise_pred, timestep, latents, **conditioning_data.scheduler_args - ) + step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. # But the way things are now, scheduler runs _after_ that, so there was @@ -710,17 +687,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # use of AddsMaskLatents. latents = AddsMaskLatents( self._unet_forward, - mask=torch.ones_like( - latents[:1, :1], device=latents.device, dtype=latents.dtype - ), - initial_image_latents=torch.zeros_like( - latents[:1], device=latents.device, dtype=latents.dtype - ), + mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype), + initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype), ).add_mask_channels(latents) # First three args should be positional, not keywords, so torch hooks can see them. return self.unet( - latents, t, text_embeddings, cross_attention_kwargs=cross_attention_kwargs, + latents, + t, + text_embeddings, + cross_attention_kwargs=cross_attention_kwargs, **kwargs, ).sample @@ -774,9 +750,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) -> InvokeAIStableDiffusionPipelineOutput: timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength) result_latents, result_attention_maps = self.latents_from_embeddings( - latents=initial_latents if strength < 1.0 else torch.zeros_like( - initial_latents, device=initial_latents.device, dtype=initial_latents.dtype - ), + latents=initial_latents + if strength < 1.0 + else torch.zeros_like(initial_latents, device=initial_latents.device, dtype=initial_latents.dtype), num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, timesteps=timesteps, @@ -797,14 +773,12 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): ) return self.check_for_safety(output, dtype=conditioning_data.dtype) - def get_img2img_timesteps( - self, num_inference_steps: int, strength: float, device=None - ) -> (torch.Tensor, int): + def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device=None) -> (torch.Tensor, int): img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) assert img2img_pipeline.scheduler is self.scheduler if self.scheduler.config.get("cpu_only", False): - scheduler_device = torch.device('cpu') + scheduler_device = torch.device("cpu") else: scheduler_device = self._model_group.device_for(self.unet) @@ -849,18 +823,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # 6. Prepare latent variables # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents # because we have our own noise function - init_image_latents = self.non_noised_latents_from_image( - init_image, device=device, dtype=latents_dtype - ) + init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) if seed is not None: set_seed(seed) noise = noise_func(init_image_latents) if mask.dim() == 3: mask = mask.unsqueeze(0) - latent_mask = tv_resize( - mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR - ).to(device=device, dtype=latents_dtype) + latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR).to( + device=device, dtype=latents_dtype + ) guidance: List[Callable] = [] @@ -868,22 +840,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint # (that's why there's a mask!) but it seems to really want that blanked out. masked_init_image = init_image * torch.where(mask < 0.5, 1, 0) - masked_latents = self.non_noised_latents_from_image( - masked_init_image, device=device, dtype=latents_dtype - ) + masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype) # TODO: we should probably pass this in so we don't have to try/finally around setting it. self.invokeai_diffuser.model_forward_callback = AddsMaskLatents( self._unet_forward, latent_mask, masked_latents ) else: - guidance.append( - AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise) - ) + guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise)) try: result_latents, result_attention_maps = self.latents_from_embeddings( - latents=init_image_latents if strength < 1.0 else torch.zeros_like( + latents=init_image_latents + if strength < 1.0 + else torch.zeros_like( init_image_latents, device=init_image_latents.device, dtype=init_image_latents.dtype ), num_inference_steps=num_inference_steps, @@ -914,18 +884,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): with torch.inference_mode(): self._model_group.load(self.vae) init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample().to( - dtype=dtype - ) # FIXME: uses torch.randn. make reproducible! + init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = 0.18215 * init_latents return init_latents def check_for_safety(self, output, dtype): with torch.inference_mode(): - screened_images, has_nsfw_concept = self.run_safety_checker( - output.images, dtype=dtype - ) + screened_images, has_nsfw_concept = self.run_safety_checker(output.images, dtype=dtype) screened_attention_map_saver = None if has_nsfw_concept is None or not has_nsfw_concept: screened_attention_map_saver = output.attention_map_saver @@ -949,9 +915,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): def debug_latents(self, latents, msg): from invokeai.backend.image_util import debug_image + with torch.inference_mode(): decoded = self.numpy_to_pil(self.decode_latents(latents)) for i, img in enumerate(decoded): - debug_image( - img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True - ) + debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True) diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py index 79a0982cfe..38763ebbee 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_control.py @@ -17,6 +17,7 @@ from torch import nn import invokeai.backend.util.logging as logger from ...util import torch_dtype + class CrossAttentionType(enum.Enum): SELF = 1 TOKENS = 2 @@ -55,9 +56,7 @@ class Context: if name in self.self_cross_attention_module_identifiers: assert False, f"name {name} cannot appear more than once" self.self_cross_attention_module_identifiers.append(name) - for name, module in get_cross_attention_modules( - model, CrossAttentionType.TOKENS - ): + for name, module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): if name in self.tokens_cross_attention_module_identifiers: assert False, f"name {name} cannot appear more than once" self.tokens_cross_attention_module_identifiers.append(name) @@ -68,9 +67,7 @@ class Context: else: self.tokens_cross_attention_action = Context.Action.SAVE - def request_apply_saved_attention_maps( - self, cross_attention_type: CrossAttentionType - ): + def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): if cross_attention_type == CrossAttentionType.SELF: self.self_cross_attention_action = Context.Action.APPLY else: @@ -139,9 +136,7 @@ class Context: saved_attention_dict = self.saved_cross_attention_maps[identifier] if requested_dim is None: if saved_attention_dict["dim"] is not None: - raise RuntimeError( - f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}" - ) + raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") return saved_attention_dict["slices"][0] if saved_attention_dict["dim"] == requested_dim: @@ -154,21 +149,13 @@ class Context: if saved_attention_dict["dim"] is None: whole_saved_attention = saved_attention_dict["slices"][0] if requested_dim == 0: - return whole_saved_attention[ - requested_offset : requested_offset + slice_size - ] + return whole_saved_attention[requested_offset : requested_offset + slice_size] elif requested_dim == 1: - return whole_saved_attention[ - :, requested_offset : requested_offset + slice_size - ] + return whole_saved_attention[:, requested_offset : requested_offset + slice_size] - raise RuntimeError( - f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}" - ) + raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") - def get_slicing_strategy( - self, identifier: str - ) -> tuple[Optional[int], Optional[int]]: + def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: saved_attention = self.saved_cross_attention_maps.get(identifier, None) if saved_attention is None: return None, None @@ -201,9 +188,7 @@ class InvokeAICrossAttentionMixin: def set_attention_slice_wrangler( self, - wrangler: Optional[ - Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor] - ], + wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]], ): """ Set custom attention calculator to be called when attention is calculated @@ -219,14 +204,10 @@ class InvokeAICrossAttentionMixin: """ self.attention_slice_wrangler = wrangler - def set_slicing_strategy_getter( - self, getter: Optional[Callable[[nn.Module], tuple[int, int]]] - ): + def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]): self.slicing_strategy_getter = getter - def set_attention_slice_calculated_callback( - self, callback: Optional[Callable[[torch.Tensor], None]] - ): + def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): self.attention_slice_calculated_callback = callback def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): @@ -247,45 +228,31 @@ class InvokeAICrossAttentionMixin: ) # calculate attention slice by taking the best scores for each latent pixel - default_attention_slice = attention_scores.softmax( - dim=-1, dtype=attention_scores.dtype - ) + default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) attention_slice_wrangler = self.attention_slice_wrangler if attention_slice_wrangler is not None: - attention_slice = attention_slice_wrangler( - self, default_attention_slice, dim, offset, slice_size - ) + attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) else: attention_slice = default_attention_slice if self.attention_slice_calculated_callback is not None: - self.attention_slice_calculated_callback( - attention_slice, dim, offset, slice_size - ) + self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) hidden_states = torch.bmm(attention_slice, value) return hidden_states def einsum_op_slice_dim0(self, q, k, v, slice_size): - r = torch.zeros( - q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype - ) + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = self.einsum_lowest_level( - q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size - ) + r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r def einsum_op_slice_dim1(self, q, k, v, slice_size): - r = torch.zeros( - q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype - ) + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = self.einsum_lowest_level( - q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size - ) + r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v): @@ -353,6 +320,7 @@ def restore_default_cross_attention( else: remove_attention_function(model) + def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -372,7 +340,7 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + if name == "equal": # or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 @@ -386,16 +354,14 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode else: # try to re-use an existing slice size default_slice_size = 4 - slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + slice_size = next( + (p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size + ) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) -def get_cross_attention_modules( - model, which: CrossAttentionType -) -> list[tuple[str, InvokeAICrossAttentionMixin]]: - cross_attention_class: type = ( - InvokeAIDiffusersCrossAttention - ) +def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: + cross_attention_class: type = InvokeAIDiffusersCrossAttention which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" attention_module_tuples = [ (name, module) @@ -420,9 +386,7 @@ def get_cross_attention_modules( def inject_attention_function(unet, context: Context): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def attention_slice_wrangler( - module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size - ): + def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size): # memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() attention_slice = suggested_attention_slice @@ -430,9 +394,7 @@ def inject_attention_function(unet, context: Context): if context.get_should_save_maps(module.identifier): # print(module.identifier, "saving suggested_attention_slice of shape", # suggested_attention_slice.shape, "dim", dim, "offset", offset) - slice_to_save = ( - attention_slice.to("cpu") if dim is not None else attention_slice - ) + slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice context.save_slice( module.identifier, slice_to_save, @@ -442,31 +404,20 @@ def inject_attention_function(unet, context: Context): ) elif context.get_should_apply_saved_maps(module.identifier): # print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) - saved_attention_slice = context.get_slice( - module.identifier, dim, offset, slice_size - ) + saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) # slice may have been offloaded to CPU - saved_attention_slice = saved_attention_slice.to( - suggested_attention_slice.device - ) + saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) if context.is_tokens_cross_attention(module.identifier): index_map = context.cross_attention_index_map - remapped_saved_attention_slice = torch.index_select( - saved_attention_slice, -1, index_map - ) + remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) this_attention_slice = suggested_attention_slice - mask = context.cross_attention_mask.to( - torch_dtype(suggested_attention_slice.device) - ) + mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) saved_mask = mask this_mask = 1 - mask - attention_slice = ( - remapped_saved_attention_slice * saved_mask - + this_attention_slice * this_mask - ) + attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask else: # just use everything attention_slice = saved_attention_slice @@ -480,14 +431,10 @@ def inject_attention_function(unet, context: Context): module.identifier = identifier try: module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter( - lambda module: context.get_slicing_strategy(identifier) - ) + module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) except AttributeError as e: if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print( - f"TODO: implement set_attention_slice_wrangler for {type(module)}" - ) # TODO + print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO else: raise @@ -503,9 +450,7 @@ def remove_attention_function(unet): module.set_slicing_strategy_getter(None) except AttributeError as e: if is_attribute_error_about(e, "set_attention_slice_wrangler"): - print( - f"TODO: implement set_attention_slice_wrangler for {type(module)}" - ) + print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") else: raise @@ -530,9 +475,7 @@ def get_mem_free_total(device): return mem_free_total -class InvokeAIDiffusersCrossAttention( - diffusers.models.attention.Attention, InvokeAICrossAttentionMixin -): +class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin): def __init__(self, **kwargs): super().__init__(**kwargs) InvokeAICrossAttentionMixin.__init__(self) @@ -641,11 +584,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # kwargs swap_cross_attn_context: SwapCrossAttnContext = None, ): - attention_type = ( - CrossAttentionType.SELF - if encoder_hidden_states is None - else CrossAttentionType.TOKENS - ) + attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS # if cross-attention control is not in play, just call through to the base implementation. if ( @@ -654,9 +593,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): or not swap_cross_attn_context.wants_cross_attention_control(attention_type) ): # print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__( - attn, hidden_states, encoder_hidden_states, attention_mask - ) + return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) # else: # print(f"SwapCrossAttnContext for {attention_type} active") @@ -699,18 +636,10 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): query_slice = query[start_idx:end_idx] original_key_slice = original_text_key[start_idx:end_idx] modified_key_slice = modified_text_key[start_idx:end_idx] - attn_mask_slice = ( - attention_mask[start_idx:end_idx] - if attention_mask is not None - else None - ) + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - original_attn_slice = attn.get_attention_scores( - query_slice, original_key_slice, attn_mask_slice - ) - modified_attn_slice = attn.get_attention_scores( - query_slice, modified_key_slice, attn_mask_slice - ) + original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) + modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) # because the prompt modifications may result in token sequences shifted forwards or backwards, # the original attention probabilities must be remapped to account for token index changes in the @@ -722,9 +651,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # only some tokens taken from the original attention probabilities. this is controlled by the mask. mask = swap_cross_attn_context.mask inverse_mask = 1 - mask - attn_slice = ( - remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask - ) + attn_slice = remapped_original_attn_slice * mask + modified_attn_slice * inverse_mask del remapped_original_attn_slice, modified_attn_slice @@ -744,6 +671,4 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): def __init__(self): - super(SwapCrossAttnProcessor, self).__init__( - slice_size=int(1e9) - ) # massive slice size = don't slice + super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice diff --git a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py index c489c2f0a9..b0174a455e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py +++ b/invokeai/backend/stable_diffusion/diffusion/cross_attention_map_saving.py @@ -59,9 +59,7 @@ class AttentionMapSaver: for key, maps in self.collated_maps.items(): # maps has shape [(H*W), N] for N tokens # but we want [N, H, W] - this_scale_factor = math.sqrt( - maps.shape[0] / (latents_width * latents_height) - ) + this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height)) this_maps_height = int(float(latents_height) * this_scale_factor) this_maps_width = int(float(latents_width) * this_scale_factor) # and we need to do some dimension juggling @@ -72,9 +70,7 @@ class AttentionMapSaver: # scale to output size if necessary if this_scale_factor != 1: - maps = tv_resize( - maps, [latents_height, latents_width], InterpolationMode.BICUBIC - ) + maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC) # normalize maps_min = torch.min(maps) @@ -83,9 +79,7 @@ class AttentionMapSaver: maps_normalized = (maps - maps_min) / maps_range # expand to (-0.1, 1.1) and clamp maps_normalized_expanded = maps_normalized * 1.1 - 0.05 - maps_normalized_expanded_clamped = torch.clamp( - maps_normalized_expanded, 0, 1 - ) + maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1) # merge together, producing a vertical stack maps_stacked = torch.reshape( diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f44578cd47..272518e928 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[ Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], ] + @dataclass(frozen=True) class PostprocessingSettings: threshold: float @@ -81,14 +82,12 @@ class InvokeAIDiffuserComponent: @contextmanager def custom_attention_context( cls, - unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs + unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs extra_conditioning_info: Optional[ExtraConditioningInfo], - step_count: int + step_count: int, ): old_attn_processors = None - if extra_conditioning_info and ( - extra_conditioning_info.wants_cross_attention_control - ): + if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control): old_attn_processors = unet.attn_processors # Load lora conditions into the model if extra_conditioning_info.wants_cross_attention_control: @@ -116,27 +115,15 @@ class InvokeAIDiffuserComponent: return saver.add_attention_maps(slice, key) - tokens_cross_attention_modules = get_cross_attention_modules( - self.model, CrossAttentionType.TOKENS - ) + tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) for identifier, module in tokens_cross_attention_modules: - key = ( - "down" - if identifier.startswith("down") - else "up" - if identifier.startswith("up") - else "mid" - ) + key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid" module.set_attention_slice_calculated_callback( - lambda slice, dim, offset, slice_size, key=key: callback( - slice, dim, offset, slice_size, key - ) + lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key) ) def remove_attention_map_saving(self): - tokens_cross_attention_modules = get_cross_attention_modules( - self.model, CrossAttentionType.TOKENS - ) + tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS) for _, module in tokens_cross_attention_modules: module.set_attention_slice_calculated_callback(None) @@ -171,10 +158,8 @@ class InvokeAIDiffuserComponent: context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = step_index / total_step_count - cross_attention_control_types_to_do = ( - context.get_active_cross_attention_control_types_for_step( - percent_through - ) + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( + percent_through ) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 @@ -182,7 +167,11 @@ class InvokeAIDiffuserComponent: if wants_hybrid_conditioning: unconditioned_next_x, conditioned_next_x = self._apply_hybrid_conditioning( - x, sigma, unconditioning, conditioning, **kwargs, + x, + sigma, + unconditioning, + conditioning, + **kwargs, ) elif wants_cross_attention_control: ( @@ -201,7 +190,11 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning_sequentially( - x, sigma, unconditioning, conditioning, **kwargs, + x, + sigma, + unconditioning, + conditioning, + **kwargs, ) else: @@ -209,12 +202,18 @@ class InvokeAIDiffuserComponent: unconditioned_next_x, conditioned_next_x, ) = self._apply_standard_conditioning( - x, sigma, unconditioning, conditioning, **kwargs, + x, + sigma, + unconditioning, + conditioning, + **kwargs, ) combined_next_x = self._combine( # unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale - unconditioned_next_x, conditioned_next_x, guidance_scale + unconditioned_next_x, + conditioned_next_x, + guidance_scale, ) return combined_next_x @@ -229,37 +228,47 @@ class InvokeAIDiffuserComponent: ) -> torch.Tensor: if postprocessing_settings is not None: percent_through = step_index / total_step_count - latents = self.apply_threshold( - postprocessing_settings, latents, percent_through - ) - latents = self.apply_symmetry( - postprocessing_settings, latents, percent_through - ) + latents = self.apply_threshold(postprocessing_settings, latents, percent_through) + latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) return latents def _concat_conditionings_for_batch(self, unconditioning, conditioning): def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype) + conditioning_attention_mask = torch.ones( + (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + ) if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat([ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], dim=1) + conditioning_attention_mask = torch.cat( + [ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], + dim=1, + ) - cond = torch.cat([ - cond, - torch.zeros((cond.shape[0], max_len - cond.shape[1], cond.shape[2]), device=cond.device, dtype=cond.dtype), - ], dim=1) + cond = torch.cat( + [ + cond, + torch.zeros( + (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), + device=cond.device, + dtype=cond.dtype, + ), + ], + dim=1, + ) if encoder_attention_mask is None: encoder_attention_mask = conditioning_attention_mask else: - encoder_attention_mask = torch.cat([ - encoder_attention_mask, - conditioning_attention_mask, - ]) - + encoder_attention_mask = torch.cat( + [ + encoder_attention_mask, + conditioning_attention_mask, + ] + ) + return cond, encoder_attention_mask encoder_attention_mask = None @@ -277,11 +286,11 @@ class InvokeAIDiffuserComponent: x_twice = torch.cat([x] * 2) sigma_twice = torch.cat([sigma] * 2) - both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - unconditioning, conditioning - ) + both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(unconditioning, conditioning) both_results = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, + x_twice, + sigma_twice, + both_conditionings, encoder_attention_mask=encoder_attention_mask, **kwargs, ) @@ -312,13 +321,17 @@ class InvokeAIDiffuserComponent: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) unconditioned_next_x = self.model_forward_callback( - x, sigma, unconditioning, + x, + sigma, + unconditioning, down_block_additional_residuals=uncond_down_block, mid_block_additional_residual=uncond_mid_block, **kwargs, ) conditioned_next_x = self.model_forward_callback( - x, sigma, conditioning, + x, + sigma, + conditioning, down_block_additional_residuals=cond_down_block, mid_block_additional_residual=cond_mid_block, **kwargs, @@ -335,13 +348,15 @@ class InvokeAIDiffuserComponent: for k in conditioning: if isinstance(conditioning[k], list): both_conditionings[k] = [ - torch.cat([unconditioning[k][i], conditioning[k][i]]) - for i in range(len(conditioning[k])) + torch.cat([unconditioning[k][i], conditioning[k][i]]) for i in range(len(conditioning[k])) ] else: both_conditionings[k] = torch.cat([unconditioning[k], conditioning[k]]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback( - x_twice, sigma_twice, both_conditionings, **kwargs, + x_twice, + sigma_twice, + both_conditionings, + **kwargs, ).chunk(2) return unconditioned_next_x, conditioned_next_x @@ -388,9 +403,7 @@ class InvokeAIDiffuserComponent: ) # do requested cross attention types for conditioning (positive prompt) - cross_attn_processor_context.cross_attention_types_to_do = ( - cross_attention_control_types_to_do - ) + cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do conditioned_next_x = self.model_forward_callback( x, sigma, @@ -414,19 +427,14 @@ class InvokeAIDiffuserComponent: latents: torch.Tensor, percent_through: float, ) -> torch.Tensor: - if ( - postprocessing_settings.threshold is None - or postprocessing_settings.threshold == 0.0 - ): + if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0: return latents threshold = postprocessing_settings.threshold warmup = postprocessing_settings.warmup if percent_through < warmup: - current_threshold = threshold + threshold * 5 * ( - 1 - (percent_through / warmup) - ) + current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup)) else: current_threshold = threshold @@ -440,18 +448,10 @@ class InvokeAIDiffuserComponent: if self.debug_thresholding: std, mean = [i.item() for i in torch.std_mean(latents)] - outside = torch.count_nonzero( - (latents < -current_threshold) | (latents > current_threshold) - ) - logger.info( - f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})" - ) - logger.debug( - f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}" - ) - logger.debug( - f"{outside / latents.numel() * 100:.2f}% values outside threshold" - ) + outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) + logger.info(f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})") + logger.debug(f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}") + logger.debug(f"{outside / latents.numel() * 100:.2f}% values outside threshold") if maxval < current_threshold and minval > -current_threshold: return latents @@ -464,25 +464,17 @@ class InvokeAIDiffuserComponent: latents = torch.clone(latents) maxval = np.clip(maxval * scale, 1, current_threshold) num_altered += torch.count_nonzero(latents > maxval) - latents[latents > maxval] = ( - torch.rand_like(latents[latents > maxval]) * maxval - ) + latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval if minval < -current_threshold: latents = torch.clone(latents) minval = np.clip(minval * scale, -current_threshold, -1) num_altered += torch.count_nonzero(latents < minval) - latents[latents < minval] = ( - torch.rand_like(latents[latents < minval]) * minval - ) + latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval if self.debug_thresholding: - logger.debug( - f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})" - ) - logger.debug( - f"{num_altered / latents.numel() * 100:.2f}% values altered" - ) + logger.debug(f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})") + logger.debug(f"{num_altered / latents.numel() * 100:.2f}% values altered") return latents @@ -501,15 +493,11 @@ class InvokeAIDiffuserComponent: # Check for out of bounds h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct - if h_symmetry_time_pct is not None and ( - h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0 - ): + if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0): h_symmetry_time_pct = None v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct - if v_symmetry_time_pct is not None and ( - v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0 - ): + if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0): v_symmetry_time_pct = None dev = latents.device.type @@ -554,9 +542,7 @@ class InvokeAIDiffuserComponent: def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: # percent_through will never reach 1.0 (but this is intended) - return float(step_index) / float( - self.cross_attention_control_context.step_count - ) + return float(step_index) / float(self.cross_attention_control_context.step_count) # find the best possible index of the current sigma in the sigma sequence smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma) sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0 @@ -567,19 +553,13 @@ class InvokeAIDiffuserComponent: # todo: make this work @classmethod - def apply_conjunction( - cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale - ): + def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) # aka sigmas deltas = None uncond_latents = None - weighted_cond_list = ( - c_or_weighted_c_list - if type(c_or_weighted_c_list) is list - else [(c_or_weighted_c_list, 1)] - ) + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] # below is fugly omg conditionings = [uc] + [c for c, weight in weighted_cond_list] @@ -608,15 +588,11 @@ class InvokeAIDiffuserComponent: deltas = torch.cat((deltas, latents_b - uncond_latents)) # merge the weighted deltas together into a single merged delta - per_delta_weights = torch.tensor( - weights[1:], dtype=deltas.dtype, device=deltas.device - ) + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) normalize = False if normalize: per_delta_weights /= torch.sum(per_delta_weights) - reshaped_weights = per_delta_weights.reshape( - per_delta_weights.shape + (1, 1, 1) - ) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) diff --git a/invokeai/backend/stable_diffusion/image_degradation/bsrgan.py b/invokeai/backend/stable_diffusion/image_degradation/bsrgan.py index 1760206073..493c8be781 100644 --- a/invokeai/backend/stable_diffusion/image_degradation/bsrgan.py +++ b/invokeai/backend/stable_diffusion/image_degradation/bsrgan.py @@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3): year={2018} } """ - x = ndimage.filters.convolve( - x, np.expand_dims(k, axis=2), mode="wrap" - ) # 'nearest' | 'mirror' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x @@ -389,21 +387,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( - np.float32 - ) + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal( - 0, noise_level / 255.0, (*img.shape[:2], 1) - ).astype(np.float32) + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal( - [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] - ).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -413,21 +405,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype( - np.float32 - ) + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) elif rnum < 0.4: - img += img * np.random.normal( - 0, noise_level / 255.0, (*img.shape[:2], 1) - ).astype(np.float32) + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal( - [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] - ).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -440,9 +426,7 @@ def add_Poisson_noise(img): else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 - noise_gray = ( - np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - ) + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img @@ -451,9 +435,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode( - ".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] - ) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -540,9 +522,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve( - img, np.expand_dims(k_shifted, axis=2), mode="mirror" - ) + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -646,9 +626,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve( - image, np.expand_dims(k_shifted, axis=2), mode="mirror" - ) + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -796,9 +774,7 @@ if __name__ == "__main__": print(i) img_lq = deg_fn(img) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize( - max_size=h, interpolation=cv2.INTER_CUBIC - )(image=img)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) @@ -812,7 +788,5 @@ if __name__ == "__main__": (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) - img_concat = np.concatenate( - [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 - ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) util.imsave(img_concat, str(i) + ".png") diff --git a/invokeai/backend/stable_diffusion/image_degradation/bsrgan_light.py b/invokeai/backend/stable_diffusion/image_degradation/bsrgan_light.py index 1e8eee82b5..d0e0abadbc 100644 --- a/invokeai/backend/stable_diffusion/image_degradation/bsrgan_light.py +++ b/invokeai/backend/stable_diffusion/image_degradation/bsrgan_light.py @@ -261,9 +261,7 @@ def srmd_degradation(x, k, sf=3): year={2018} } """ - x = ndimage.filters.convolve( - x, np.expand_dims(k, axis=2), mode="wrap" - ) # 'nearest' | 'mirror' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode="wrap") # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x @@ -393,21 +391,15 @@ def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( - np.float32 - ) + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal( - 0, noise_level / 255.0, (*img.shape[:2], 1) - ).astype(np.float32) + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: # add noise L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal( - [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] - ).astype(np.float32) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -417,21 +409,15 @@ def add_speckle_noise(img, noise_level1=2, noise_level2=25): img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype( - np.float32 - ) + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) elif rnum < 0.4: - img += img * np.random.normal( - 0, noise_level / 255.0, (*img.shape[:2], 1) - ).astype(np.float32) + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) else: L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal( - [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] - ).astype(np.float32) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L**2 * conv), img.shape[:2]).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img @@ -444,9 +430,7 @@ def add_Poisson_noise(img): else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 - noise_gray = ( - np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - ) + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img @@ -455,9 +439,7 @@ def add_Poisson_noise(img): def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode( - ".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] - ) + result, encimg = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img @@ -544,9 +526,7 @@ def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve( - img, np.expand_dims(k_shifted, axis=2), mode="mirror" - ) + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode="mirror") img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) @@ -653,9 +633,7 @@ def degradation_bsrgan_variant(image, sf=4, isp_model=None): k = fspecial("gaussian", 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve( - image, np.expand_dims(k_shifted, axis=2), mode="mirror" - ) + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode="mirror") image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) @@ -705,9 +683,9 @@ if __name__ == "__main__": img_lq = deg_fn(img)["image"] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize( - max_size=h, interpolation=cv2.INTER_CUBIC - )(image=img_hq)["image"] + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[ + "image" + ] print(img_lq.shape) print("bicubic", img_lq_bicubic.shape) print(img_hq.shape) @@ -721,7 +699,5 @@ if __name__ == "__main__": (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) - img_concat = np.concatenate( - [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 - ) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) util.imsave(img_concat, str(i) + ".png") diff --git a/invokeai/backend/stable_diffusion/image_degradation/utils_image.py b/invokeai/backend/stable_diffusion/image_degradation/utils_image.py index c4d37a24bf..d45ca602e6 100644 --- a/invokeai/backend/stable_diffusion/image_degradation/utils_image.py +++ b/invokeai/backend/stable_diffusion/image_degradation/utils_image.py @@ -11,6 +11,7 @@ from torchvision.utils import make_grid # import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py import invokeai.backend.util.logging as logger + os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" @@ -296,22 +297,14 @@ def single2uint16(img): def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return ( - torch.from_numpy(np.ascontiguousarray(img)) - .permute(2, 0, 1) - .float() - .div(255.0) - .unsqueeze(0) - ) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0).unsqueeze(0) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) - return ( - torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) - ) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.0) # convert 2/3/4-dimensional torch tensor to uint @@ -334,12 +327,7 @@ def single2tensor3(img): # convert single (HxWxC) to 4-dimensional torch tensor def single2tensor4(img): - return ( - torch.from_numpy(np.ascontiguousarray(img)) - .permute(2, 0, 1) - .float() - .unsqueeze(0) - ) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) # convert torch tensor to single @@ -362,12 +350,7 @@ def tensor2single3(img): def single2tensor5(img): - return ( - torch.from_numpy(np.ascontiguousarray(img)) - .permute(2, 0, 1, 3) - .float() - .unsqueeze(0) - ) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) def single32tensor5(img): @@ -385,9 +368,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) """ - tensor = ( - tensor.squeeze().float().cpu().clamp_(*min_max) - ) # squeeze first, then clamp + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: @@ -400,11 +381,7 @@ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): elif n_dim == 2: img_np = tensor.numpy() else: - raise TypeError( - "Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format( - n_dim - ) - ) + raise TypeError("Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format(n_dim)) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. @@ -744,9 +721,7 @@ def ssim(img1, img2): sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( - (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) - ) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() @@ -767,9 +742,7 @@ def cubic(x): ) * (((absx > 1) * (absx <= 2)).type_as(absx)) -def calculate_weights_indices( - in_length, out_length, scale, kernel, kernel_width, antialiasing -): +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): if (scale < 1) and (antialiasing): # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width kernel_width = kernel_width / scale @@ -793,9 +766,9 @@ def calculate_weights_indices( # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( - 0, P - 1, P - ).view(1, P).expand(out_length, P) + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(1, P).expand( + out_length, P + ) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. @@ -876,9 +849,7 @@ def imresize(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[j, i, :] = ( - img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - ) + out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying @@ -959,9 +930,7 @@ def imresize_np(img, scale, antialiasing=True): for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): - out_1[i, :, j] = ( - img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) - ) + out_1[i, :, j] = img_aug[idx : idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying diff --git a/invokeai/backend/stable_diffusion/offloading.py b/invokeai/backend/stable_diffusion/offloading.py index d36b65872a..aa2426d514 100644 --- a/invokeai/backend/stable_diffusion/offloading.py +++ b/invokeai/backend/stable_diffusion/offloading.py @@ -95,10 +95,7 @@ class ModelGroup(metaclass=ABCMeta): pass def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__} object at {id(self):x}: " - f"device={self.execution_device} >" - ) + return f"<{self.__class__.__name__} object at {id(self):x}: " f"device={self.execution_device} >" class LazilyLoadedModelGroup(ModelGroup): @@ -143,8 +140,7 @@ class LazilyLoadedModelGroup(ModelGroup): self.load(module) if len(forward_input) == 0: warnings.warn( - f"Hook for {module.__class__.__name__} got no input. " - f"Inputs must be positional, not keywords.", + f"Hook for {module.__class__.__name__} got no input. " f"Inputs must be positional, not keywords.", stacklevel=3, ) return send_to_device(forward_input, self.execution_device) @@ -161,9 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup): self.clear_current_model() def _load(self, module: torch.nn.Module) -> torch.nn.Module: - assert ( - self.is_empty() - ), f"A model is already loaded: {self._current_model_ref()}" + assert self.is_empty(), f"A model is already loaded: {self._current_model_ref()}" module = module.to(self.execution_device) self.set_current_model(module) return module @@ -192,12 +186,8 @@ class LazilyLoadedModelGroup(ModelGroup): def device_for(self, model): if model not in self: - raise KeyError( - f"This does not manage this model {type(model).__name__}", model - ) - return ( - self.execution_device - ) # this implementation only dispatches to one device + raise KeyError(f"This does not manage this model {type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device def ready(self): pass # always ready to load on-demand @@ -256,12 +246,8 @@ class FullyLoadedModelGroup(ModelGroup): def device_for(self, model): if model not in self: - raise KeyError( - "This does not manage this model f{type(model).__name__}", model - ) - return ( - self.execution_device - ) # this implementation only dispatches to one device + raise KeyError("This does not manage this model f{type(model).__name__}", model) + return self.execution_device # this implementation only dispatches to one device def __contains__(self, model): return model in self._models diff --git a/invokeai/backend/stable_diffusion/schedulers/__init__.py b/invokeai/backend/stable_diffusion/schedulers/__init__.py index b2df0df231..29a96eb3a5 100644 --- a/invokeai/backend/stable_diffusion/schedulers/__init__.py +++ b/invokeai/backend/stable_diffusion/schedulers/__init__.py @@ -1 +1 @@ -from .schedulers import SCHEDULER_MAP \ No newline at end of file +from .schedulers import SCHEDULER_MAP diff --git a/invokeai/backend/stable_diffusion/schedulers/schedulers.py b/invokeai/backend/stable_diffusion/schedulers/schedulers.py index 77c45d5eb8..2f62f8c477 100644 --- a/invokeai/backend/stable_diffusion/schedulers/schedulers.py +++ b/invokeai/backend/stable_diffusion/schedulers/schedulers.py @@ -1,7 +1,19 @@ -from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \ - KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \ - HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \ - DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler +from diffusers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UniPCMultistepScheduler, + DPMSolverSinglestepScheduler, + DEISMultistepScheduler, + DDPMScheduler, + DPMSolverSDEScheduler, +) SCHEDULER_MAP = dict( ddim=(DDIMScheduler, dict()), @@ -21,9 +33,9 @@ SCHEDULER_MAP = dict( dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)), dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)), dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)), - dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')), - dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')), + dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type="sde-dpmsolver++")), + dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++")), dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)), dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)), - unipc=(UniPCMultistepScheduler, dict(cpu_only=True)) + unipc=(UniPCMultistepScheduler, dict(cpu_only=True)), ) diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index aadf8c76a3..d92aa80b38 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -45,7 +45,7 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer # invokeai stuff -from invokeai.app.services.config import InvokeAIAppConfig,PagingArgumentParser +from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser from invokeai.app.services.model_manager_service import ModelManagerService from invokeai.backend.model_management.models import SubModelType @@ -75,24 +75,16 @@ check_min_version("0.10.0.dev0") logger = get_logger(__name__) -def save_progress( - text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path -): +def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path): logger.info("Saving embeddings") - learned_embeds = ( - accelerator.unwrap_model(text_encoder) - .get_input_embeddings() - .weight[placeholder_token_id] - ) + learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, save_path) def parse_args(): config = InvokeAIAppConfig.get_config() - parser = PagingArgumentParser( - description="Textual inversion training" - ) + parser = PagingArgumentParser(description="Textual inversion training") general_group = parser.add_argument_group("General") model_group = parser.add_argument_group("Models and Paths") image_group = parser.add_argument_group("Training Image Location and Options") @@ -221,9 +213,7 @@ def parse_args(): default=100, help="How many times to repeat the training data.", ) - training_group.add_argument( - "--seed", type=int, default=None, help="A seed for reproducible training." - ) + training_group.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") training_group.add_argument( "--train_batch_size", type=int, @@ -287,9 +277,7 @@ def parse_args(): default=0.999, help="The beta2 parameter for the Adam optimizer.", ) - training_group.add_argument( - "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." - ) + training_group.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") training_group.add_argument( "--adam_epsilon", type=float, @@ -442,9 +430,7 @@ class TextualInversionDataset(Dataset): self.data_root / file_path for file_path in self.data_root.iterdir() if file_path.is_file() - and file_path.name.endswith( - (".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF") - ) + and file_path.name.endswith((".png", ".PNG", ".jpg", ".JPG", ".jpeg", ".JPEG", ".gif", ".GIF")) ] self.num_images = len(self.image_paths) @@ -460,11 +446,7 @@ class TextualInversionDataset(Dataset): "lanczos": PIL_INTERPOLATION["lanczos"], }[interpolation] - self.templates = ( - imagenet_style_templates_small - if learnable_property == "style" - else imagenet_templates_small - ) + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) def __len__(self): @@ -500,9 +482,7 @@ class TextualInversionDataset(Dataset): img.shape[0], img.shape[1], ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2 - ] + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] image = Image.fromarray(img) image = image.resize((self.size, self.size), resample=self.interpolation) @@ -515,9 +495,7 @@ class TextualInversionDataset(Dataset): return example -def get_full_repo_name( - model_id: str, organization: Optional[str] = None, token: Optional[str] = None -): +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): if token is None: token = HfFolder.get_token() if organization is None: @@ -570,9 +548,7 @@ def do_textual_inversion_training( **kwargs, ): assert model, "Please specify a base model with --model" - assert ( - train_data_dir - ), "Please specify a directory containing the training images using --train_data_dir" + assert train_data_dir, "Please specify a directory containing the training images using --train_data_dir" assert placeholder_token, "Please specify a trigger term using --placeholder_token" env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != local_rank: @@ -593,7 +569,7 @@ def do_textual_inversion_training( project_config=accelerator_config, ) - model_manager = ModelManagerService(config,logger) + model_manager = ModelManagerService(config, logger) # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -633,13 +609,11 @@ def do_textual_inversion_training( os.makedirs(output_dir, exist_ok=True) known_models = model_manager.model_names() - model_name = model.split('/')[-1] + model_name = model.split("/")[-1] model_meta = next((mm for mm in known_models if mm[0].endswith(model_name)), None) assert model_meta is not None, f"Unknown model: {model}" model_info = model_manager.model_info(*model_meta) - assert ( - model_info['model_format'] == "diffusers" - ), "This script only works with models of type 'diffusers'" + assert model_info["model_format"] == "diffusers", "This script only works with models of type 'diffusers'" tokenizer_info = model_manager.get_model(*model_meta, submodel=SubModelType.Tokenizer) noise_scheduler_info = model_manager.get_model(*model_meta, submodel=SubModelType.Scheduler) text_encoder_info = model_manager.get_model(*model_meta, submodel=SubModelType.TextEncoder) @@ -650,9 +624,7 @@ def do_textual_inversion_training( if tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args) else: - tokenizer = CLIPTokenizer.from_pretrained( - tokenizer_info.location, subfolder='tokenizer', **pipeline_args - ) + tokenizer = CLIPTokenizer.from_pretrained(tokenizer_info.location, subfolder="tokenizer", **pipeline_args) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained( @@ -722,9 +694,7 @@ def do_textual_inversion_training( if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() else: - raise ValueError( - "xformers is not available. Make sure it is installed correctly" - ) + raise ValueError("xformers is not available. Make sure it is installed correctly") # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -732,12 +702,7 @@ def do_textual_inversion_training( torch.backends.cuda.matmul.allow_tf32 = True if scale_lr: - learning_rate = ( - learning_rate - * gradient_accumulation_steps - * train_batch_size - * accelerator.num_processes - ) + learning_rate = learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes # Initialize the optimizer optimizer = torch.optim.AdamW( @@ -759,15 +724,11 @@ def do_textual_inversion_training( center_crop=center_crop, set="train", ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=train_batch_size, shuffle=True - ) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) # Scheduler and math around the number of training steps. overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) if max_train_steps is None: max_train_steps = num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True @@ -797,9 +758,7 @@ def do_textual_inversion_training( vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / gradient_accumulation_steps - ) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) if overrode_max_train_steps: max_train_steps = num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs @@ -814,17 +773,13 @@ def do_textual_inversion_training( accelerator.init_trackers("textual_inversion", config=params) # Train! - total_batch_size = ( - train_batch_size * accelerator.num_processes * gradient_accumulation_steps - ) + total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {train_batch_size}") - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_train_steps}") global_step = 0 @@ -843,9 +798,7 @@ def do_textual_inversion_training( path = dirs[-1] if len(dirs) > 0 else None if path is None: - accelerator.print( - f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run." - ) + accelerator.print(f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run.") resume_from_checkpoint = None else: accelerator.print(f"Resuming from checkpoint {path}") @@ -854,9 +807,7 @@ def do_textual_inversion_training( resume_global_step = global_step * gradient_accumulation_steps first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % ( - num_update_steps_per_epoch * gradient_accumulation_steps - ) + resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps) # Only show the progress bar once on each machine. progress_bar = tqdm( @@ -866,33 +817,20 @@ def do_textual_inversion_training( progress_bar.set_description("Steps") # keep original embeddings as reference - orig_embeds_params = ( - accelerator.unwrap_model(text_encoder) - .get_input_embeddings() - .weight.data.clone() - ) + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() for epoch in range(first_epoch, num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): # Skip steps until we reach the resumed step - if ( - resume_step - and resume_from_checkpoint - and epoch == first_epoch - and step < resume_step - ): + if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step: if step % gradient_accumulation_steps == 0: progress_bar.update(1) continue with accelerator.accumulate(text_encoder): # Convert images to latent space - latents = ( - vae.encode(batch["pixel_values"].to(dtype=weight_dtype)) - .latent_dist.sample() - .detach() - ) + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -912,14 +850,10 @@ def do_textual_inversion_training( noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0].to( - dtype=weight_dtype - ) + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual - model_pred = unet( - noisy_latents, timesteps, encoder_hidden_states - ).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -927,9 +861,7 @@ def do_textual_inversion_training( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -942,22 +874,16 @@ def do_textual_inversion_training( # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model( - text_encoder - ).get_input_embeddings().weight[ + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ index_no_updates - ] = orig_embeds_params[ - index_no_updates - ] + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if global_step % save_steps == 0: - save_path = os.path.join( - output_dir, f"learned_embeds-steps-{global_step}.bin" - ) + save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin") save_progress( text_encoder, placeholder_token_id, @@ -968,9 +894,7 @@ def do_textual_inversion_training( if global_step % checkpointing_steps == 0: if accelerator.is_main_process: - save_path = os.path.join( - output_dir, f"checkpoint-{global_step}" - ) + save_path = os.path.join(output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") @@ -985,9 +909,7 @@ def do_textual_inversion_training( accelerator.wait_for_everyone() if accelerator.is_main_process: if push_to_hub and only_save_embeds: - logger.warn( - "Enabling full model saving because --push_to_hub=True was specified." - ) + logger.warn("Enabling full model saving because --push_to_hub=True was specified.") save_full_model = True else: save_full_model = not only_save_embeds @@ -1012,8 +934,6 @@ def do_textual_inversion_training( ) if push_to_hub: - repo.push_to_hub( - commit_message="End of training", blocking=False, auto_lfs_prune=True - ) + repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) accelerator.end_training() diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index fadeff4d75..2e69af5382 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -11,12 +11,4 @@ from .devices import ( torch_dtype, ) from .log import write_log -from .util import ( - ask_user, - download_with_resume, - instantiate_from_config, - url_attachment_name, - Chdir -) - - +from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 3fbdaba41a..eeabcc35db 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -12,6 +12,7 @@ CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") config = InvokeAIAppConfig.get_config() + def choose_torch_device() -> torch.device: """Convenience routine for guessing which GPU device to run model on""" if config.always_use_cpu: diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 67e46a426f..4710682ac1 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -20,6 +20,7 @@ from diffusers.models.controlnet import ControlNetConditioningEmbedding, Control # Modified ControlNetModel with encoder_attention_mask argument added + class ControlNetModel(ModelMixin, ConfigMixin): """ A ControlNet model. @@ -618,9 +619,7 @@ class ControlNetModel(ModelMixin, ConfigMixin): mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: - down_block_res_samples = [ - torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples - ] + down_block_res_samples = [torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: @@ -630,5 +629,6 @@ class ControlNetModel(ModelMixin, ConfigMixin): down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) + diffusers.ControlNetModel = ControlNetModel -diffusers.models.controlnet.ControlNetModel = ControlNetModel \ No newline at end of file +diffusers.models.controlnet.ControlNetModel = ControlNetModel diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index d06c036506..3a8d721aa5 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -186,89 +186,109 @@ from invokeai.app.services.config import InvokeAIAppConfig, get_invokeai_config try: import syslog + SYSLOG_AVAILABLE = True except: SYSLOG_AVAILABLE = False + # module level functions def debug(msg, *args, **kwargs): InvokeAILogger.getLogger().debug(msg, *args, **kwargs) + def info(msg, *args, **kwargs): InvokeAILogger.getLogger().info(msg, *args, **kwargs) + def warning(msg, *args, **kwargs): InvokeAILogger.getLogger().warning(msg, *args, **kwargs) + def error(msg, *args, **kwargs): InvokeAILogger.getLogger().error(msg, *args, **kwargs) + def critical(msg, *args, **kwargs): InvokeAILogger.getLogger().critical(msg, *args, **kwargs) + def log(level, msg, *args, **kwargs): InvokeAILogger.getLogger().log(level, msg, *args, **kwargs) + def disable(level=logging.CRITICAL): InvokeAILogger.getLogger().disable(level) + def basicConfig(**kwargs): InvokeAILogger.getLogger().basicConfig(**kwargs) + def getLogger(name: str = None) -> logging.Logger: return InvokeAILogger.getLogger(name) -_FACILITY_MAP = dict( - LOG_KERN = syslog.LOG_KERN, - LOG_USER = syslog.LOG_USER, - LOG_MAIL = syslog.LOG_MAIL, - LOG_DAEMON = syslog.LOG_DAEMON, - LOG_AUTH = syslog.LOG_AUTH, - LOG_LPR = syslog.LOG_LPR, - LOG_NEWS = syslog.LOG_NEWS, - LOG_UUCP = syslog.LOG_UUCP, - LOG_CRON = syslog.LOG_CRON, - LOG_SYSLOG = syslog.LOG_SYSLOG, - LOG_LOCAL0 = syslog.LOG_LOCAL0, - LOG_LOCAL1 = syslog.LOG_LOCAL1, - LOG_LOCAL2 = syslog.LOG_LOCAL2, - LOG_LOCAL3 = syslog.LOG_LOCAL3, - LOG_LOCAL4 = syslog.LOG_LOCAL4, - LOG_LOCAL5 = syslog.LOG_LOCAL5, - LOG_LOCAL6 = syslog.LOG_LOCAL6, - LOG_LOCAL7 = syslog.LOG_LOCAL7, -) if SYSLOG_AVAILABLE else dict() - -_SOCK_MAP = dict( - SOCK_STREAM = socket.SOCK_STREAM, - SOCK_DGRAM = socket.SOCK_DGRAM, +_FACILITY_MAP = ( + dict( + LOG_KERN=syslog.LOG_KERN, + LOG_USER=syslog.LOG_USER, + LOG_MAIL=syslog.LOG_MAIL, + LOG_DAEMON=syslog.LOG_DAEMON, + LOG_AUTH=syslog.LOG_AUTH, + LOG_LPR=syslog.LOG_LPR, + LOG_NEWS=syslog.LOG_NEWS, + LOG_UUCP=syslog.LOG_UUCP, + LOG_CRON=syslog.LOG_CRON, + LOG_SYSLOG=syslog.LOG_SYSLOG, + LOG_LOCAL0=syslog.LOG_LOCAL0, + LOG_LOCAL1=syslog.LOG_LOCAL1, + LOG_LOCAL2=syslog.LOG_LOCAL2, + LOG_LOCAL3=syslog.LOG_LOCAL3, + LOG_LOCAL4=syslog.LOG_LOCAL4, + LOG_LOCAL5=syslog.LOG_LOCAL5, + LOG_LOCAL6=syslog.LOG_LOCAL6, + LOG_LOCAL7=syslog.LOG_LOCAL7, + ) + if SYSLOG_AVAILABLE + else dict() ) +_SOCK_MAP = dict( + SOCK_STREAM=socket.SOCK_STREAM, + SOCK_DGRAM=socket.SOCK_DGRAM, +) + + class InvokeAIFormatter(logging.Formatter): - ''' + """ Base class for logging formatter - ''' + """ + def format(self, record): formatter = logging.Formatter(self.log_fmt(record.levelno)) return formatter.format(record) @abstractmethod - def log_fmt(self, levelno: int)->str: + def log_fmt(self, levelno: int) -> str: pass - + + class InvokeAISyslogFormatter(InvokeAIFormatter): - ''' + """ Formatting for syslog - ''' - def log_fmt(self, levelno: int)->str: - return '%(name)s [%(process)d] <%(levelname)s> %(message)s' + """ + + def log_fmt(self, levelno: int) -> str: + return "%(name)s [%(process)d] <%(levelname)s> %(message)s" + class InvokeAILegacyLogFormatter(InvokeAIFormatter): - ''' + """ Formatting for the InvokeAI Logger (legacy version) - ''' + """ + FORMATS = { logging.DEBUG: " | %(message)s", logging.INFO: ">> %(message)s", @@ -276,20 +296,25 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter): logging.ERROR: "*** %(message)s", logging.CRITICAL: "### %(message)s", } - def log_fmt(self,levelno:int)->str: + + def log_fmt(self, levelno: int) -> str: return self.FORMATS.get(levelno) + class InvokeAIPlainLogFormatter(InvokeAIFormatter): - ''' + """ Custom Formatting for the InvokeAI Logger (plain version) - ''' - def log_fmt(self, levelno: int)->str: + """ + + def log_fmt(self, levelno: int) -> str: return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s" + class InvokeAIColorLogFormatter(InvokeAIFormatter): - ''' + """ Custom Formatting for the InvokeAI Logger - ''' + """ + # Color Codes grey = "\x1b[38;20m" yellow = "\x1b[33;20m" @@ -308,32 +333,34 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter): logging.INFO: grey + log_format + reset, logging.WARNING: yellow + log_format + reset, logging.ERROR: red + log_format + reset, - logging.CRITICAL: bold_red + log_format + reset + logging.CRITICAL: bold_red + log_format + reset, } - def log_fmt(self, levelno: int)->str: + def log_fmt(self, levelno: int) -> str: return self.FORMATS.get(levelno) + LOG_FORMATTERS = { - 'plain': InvokeAIPlainLogFormatter, - 'color': InvokeAIColorLogFormatter, - 'syslog': InvokeAISyslogFormatter, - 'legacy': InvokeAILegacyLogFormatter, + "plain": InvokeAIPlainLogFormatter, + "color": InvokeAIColorLogFormatter, + "syslog": InvokeAISyslogFormatter, + "legacy": InvokeAILegacyLogFormatter, } + class InvokeAILogger(object): loggers = dict() @classmethod - def getLogger(cls, - name: str = 'InvokeAI', - config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger: + def getLogger( + cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() + ) -> logging.Logger: if name in cls.loggers: logger = cls.loggers[name] logger.handlers.clear() else: logger = logging.getLogger(name) - logger.setLevel(config.log_level.upper()) # yes, strings work here + logger.setLevel(config.log_level.upper()) # yes, strings work here for ch in cls.getLoggers(config): logger.addHandler(ch) cls.loggers[name] = logger @@ -344,82 +371,80 @@ class InvokeAILogger(object): handler_strs = config.log_handlers handlers = list() for handler in handler_strs: - handler_name,*args = handler.split('=',2) + handler_name, *args = handler.split("=", 2) args = args[0] if len(args) > 0 else None # console and file get the fancy formatter. # syslog gets a simple one # http gets no custom formatter formatter = LOG_FORMATTERS[config.log_format] - if handler_name=='console': + if handler_name == "console": ch = logging.StreamHandler() ch.setFormatter(formatter()) handlers.append(ch) - - elif handler_name=='syslog': + + elif handler_name == "syslog": ch = cls._parse_syslog_args(args) handlers.append(ch) - - elif handler_name=='file': + + elif handler_name == "file": ch = cls._parse_file_args(args) ch.setFormatter(formatter()) handlers.append(ch) - - elif handler_name=='http': + + elif handler_name == "http": ch = cls._parse_http_args(args) handlers.append(ch) return handlers @staticmethod - def _parse_syslog_args( - args: str=None - )-> logging.Handler: + def _parse_syslog_args(args: str = None) -> logging.Handler: if not SYSLOG_AVAILABLE: raise ValueError("syslog is not available on this system") if not args: - args='/dev/log' if Path('/dev/log').exists() else 'address:localhost:514' + args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514" syslog_args = dict() try: - for a in args.split(','): - arg_name,*arg_value = a.split(':',2) - if arg_name=='address': - host,*port = arg_value - port = 514 if len(port)==0 else int(port[0]) - syslog_args['address'] = (host,port) - elif arg_name=='facility': - syslog_args['facility'] = _FACILITY_MAP[arg_value[0]] - elif arg_name=='socktype': - syslog_args['socktype'] = _SOCK_MAP[arg_value[0]] + for a in args.split(","): + arg_name, *arg_value = a.split(":", 2) + if arg_name == "address": + host, *port = arg_value + port = 514 if len(port) == 0 else int(port[0]) + syslog_args["address"] = (host, port) + elif arg_name == "facility": + syslog_args["facility"] = _FACILITY_MAP[arg_value[0]] + elif arg_name == "socktype": + syslog_args["socktype"] = _SOCK_MAP[arg_value[0]] else: - syslog_args['address'] = arg_name + syslog_args["address"] = arg_name except: raise ValueError(f"{args} is not a value argument list for syslog logging") return logging.handlers.SysLogHandler(**syslog_args) - + @staticmethod - def _parse_file_args(args: str=None)-> logging.Handler: + def _parse_file_args(args: str = None) -> logging.Handler: if not args: raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'") return logging.FileHandler(args) @staticmethod - def _parse_http_args(args: str=None)-> logging.Handler: + def _parse_http_args(args: str = None) -> logging.Handler: if not args: raise ValueError("please provide destination for http logging using format 'http=url'") - arg_list = args.split(',') + arg_list = args.split(",") url = urllib.parse.urlparse(arg_list.pop(0)) - if url.scheme != 'http': + if url.scheme != "http": raise ValueError(f"the http logging module can only log to HTTP URLs, but {url.scheme} was specified") host = url.hostname path = url.path port = url.port or 80 - + syslog_args = dict() for a in arg_list: - arg_name, *arg_value = a.split(':',2) - if arg_name=='method': - arg_value = arg_value[0] if len(arg_value)>0 else 'GET' + arg_name, *arg_value = a.split(":", 2) + if arg_name == "method": + arg_value = arg_value[0] if len(arg_value) > 0 else "GET" syslog_args[arg_name] = arg_value else: # TODO: Provide support for SSL context and credentials pass - return logging.handlers.HTTPHandler(f'{host}:{port}',path,**syslog_args) + return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args) diff --git a/invokeai/backend/util/mps_fixes.py b/invokeai/backend/util/mps_fixes.py index e4d8da20a9..409eff1c8b 100644 --- a/invokeai/backend/util/mps_fixes.py +++ b/invokeai/backend/util/mps_fixes.py @@ -8,6 +8,8 @@ if torch.backends.mps.is_available(): _torch_layer_norm = torch.nn.functional.layer_norm + + def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): if input.device.type == "mps" and input.dtype == torch.float16: input = input.float() @@ -19,20 +21,26 @@ def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): else: return _torch_layer_norm(input, normalized_shape, weight, bias, eps) + torch.nn.functional.layer_norm = new_layer_norm _torch_tensor_permute = torch.Tensor.permute + + def new_torch_tensor_permute(input, *dims): result = _torch_tensor_permute(input, *dims) if input.device == "mps" and input.dtype == torch.float16: result = result.contiguous() return result + torch.Tensor.permute = new_torch_tensor_permute _torch_lerp = torch.lerp + + def new_torch_lerp(input, end, weight, *, out=None): if input.device.type == "mps" and input.dtype == torch.float16: input = input.float() @@ -52,20 +60,36 @@ def new_torch_lerp(input, end, weight, *, out=None): else: return _torch_lerp(input, end, weight, out=out) + torch.lerp = new_torch_lerp _torch_interpolate = torch.nn.functional.interpolate -def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): + + +def new_torch_interpolate( + input, + size=None, + scale_factor=None, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + antialias=False, +): if input.device.type == "mps" and input.dtype == torch.float16: - return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half() + return _torch_interpolate( + input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias + ).half() else: return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias) + torch.nn.functional.interpolate = new_torch_interpolate # TODO: refactor it _SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor + + class ChunkedSlicedAttnProcessor: r""" Processor for implementing sliced attention. @@ -78,7 +102,7 @@ class ChunkedSlicedAttnProcessor: def __init__(self, slice_size): assert isinstance(slice_size, int) - slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory + slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory self.slice_size = slice_size self._sliced_attn_processor = _SlicedAttnProcessor(slice_size) @@ -121,7 +145,9 @@ class ChunkedSlicedAttnProcessor: (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype ) - chunk_tmp_tensor = torch.empty(self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device) + chunk_tmp_tensor = torch.empty( + self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device + ) for i in range(batch_size_attention // self.slice_size): start_idx = i * self.slice_size @@ -131,7 +157,15 @@ class ChunkedSlicedAttnProcessor: key_slice = key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - self.get_attention_scores_chunked(attn, query_slice, key_slice, attn_mask_slice, hidden_states[start_idx:end_idx], value[start_idx:end_idx], chunk_tmp_tensor) + self.get_attention_scores_chunked( + attn, + query_slice, + key_slice, + attn_mask_slice, + hidden_states[start_idx:end_idx], + value[start_idx:end_idx], + chunk_tmp_tensor, + ) hidden_states = attn.batch_to_head_dim(hidden_states) @@ -150,7 +184,6 @@ class ChunkedSlicedAttnProcessor: return hidden_states - def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk): # batch size = 1 assert query.shape[0] == 1 @@ -163,14 +196,14 @@ class ChunkedSlicedAttnProcessor: query = query.float() key = key.float() - #out_item_size = query.dtype.itemsize - #if attn.upcast_attention: + # out_item_size = query.dtype.itemsize + # if attn.upcast_attention: # out_item_size = torch.float32.itemsize out_item_size = query.element_size() if attn.upcast_attention: out_item_size = 4 - chunk_size = 2 ** 29 + chunk_size = 2**29 out_size = query.shape[1] * key.shape[1] * out_item_size chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size)) @@ -181,8 +214,8 @@ class ChunkedSlicedAttnProcessor: def _get_chunk_view(tensor, start, length): if start + length > tensor.shape[1]: length = tensor.shape[1] - start - #print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}") - return tensor[:,start:start+length] + # print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}") + return tensor[:, start : start + length] for chunk_pos in range(0, query.shape[1], chunk_step): if attention_mask is not None: @@ -196,7 +229,7 @@ class ChunkedSlicedAttnProcessor: ) else: torch.baddbmm( - torch.zeros((1,1,1), device=query.device, dtype=query.dtype), + torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype), _get_chunk_view(query, chunk_pos, chunk_step), key, beta=0, @@ -206,7 +239,7 @@ class ChunkedSlicedAttnProcessor: chunk = chunk.softmax(dim=-1) torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step)) - #del chunk + # del chunk diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 1cc632e483..f3c182c063 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -32,9 +32,7 @@ def log_txt_as_img(wh, xc, size=10): draw = ImageDraw.Draw(txt) font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) - lines = "\n".join( - xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) - ) + lines = "\n".join(xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) @@ -81,9 +79,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - logger.debug( - f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params." - ) + logger.debug(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params @@ -154,21 +150,12 @@ def parallel_data_prefetch( proc = Thread # spawn processes if target_data_type == "ndarray": - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate(np.array_split(data, n_proc)) - ] + arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))] else: - step = ( - int(len(data) / n_proc + 1) - if len(data) % n_proc != 0 - else int(len(data) / n_proc) - ) + step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) arguments = [ [func, Q, part, i, use_worker_id] - for i, part in enumerate( - [data[i : i + step] for i in range(0, len(data), step)] - ) + for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)]) ] processes = [] for i in range(n_proc): @@ -220,9 +207,7 @@ def parallel_data_prefetch( return gather_res -def rand_perlin_2d( - shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3 -): +def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3): delta = (res[0] / shape[0], res[1] / shape[1]) d = (shape[0] // res[0], shape[1] // res[1]) @@ -265,9 +250,9 @@ def rand_perlin_2d( n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device) n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device) t = fade(grid[: shape[0], : shape[1]]) - noise = math.sqrt(2) * torch.lerp( - torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1] - ).to(device) + noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to( + device + ) return noise.to(dtype=torch_dtype(device)) @@ -276,9 +261,7 @@ def ask_user(question: str, answers: list): user_prompt = f"\n>> {question} {answers}: " invalid_answer_msg = "Invalid answer. Please try again." - pose_question = chain( - [user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt])) - ) + pose_question = chain([user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))) user_answers = map(input, pose_question) valid_response = next(filter(answers.__contains__, user_answers)) return valid_response @@ -303,9 +286,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path if dest.is_dir(): try: - file_name = re.search( - 'filename="(.+)"', resp.headers.get("Content-Disposition") - ).group(1) + file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1) except: file_name = os.path.basename(url) dest = dest / file_name @@ -322,7 +303,7 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path logger.warning("corrupt existing file found. re-downloading") os.remove(dest) exist_size = 0 - + if resp.status_code == 416 or (content_length > 0 and exist_size == content_length): logger.warning(f"{dest}: complete file found. Skipping.") return dest @@ -377,16 +358,16 @@ def image_to_dataURL(image: Image.Image, image_format: str = "PNG") -> str: buffered = io.BytesIO() image.save(buffered, format=image_format) mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower()) - image_base64 = f"data:{mime_type};base64," + base64.b64encode( - buffered.getvalue() - ).decode("UTF-8") + image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8") return image_base64 + class Chdir(object): - '''Context manager to chdir to desired directory and change back after context exits: + """Context manager to chdir to desired directory and change back after context exits: Args: path (Path): The path to the cwd - ''' + """ + def __init__(self, path: Path): self.path = path self.original = Path().absolute() @@ -394,5 +375,5 @@ class Chdir(object): def __enter__(self): os.chdir(self.path) - def __exit__(self,*args): + def __exit__(self, *args): os.chdir(self.original) diff --git a/invokeai/backend/web/invoke_ai_web_server.py b/invokeai/backend/web/invoke_ai_web_server.py index eec02cd9dc..88eb77551f 100644 --- a/invokeai/backend/web/invoke_ai_web_server.py +++ b/invokeai/backend/web/invoke_ai_web_server.py @@ -64,10 +64,7 @@ class InvokeAIWebServer: self.ALLOWED_EXTENSIONS = {"png", "jpg", "jpeg"} def allowed_file(self, filename: str) -> bool: - return ( - "." in filename - and filename.rsplit(".", 1)[1].lower() in self.ALLOWED_EXTENSIONS - ) + return "." in filename and filename.rsplit(".", 1)[1].lower() in self.ALLOWED_EXTENSIONS def run(self): self.setup_app() @@ -99,9 +96,7 @@ class InvokeAIWebServer: _cors = _cors.split(",") socketio_args["cors_allowed_origins"] = _cors - self.app = Flask( - __name__, static_url_path="", static_folder=frontend.__path__[0] - ) + self.app = Flask(__name__, static_url_path="", static_folder=frontend.__path__[0]) self.socketio = SocketIO(self.app, **socketio_args) @@ -192,9 +187,7 @@ class InvokeAIWebServer: (width, height) = pil_image.size - thumbnail_path = save_thumbnail( - pil_image, os.path.basename(file_path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(pil_image, os.path.basename(file_path), self.thumbnail_image_path) response = { "url": self.get_url_from_image_path(file_path), @@ -237,12 +230,8 @@ class InvokeAIWebServer: f"Point your browser at http{'s' if useSSL else ''}://localhost:{self.port} or use the host's DNS name or IP address." ) else: - logger.info( - "Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address." - ) - logger.info( - f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}" - ) + logger.info("Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.") + logger.info(f"Point your browser at http{'s' if useSSL else ''}://{self.host}:{self.port}") if not useSSL: self.socketio.run(app=self.app, host=self.host, port=self.port) else: @@ -392,22 +381,16 @@ class InvokeAIWebServer: @socketio.on("convertToDiffusers") def convert_to_diffusers(model_to_convert: dict): try: - if model_info := self.generate.model_manager.model_info( - model_name=model_to_convert["model_name"] - ): + if model_info := self.generate.model_manager.model_info(model_name=model_to_convert["model_name"]): if "weights" in model_info: ckpt_path = Path(model_info["weights"]) original_config_file = Path(model_info["config"]) model_name = model_to_convert["model_name"] model_description = model_info["description"] else: - self.socketio.emit( - "error", {"message": "Model is not a valid checkpoint file"} - ) + self.socketio.emit("error", {"message": "Model is not a valid checkpoint file"}) else: - self.socketio.emit( - "error", {"message": "Could not retrieve model info."} - ) + self.socketio.emit("error", {"message": "Could not retrieve model info."}) if not ckpt_path.is_absolute(): ckpt_path = Path(Globals.root, ckpt_path) @@ -415,22 +398,13 @@ class InvokeAIWebServer: if original_config_file and not original_config_file.is_absolute(): original_config_file = Path(Globals.root, original_config_file) - diffusers_path = Path( - ckpt_path.parent.absolute(), f"{model_name}_diffusers" - ) + diffusers_path = Path(ckpt_path.parent.absolute(), f"{model_name}_diffusers") if model_to_convert["save_location"] == "root": - diffusers_path = Path( - global_converted_ckpts_dir(), f"{model_name}_diffusers" - ) + diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers") - if ( - model_to_convert["save_location"] == "custom" - and model_to_convert["custom_location"] is not None - ): - diffusers_path = Path( - model_to_convert["custom_location"], f"{model_name}_diffusers" - ) + if model_to_convert["save_location"] == "custom" and model_to_convert["custom_location"] is not None: + diffusers_path = Path(model_to_convert["custom_location"], f"{model_name}_diffusers") if diffusers_path.exists(): shutil.rmtree(diffusers_path) @@ -462,10 +436,7 @@ class InvokeAIWebServer: def merge_diffusers_models(model_merge_info: dict): try: models_to_merge = model_merge_info["models_to_merge"] - model_ids_or_paths = [ - self.generate.model_manager.model_name_or_path(x) - for x in models_to_merge - ] + model_ids_or_paths = [self.generate.model_manager.model_name_or_path(x) for x in models_to_merge] merged_pipe = merge_diffusion_models( model_ids_or_paths, model_merge_info["alpha"], @@ -487,15 +458,11 @@ class InvokeAIWebServer: commit_to_conf=opt.conf, ) - if vae := self.generate.model_manager.config[models_to_merge[0]].get( - "vae", None - ): + if vae := self.generate.model_manager.config[models_to_merge[0]].get("vae", None): logger.info(f"Using configured VAE assigned to {models_to_merge[0]}") merged_model_config.update(vae=vae) - self.generate.model_manager.import_diffuser_model( - dump_path, **merged_model_config - ) + self.generate.model_manager.import_diffuser_model(dump_path, **merged_model_config) new_model_list = self.generate.model_manager.list_models() socketio.emit( @@ -525,9 +492,7 @@ class InvokeAIWebServer: ) os.remove(thumbnail_path) except Exception as e: - socketio.emit( - "error", {"message": f"Unable to delete {f}: {str(e)}"} - ) + socketio.emit("error", {"message": f"Unable to delete {f}: {str(e)}"}) pass socketio.emit("tempFolderEmptied") @@ -550,9 +515,7 @@ class InvokeAIWebServer: (width, height) = pil_image.size - thumbnail_path = save_thumbnail( - pil_image, os.path.basename(new_path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(pil_image, os.path.basename(new_path), self.thumbnail_image_path) image_array = [ { @@ -577,18 +540,14 @@ class InvokeAIWebServer: @socketio.on("requestLatestImages") def handle_request_latest_images(category, latest_mtime): try: - base_path = ( - self.result_path if category == "result" else self.init_image_path - ) + base_path = self.result_path if category == "result" else self.init_image_path paths = [] for ext in ("*.png", "*.jpg", "*.jpeg"): paths.extend(glob.glob(os.path.join(base_path, ext))) - image_paths = sorted( - paths, key=lambda x: os.path.getmtime(x), reverse=True - ) + image_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True) image_paths = list( filter( @@ -609,16 +568,12 @@ class InvokeAIWebServer: pil_image = Image.open(path) (width, height) = pil_image.size - thumbnail_path = save_thumbnail( - pil_image, os.path.basename(path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(pil_image, os.path.basename(path), self.thumbnail_image_path) image_array.append( { "url": self.get_url_from_image_path(path), - "thumbnail": self.get_url_from_image_path( - thumbnail_path - ), + "thumbnail": self.get_url_from_image_path(thumbnail_path), "mtime": os.path.getmtime(path), "metadata": metadata.get("sd-metadata"), "dreamPrompt": metadata.get("Dream"), @@ -628,9 +583,7 @@ class InvokeAIWebServer: } ) except Exception as e: - socketio.emit( - "error", {"message": f"Unable to load {path}: {str(e)}"} - ) + socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"}) pass socketio.emit( @@ -645,17 +598,13 @@ class InvokeAIWebServer: try: page_size = 50 - base_path = ( - self.result_path if category == "result" else self.init_image_path - ) + base_path = self.result_path if category == "result" else self.init_image_path paths = [] for ext in ("*.png", "*.jpg", "*.jpeg"): paths.extend(glob.glob(os.path.join(base_path, ext))) - image_paths = sorted( - paths, key=lambda x: os.path.getmtime(x), reverse=True - ) + image_paths = sorted(paths, key=lambda x: os.path.getmtime(x), reverse=True) if earliest_mtime: image_paths = list( @@ -679,16 +628,12 @@ class InvokeAIWebServer: pil_image = Image.open(path) (width, height) = pil_image.size - thumbnail_path = save_thumbnail( - pil_image, os.path.basename(path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(pil_image, os.path.basename(path), self.thumbnail_image_path) image_array.append( { "url": self.get_url_from_image_path(path), - "thumbnail": self.get_url_from_image_path( - thumbnail_path - ), + "thumbnail": self.get_url_from_image_path(thumbnail_path), "mtime": os.path.getmtime(path), "metadata": metadata.get("sd-metadata"), "dreamPrompt": metadata.get("Dream"), @@ -699,9 +644,7 @@ class InvokeAIWebServer: ) except Exception as e: logger.info(f"Unable to load {path}") - socketio.emit( - "error", {"message": f"Unable to load {path}: {str(e)}"} - ) + socketio.emit("error", {"message": f"Unable to load {path}: {str(e)}"}) pass socketio.emit( @@ -716,9 +659,7 @@ class InvokeAIWebServer: self.handle_exceptions(e) @socketio.on("generateImage") - def handle_generate_image_event( - generation_parameters, esrgan_parameters, facetool_parameters - ): + def handle_generate_image_event(generation_parameters, esrgan_parameters, facetool_parameters): try: # truncate long init_mask/init_img base64 if needed printable_parameters = { @@ -726,14 +667,10 @@ class InvokeAIWebServer: } if "init_img" in generation_parameters: - printable_parameters["init_img"] = ( - printable_parameters["init_img"][:64] + "..." - ) + printable_parameters["init_img"] = printable_parameters["init_img"][:64] + "..." if "init_mask" in generation_parameters: - printable_parameters["init_mask"] = ( - printable_parameters["init_mask"][:64] + "..." - ) + printable_parameters["init_mask"] = printable_parameters["init_mask"][:64] + "..." logger.info(f"Image Generation Parameters:\n\n{printable_parameters}\n") logger.info(f"ESRGAN Parameters: {esrgan_parameters}") @@ -750,18 +687,14 @@ class InvokeAIWebServer: @socketio.on("runPostprocessing") def handle_run_postprocessing(original_image, postprocessing_parameters): try: - logger.info( - f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}' - ) + logger.info(f'Postprocessing requested for "{original_image["url"]}": {postprocessing_parameters}') progress = Progress() socketio.emit("progressUpdate", progress.to_formatted_dict()) eventlet.sleep(0) - original_image_path = self.get_image_path_from_url( - original_image["url"] - ) + original_image_path = self.get_image_path_from_url(original_image["url"]) image = Image.open(original_image_path) @@ -801,14 +734,10 @@ class InvokeAIWebServer: strength=postprocessing_parameters["facetool_strength"], fidelity=postprocessing_parameters["codeformer_fidelity"], seed=seed, - device="cpu" - if str(self.generate.device) == "mps" - else self.generate.device, + device="cpu" if str(self.generate.device) == "mps" else self.generate.device, ) else: - raise TypeError( - f'{postprocessing_parameters["type"]} is not a valid postprocessing type' - ) + raise TypeError(f'{postprocessing_parameters["type"]} is not a valid postprocessing type') progress.set_current_status("common.statusSavingImage") socketio.emit("progressUpdate", progress.to_formatted_dict()) @@ -832,9 +761,7 @@ class InvokeAIWebServer: postprocessing=postprocessing_parameters["type"], ) - thumbnail_path = save_thumbnail( - image, os.path.basename(path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(image, os.path.basename(path), self.thumbnail_image_path) self.write_log_message( f'[Postprocessed] "{original_image_path}" > "{path}": {postprocessing_parameters}' @@ -901,17 +828,13 @@ class InvokeAIWebServer: "app_version": APP_VERSION, } - def generate_images( - self, generation_parameters, esrgan_parameters, facetool_parameters - ): + def generate_images(self, generation_parameters, esrgan_parameters, facetool_parameters): try: self.canceled.clear() step_index = 1 prior_variations = ( - generation_parameters["with_variations"] - if "with_variations" in generation_parameters - else [] + generation_parameters["with_variations"] if "with_variations" in generation_parameters else [] ) actual_generation_mode = generation_parameters["generation_mode"] @@ -943,9 +866,7 @@ class InvokeAIWebServer: original_bounding_box = generation_parameters["bounding_box"].copy() - initial_image = dataURL_to_image( - generation_parameters["init_img"] - ).convert("RGBA") + initial_image = dataURL_to_image(generation_parameters["init_img"]).convert("RGBA") """ The outpaint image and mask are pre-cropped by the UI, so the bounding box we pass @@ -962,13 +883,9 @@ class InvokeAIWebServer: generation_parameters["bounding_box"]["y"] = 0 # Convert mask dataURL to an image and convert to greyscale - mask_image = dataURL_to_image( - generation_parameters["init_mask"] - ).convert("L") + mask_image = dataURL_to_image(generation_parameters["init_mask"]).convert("L") - actual_generation_mode = get_canvas_generation_mode( - initial_image, mask_image - ) + actual_generation_mode = get_canvas_generation_mode(initial_image, mask_image) """ Apply the mask to the init image, creating a "mask" image with @@ -1018,9 +935,7 @@ class InvokeAIWebServer: elif generation_parameters["generation_mode"] == "img2img": init_img_url = generation_parameters["init_img"] init_img_path = self.get_image_path_from_url(init_img_url) - generation_parameters["init_img"] = Image.open(init_img_path).convert( - "RGB" - ) + generation_parameters["init_img"] = Image.open(init_img_path).convert("RGB") def image_progress(intermediate_state: PipelineIntermediateState): if self.canceled.is_set(): @@ -1046,9 +961,7 @@ class InvokeAIWebServer: } progress.set_current_step(step + 1) - progress.set_current_status( - f"{generation_messages[actual_generation_mode]}" - ) + progress.set_current_status(f"{generation_messages[actual_generation_mode]}") progress.set_current_status_has_steps(True) if ( @@ -1057,9 +970,7 @@ class InvokeAIWebServer: and step < generation_parameters["steps"] - 1 ): image = self.generate.sample_to_image(sample) - metadata = self.parameters_to_generated_image_metadata( - generation_parameters - ) + metadata = self.parameters_to_generated_image_metadata(generation_parameters) command = parameters_to_command(generation_parameters) (width, height) = image.size @@ -1140,15 +1051,10 @@ class InvokeAIWebServer: all_parameters = generation_parameters postprocessing = False - if ( - "variation_amount" in all_parameters - and all_parameters["variation_amount"] > 0 - ): + if "variation_amount" in all_parameters and all_parameters["variation_amount"] > 0: first_seed = first_seed or seed this_variation = [[seed, all_parameters["variation_amount"]]] - all_parameters["with_variations"] = ( - prior_variations + this_variation - ) + all_parameters["with_variations"] = prior_variations + this_variation all_parameters["seed"] = first_seed elif "with_variations" in all_parameters: all_parameters["seed"] = first_seed @@ -1186,9 +1092,7 @@ class InvokeAIWebServer: if facetool_parameters["type"] == "gfpgan": progress.set_current_status("common.statusRestoringFacesGFPGAN") elif facetool_parameters["type"] == "codeformer": - progress.set_current_status( - "common.statusRestoringFacesCodeFormer" - ) + progress.set_current_status("common.statusRestoringFacesCodeFormer") progress.set_current_status_has_steps(False) self.socketio.emit("progressUpdate", progress.to_formatted_dict()) @@ -1206,18 +1110,12 @@ class InvokeAIWebServer: strength=facetool_parameters["strength"], fidelity=facetool_parameters["codeformer_fidelity"], seed=seed, - device="cpu" - if str(self.generate.device) == "mps" - else self.generate.device, + device="cpu" if str(self.generate.device) == "mps" else self.generate.device, ) - all_parameters["codeformer_fidelity"] = facetool_parameters[ - "codeformer_fidelity" - ] + all_parameters["codeformer_fidelity"] = facetool_parameters["codeformer_fidelity"] postprocessing = True - all_parameters["facetool_strength"] = facetool_parameters[ - "strength" - ] + all_parameters["facetool_strength"] = facetool_parameters["strength"] all_parameters["facetool_type"] = facetool_parameters["type"] progress.set_current_status("common.statusSavingImage") @@ -1226,9 +1124,7 @@ class InvokeAIWebServer: # restore the stashed URLS and discard the paths, we are about to send the result to client all_parameters["init_img"] = ( - init_img_url - if generation_parameters["generation_mode"] == "img2img" - else "" + init_img_url if generation_parameters["generation_mode"] == "img2img" else "" ) if "init_mask" in all_parameters: @@ -1246,8 +1142,7 @@ class InvokeAIWebServer: generated_image_outdir = ( self.result_path - if generation_parameters["generation_mode"] - in ["txt2img", "img2img"] + if generation_parameters["generation_mode"] in ["txt2img", "img2img"] else self.temp_image_path ) @@ -1259,9 +1154,7 @@ class InvokeAIWebServer: postprocessing=postprocessing, ) - thumbnail_path = save_thumbnail( - image, os.path.basename(path), self.thumbnail_image_path - ) + thumbnail_path = save_thumbnail(image, os.path.basename(path), self.thumbnail_image_path) logger.info(f'Image generated: "{path}"\n') self.write_log_message(f'[Generated] "{path}": {command}') @@ -1281,14 +1174,10 @@ class InvokeAIWebServer: tokens = ( None if type(parsed_prompt) is Blend - else get_tokens_for_prompt_object( - model.tokenizer, parsed_prompt - ) + else get_tokens_for_prompt_object(model.tokenizer, parsed_prompt) ) attention_maps_image_base64_url = ( - None - if attention_maps_image is None - else image_to_dataURL(attention_maps_image) + None if attention_maps_image is None else image_to_dataURL(attention_maps_image) ) self.socketio.emit( @@ -1382,9 +1271,7 @@ class InvokeAIWebServer: } if parameters["facetool_type"] == "codeformer": - facetool_parameters["fidelity"] = float( - parameters["codeformer_fidelity"] - ) + facetool_parameters["fidelity"] = float(parameters["codeformer_fidelity"]) postprocessing.append(facetool_parameters) @@ -1398,9 +1285,7 @@ class InvokeAIWebServer: } ) - rfc_dict["postprocessing"] = ( - postprocessing if len(postprocessing) > 0 else None - ) + rfc_dict["postprocessing"] = postprocessing if len(postprocessing) > 0 else None # semantic drift rfc_dict["sampler"] = parameters["sampler_name"] @@ -1409,22 +1294,15 @@ class InvokeAIWebServer: variations = [] if "with_variations" in parameters: - variations = [ - {"seed": x[0], "weight": x[1]} - for x in parameters["with_variations"] - ] + variations = [{"seed": x[0], "weight": x[1]} for x in parameters["with_variations"]] rfc_dict["variations"] = variations if rfc_dict["type"] == "img2img": rfc_dict["strength"] = parameters["strength"] rfc_dict["fit"] = parameters["fit"] # TODO: Noncompliant - rfc_dict["orig_hash"] = calculate_init_img_hash( - self.get_image_path_from_url(parameters["init_img"]) - ) - rfc_dict["init_image_path"] = parameters[ - "init_img" - ] # TODO: Noncompliant + rfc_dict["orig_hash"] = calculate_init_img_hash(self.get_image_path_from_url(parameters["init_img"])) + rfc_dict["init_image_path"] = parameters["init_img"] # TODO: Noncompliant metadata["image"] = rfc_dict @@ -1433,9 +1311,7 @@ class InvokeAIWebServer: except Exception as e: self.handle_exceptions(e) - def parameters_to_post_processed_image_metadata( - self, parameters, original_image_path - ): + def parameters_to_post_processed_image_metadata(self, parameters, original_image_path): try: current_metadata = retrieve_metadata(original_image_path)["sd-metadata"] postprocessing_metadata = {} @@ -1447,9 +1323,7 @@ class InvokeAIWebServer: if "image" not in current_metadata: current_metadata["image"] = {} - orig_hash = calculate_init_img_hash( - self.get_image_path_from_url(original_image_path) - ) + orig_hash = calculate_init_img_hash(self.get_image_path_from_url(original_image_path)) postprocessing_metadata["orig_path"] = (original_image_path,) postprocessing_metadata["orig_hash"] = orig_hash @@ -1473,9 +1347,7 @@ class InvokeAIWebServer: if "postprocessing" in current_metadata["image"] and isinstance( current_metadata["image"]["postprocessing"], list ): - current_metadata["image"]["postprocessing"].append( - postprocessing_metadata - ) + current_metadata["image"]["postprocessing"].append(postprocessing_metadata) else: current_metadata["image"]["postprocessing"] = [postprocessing_metadata] @@ -1556,29 +1428,17 @@ class InvokeAIWebServer: """Given a url to an image used by the client, returns the absolute file path to that image""" try: if "init-images" in url: - return os.path.abspath( - os.path.join(self.init_image_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.init_image_path, os.path.basename(url))) elif "mask-images" in url: - return os.path.abspath( - os.path.join(self.mask_image_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.mask_image_path, os.path.basename(url))) elif "intermediates" in url: - return os.path.abspath( - os.path.join(self.intermediate_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.intermediate_path, os.path.basename(url))) elif "temp-images" in url: - return os.path.abspath( - os.path.join(self.temp_image_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.temp_image_path, os.path.basename(url))) elif "thumbnails" in url: - return os.path.abspath( - os.path.join(self.thumbnail_image_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.thumbnail_image_path, os.path.basename(url))) else: - return os.path.abspath( - os.path.join(self.result_path, os.path.basename(url)) - ) + return os.path.abspath(os.path.join(self.result_path, os.path.basename(url))) except Exception as e: self.handle_exceptions(e) @@ -1632,18 +1492,14 @@ class Progress: self.total_steps = ( self._calculate_real_steps( steps=generation_parameters["steps"], - strength=generation_parameters["strength"] - if "strength" in generation_parameters - else None, + strength=generation_parameters["strength"] if "strength" in generation_parameters else None, has_init_image="init_img" in generation_parameters, ) if generation_parameters else 1 ) self.current_iteration = 1 - self.total_iterations = ( - generation_parameters["iterations"] if generation_parameters else 1 - ) + self.total_iterations = generation_parameters["iterations"] if generation_parameters else 1 self.current_status = "common.statusPreparing" self.is_processing = True self.current_status_has_steps = False @@ -1703,9 +1559,7 @@ class CanceledException(Exception): pass -def copy_image_from_bounding_box( - image: ImageType, x: int, y: int, width: int, height: int -) -> ImageType: +def copy_image_from_bounding_box(image: ImageType, x: int, y: int, width: int, height: int) -> ImageType: """ Returns a copy an image, cropped to a bounding box. """ @@ -1740,9 +1594,7 @@ def image_to_dataURL(image: ImageType, image_format: str = "PNG") -> str: buffered = io.BytesIO() image.save(buffered, format=image_format) mime_type = Image.MIME.get(image_format.upper(), "image/" + image_format.lower()) - image_base64 = f"data:{mime_type};base64," + base64.b64encode( - buffered.getvalue() - ).decode("UTF-8") + image_base64 = f"data:{mime_type};base64," + base64.b64encode(buffered.getvalue()).decode("UTF-8") return image_base64 diff --git a/invokeai/backend/web/modules/get_canvas_generation_mode.py b/invokeai/backend/web/modules/get_canvas_generation_mode.py index 55955cc33c..6d680016e7 100644 --- a/invokeai/backend/web/modules/get_canvas_generation_mode.py +++ b/invokeai/backend/web/modules/get_canvas_generation_mode.py @@ -40,9 +40,7 @@ def get_canvas_generation_mode( init_img_has_transparency = check_for_any_transparency(init_img) if init_img_has_transparency: - init_img_is_fully_transparent = ( - True if init_img_alpha_mask.getbbox() is None else False - ) + init_img_is_fully_transparent = True if init_img_alpha_mask.getbbox() is None else False """ Mask images are white in areas where no change should be made, black where changes diff --git a/invokeai/backend/web/modules/parameters.py b/invokeai/backend/web/modules/parameters.py index 440f21a947..8ab74adb92 100644 --- a/invokeai/backend/web/modules/parameters.py +++ b/invokeai/backend/web/modules/parameters.py @@ -10,7 +10,7 @@ SAMPLER_CHOICES = [ "lms_k", "pndm", "heun", - 'heun_k', + "heun_k", "euler", "euler_k", "euler_a", @@ -76,9 +76,7 @@ def parameters_to_command(params): if "variation_amount" in params and params["variation_amount"] > 0: switches.append(f'-v {params["variation_amount"]}') if "with_variations" in params: - seed_weight_pairs = ",".join( - f"{seed}:{weight}" for seed, weight in params["with_variations"] - ) + seed_weight_pairs = ",".join(f"{seed}:{weight}" for seed, weight in params["with_variations"]) switches.append(f"-V {seed_weight_pairs}") return " ".join(switches) diff --git a/invokeai/frontend/__init__.py b/invokeai/frontend/__init__.py index 98fdf870e9..19eafe46c4 100644 --- a/invokeai/frontend/__init__.py +++ b/invokeai/frontend/__init__.py @@ -1,3 +1,3 @@ -''' +""" Initialization file for invokeai.frontend -''' +""" diff --git a/invokeai/frontend/install/invokeai_update.py b/invokeai/frontend/install/invokeai_update.py index f73f670496..56d1a313c7 100644 --- a/invokeai/frontend/install/invokeai_update.py +++ b/invokeai/frontend/install/invokeai_update.py @@ -1,7 +1,7 @@ -''' +""" Minimalist updater script. Prompts user for the tag or branch to update to and runs pip install . -''' +""" import os import platform import pkg_resources @@ -15,10 +15,10 @@ from rich.style import Style from invokeai.version import __version__ -INVOKE_AI_SRC="https://github.com/invoke-ai/InvokeAI/archive" -INVOKE_AI_TAG="https://github.com/invoke-ai/InvokeAI/archive/refs/tags" -INVOKE_AI_BRANCH="https://github.com/invoke-ai/InvokeAI/archive/refs/heads" -INVOKE_AI_REL="https://api.github.com/repos/invoke-ai/InvokeAI/releases" +INVOKE_AI_SRC = "https://github.com/invoke-ai/InvokeAI/archive" +INVOKE_AI_TAG = "https://github.com/invoke-ai/InvokeAI/archive/refs/tags" +INVOKE_AI_BRANCH = "https://github.com/invoke-ai/InvokeAI/archive/refs/heads" +INVOKE_AI_REL = "https://api.github.com/repos/invoke-ai/InvokeAI/releases" OS = platform.uname().system ARCH = platform.uname().machine @@ -29,34 +29,38 @@ if OS == "Windows": else: console = Console(style=Style(color="grey74", bgcolor="grey19")) -def get_versions()->dict: + +def get_versions() -> dict: return requests.get(url=INVOKE_AI_REL).json() -def invokeai_is_running()->bool: + +def invokeai_is_running() -> bool: for p in psutil.process_iter(): try: cmdline = p.cmdline() - matches = [x for x in cmdline if x.endswith(('invokeai','invokeai.exe'))] + matches = [x for x in cmdline if x.endswith(("invokeai", "invokeai.exe"))] if matches: - print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]') + print( + f":exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]" + ) return True - except (psutil.AccessDenied,psutil.NoSuchProcess): + except (psutil.AccessDenied, psutil.NoSuchProcess): continue return False + def welcome(versions: dict): - @group() def text(): - yield f'InvokeAI Version: [bold yellow]{__version__}' - yield '' - yield 'This script will update InvokeAI to the latest release, or to a development version of your choice.' - yield '' - yield '[bold yellow]Options:' - yield f'''[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic]) + yield f"InvokeAI Version: [bold yellow]{__version__}" + yield "" + yield "This script will update InvokeAI to the latest release, or to a development version of your choice." + yield "" + yield "[bold yellow]Options:" + yield f"""[1] Update to the latest official release ([italic]{versions[0]['tag_name']}[/italic]) [2] Update to the bleeding-edge development version ([italic]main[/italic]) [3] Manually enter the [bold]tag name[/bold] for the version you wish to update to -[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to''' +[4] Manually enter the [bold]branch name[/bold] for the version you wish to update to""" console.rule() print( @@ -72,20 +76,22 @@ def welcome(versions: dict): ) console.line() + def get_extras(): - extras = '' + extras = "" try: - dist = pkg_resources.get_distribution('xformers') - extras = '[xformers]' + dist = pkg_resources.get_distribution("xformers") + extras = "[xformers]" except pkg_resources.DistributionNotFound: pass return extras + def main(): versions = get_versions() if invokeai_is_running(): - print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]') - input('Press any key to continue...') + print(f":exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]") + input("Press any key to continue...") return welcome(versions) @@ -93,36 +99,36 @@ def main(): tag = None branch = None release = None - choice = Prompt.ask('Choice:',choices=['1','2','3','4'],default='1') - - if choice=='1': - release = versions[0]['tag_name'] - elif choice=='2': - release = 'main' - elif choice=='3': - tag = Prompt.ask('Enter an InvokeAI tag name') - elif choice=='4': - branch = Prompt.ask('Enter an InvokeAI branch name') + choice = Prompt.ask("Choice:", choices=["1", "2", "3", "4"], default="1") + + if choice == "1": + release = versions[0]["tag_name"] + elif choice == "2": + release = "main" + elif choice == "3": + tag = Prompt.ask("Enter an InvokeAI tag name") + elif choice == "4": + branch = Prompt.ask("Enter an InvokeAI branch name") extras = get_extras() - print(f':crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]') + print(f":crossed_fingers: Upgrading to [yellow]{tag if tag else release}[/yellow]") if release: cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_SRC}/{release}.zip" --use-pep517 --upgrade' elif tag: cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_TAG}/{tag}.zip" --use-pep517 --upgrade' else: cmd = f'pip install "invokeai{extras} @ {INVOKE_AI_BRANCH}/{branch}.zip" --use-pep517 --upgrade' - print('') - print('') - if os.system(cmd)==0: - print(f':heavy_check_mark: Upgrade successful') + print("") + print("") + if os.system(cmd) == 0: + print(f":heavy_check_mark: Upgrade successful") else: - print(f':exclamation: [bold red]Upgrade failed[/red bold]') - + print(f":exclamation: [bold red]Upgrade failed[/red bold]") + + if __name__ == "__main__": try: main() except KeyboardInterrupt: pass - diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index 448ded6318..ea9efe1908 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -56,18 +56,18 @@ logger = InvokeAILogger.getLogger() # build a table mapping all non-printable characters to None # for stripping control characters # from https://stackoverflow.com/questions/92438/stripping-non-printable-characters-from-a-string-in-python -NOPRINT_TRANS_TABLE = { - i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable() -} +NOPRINT_TRANS_TABLE = {i: None for i in range(0, sys.maxunicode + 1) if not chr(i).isprintable()} -def make_printable(s:str)->str: - '''Replace non-printable characters in a string''' + +def make_printable(s: str) -> str: + """Replace non-printable characters in a string""" return s.translate(NOPRINT_TRANS_TABLE) + class addModelsForm(CyclingForm, npyscreen.FormMultiPage): # for responsive resizing set to False, but this seems to cause a crash! FIX_MINIMUM_SIZE_WHEN_CREATED = True - + # for persistence current_tab = 0 @@ -82,12 +82,12 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.subprocess_connection = None if not config.model_conf_path.exists(): - with open(config.model_conf_path,'w') as file: - print('# InvokeAI model configuration file',file=file) + with open(config.model_conf_path, "w") as file: + print("# InvokeAI model configuration file", file=file) self.installer = ModelInstall(config) self.all_models = self.installer.all_models() self.starter_models = self.installer.starter_models() - self.model_labels = self._get_model_labels() + self.model_labels = self._get_model_labels() window_width, window_height = get_terminal_size() self.nextrely -= 1 @@ -101,17 +101,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.tabs = self.add_widget_intelligent( SingleSelectColumns, values=[ - 'STARTER MODELS', - 'MORE MODELS', - 'CONTROLNETS', - 'LORA/LYCORIS', - 'TEXTUAL INVERSION', + "STARTER MODELS", + "MORE MODELS", + "CONTROLNETS", + "LORA/LYCORIS", + "TEXTUAL INVERSION", ], value=[self.current_tab], - columns = 5, - max_height = 2, + columns=5, + max_height=2, relx=8, - scroll_exit = True, + scroll_exit=True, ) self.tabs.on_changed = self._toggle_tables @@ -121,43 +121,41 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.nextrely = top_of_table self.pipeline_models = self.add_pipeline_widgets( - model_type=ModelType.Main, - window_width=window_width, - exclude = self.starter_models + model_type=ModelType.Main, window_width=window_width, exclude=self.starter_models ) # self.pipeline_models['autoload_pending'] = True - bottom_of_table = max(bottom_of_table,self.nextrely) + bottom_of_table = max(bottom_of_table, self.nextrely) self.nextrely = top_of_table self.controlnet_models = self.add_model_widgets( model_type=ModelType.ControlNet, window_width=window_width, ) - bottom_of_table = max(bottom_of_table,self.nextrely) + bottom_of_table = max(bottom_of_table, self.nextrely) self.nextrely = top_of_table self.lora_models = self.add_model_widgets( model_type=ModelType.Lora, window_width=window_width, ) - bottom_of_table = max(bottom_of_table,self.nextrely) + bottom_of_table = max(bottom_of_table, self.nextrely) self.nextrely = top_of_table self.ti_models = self.add_model_widgets( model_type=ModelType.TextualInversion, window_width=window_width, ) - bottom_of_table = max(bottom_of_table,self.nextrely) - - self.nextrely = bottom_of_table+1 + bottom_of_table = max(bottom_of_table, self.nextrely) + + self.nextrely = bottom_of_table + 1 self.monitor = self.add_widget_intelligent( BufferBox, - name='Log Messages', + name="Log Messages", editable=False, - max_height = 8, + max_height=8, ) - + self.nextrely += 1 done_label = "APPLY CHANGES" back_label = "BACK" @@ -172,16 +170,14 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): else: self.nextrely = current_position self.cancel_button = self.add_widget_intelligent( - npyscreen.ButtonPress, - name=cancel_label, - when_pressed_function=self.on_cancel + npyscreen.ButtonPress, name=cancel_label, when_pressed_function=self.on_cancel ) self.nextrely = current_position self.ok_button = self.add_widget_intelligent( npyscreen.ButtonPress, name=done_label, relx=(window_width - len(done_label)) // 2, - when_pressed_function=self.on_execute + when_pressed_function=self.on_execute, ) label = "APPLY CHANGES & EXIT" @@ -189,43 +185,41 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.done = self.add_widget_intelligent( npyscreen.ButtonPress, name=label, - relx=window_width-len(label)-15, + relx=window_width - len(label) - 15, when_pressed_function=self.on_done, ) # This restores the selected page on return from an installation - for i in range(1,self.current_tab+1): + for i in range(1, self.current_tab + 1): self.tabs.h_cursor_line_down(1) self._toggle_tables([self.current_tab]) - ############# diffusers tab ########## - def add_starter_pipelines(self)->dict[str, npyscreen.widget]: - '''Add widgets responsible for selecting diffusers models''' + ############# diffusers tab ########## + def add_starter_pipelines(self) -> dict[str, npyscreen.widget]: + """Add widgets responsible for selecting diffusers models""" widgets = dict() models = self.all_models starters = self.starter_models starter_model_labels = self.model_labels - - self.installed_models = sorted( - [x for x in starters if models[x].installed] - ) + + self.installed_models = sorted([x for x in starters if models[x].installed]) widgets.update( - label1 = self.add_widget_intelligent( + label1=self.add_widget_intelligent( CenteredTitleText, name="Select from a starter set of Stable Diffusion models from HuggingFace.", editable=False, labelColor="CAUTION", ) ) - + self.nextrely -= 1 # if user has already installed some initial models, then don't patronize them # by showing more recommendations - show_recommended = len(self.installed_models)==0 + show_recommended = len(self.installed_models) == 0 keys = [x for x in models.keys() if x in starters] widgets.update( - models_selected = self.add_widget_intelligent( + models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=1, name="Install Starter Models", @@ -233,40 +227,43 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): value=[ keys.index(x) for x in keys - if (show_recommended and models[x].recommended) \ - or (x in self.installed_models) + if (show_recommended and models[x].recommended) or (x in self.installed_models) ], max_height=len(starters) + 1, relx=4, scroll_exit=True, ), - models = keys, + models=keys, ) self.nextrely += 1 return widgets ############# Add a set of model install widgets ######## - def add_model_widgets(self, - model_type: ModelType, - window_width: int=120, - install_prompt: str=None, - exclude: set=set(), - )->dict[str,npyscreen.widget]: - '''Generic code to create model selection widgets''' + def add_model_widgets( + self, + model_type: ModelType, + window_width: int = 120, + install_prompt: str = None, + exclude: set = set(), + ) -> dict[str, npyscreen.widget]: + """Generic code to create model selection widgets""" widgets = dict() - model_list = [x for x in self.all_models if self.all_models[x].model_type==model_type and not x in exclude] + model_list = [x for x in self.all_models if self.all_models[x].model_type == model_type and not x in exclude] model_labels = [self.model_labels[x] for x in model_list] - show_recommended = len(self.installed_models)==0 + show_recommended = len(self.installed_models) == 0 if len(model_list) > 0: max_width = max([len(x) for x in model_labels]) - columns = window_width // (max_width+8) # 8 characters for "[x] " and padding - columns = min(len(model_list),columns) or 1 - prompt = install_prompt or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk." + columns = window_width // (max_width + 8) # 8 characters for "[x] " and padding + columns = min(len(model_list), columns) or 1 + prompt = ( + install_prompt + or f"Select the desired {model_type.value.title()} models to install. Unchecked models will be purged from disk." + ) widgets.update( - label1 = self.add_widget_intelligent( + label1=self.add_widget_intelligent( CenteredTitleText, name=prompt, editable=False, @@ -275,7 +272,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): ) widgets.update( - models_selected = self.add_widget_intelligent( + models_selected=self.add_widget_intelligent( MultiSelectColumns, columns=columns, name=f"Install {model_type} Models", @@ -283,21 +280,20 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): value=[ model_list.index(x) for x in model_list - if (show_recommended and self.all_models[x].recommended) \ - or self.all_models[x].installed + if (show_recommended and self.all_models[x].recommended) or self.all_models[x].installed ], - max_height=len(model_list)//columns + 1, + max_height=len(model_list) // columns + 1, relx=4, scroll_exit=True, ), - models = model_list, + models=model_list, ) self.nextrely += 1 widgets.update( - download_ids = self.add_widget_intelligent( + download_ids=self.add_widget_intelligent( TextBox, - name = "Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", + name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):", max_height=4, scroll_exit=True, editable=True, @@ -306,16 +302,17 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): return widgets ### Tab for arbitrary diffusers widgets ### - def add_pipeline_widgets(self, - model_type: ModelType=ModelType.Main, - window_width: int=120, - **kwargs, - )->dict[str,npyscreen.widget]: - '''Similar to add_model_widgets() but adds some additional widgets at the bottom - to support the autoload directory''' + def add_pipeline_widgets( + self, + model_type: ModelType = ModelType.Main, + window_width: int = 120, + **kwargs, + ) -> dict[str, npyscreen.widget]: + """Similar to add_model_widgets() but adds some additional widgets at the bottom + to support the autoload directory""" widgets = self.add_model_widgets( - model_type = model_type, - window_width = window_width, + model_type=model_type, + window_width=window_width, install_prompt=f"Additional {model_type.value.title()} models already installed.", **kwargs, ) @@ -324,7 +321,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): def resize(self): super().resize() - if (s := self.starter_pipelines.get("models_selected")): + if s := self.starter_pipelines.get("models_selected"): keys = [x for x in self.all_models.keys() if x in self.starter_models] s.values = [self.model_labels[x] for x in keys] @@ -339,27 +336,27 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): ] for group in widgets: - for k,v in group.items(): + for k, v in group.items(): try: v.hidden = True v.editable = False except: pass - for k,v in widgets[selected_tab].items(): + for k, v in widgets[selected_tab].items(): try: v.hidden = False - if not isinstance(v,(npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): + if not isinstance(v, (npyscreen.FixedText, npyscreen.TitleFixedText, CenteredTitleText)): v.editable = True except: pass self.__class__.current_tab = selected_tab # for persistence self.display() - def _get_model_labels(self) -> dict[str,str]: + def _get_model_labels(self) -> dict[str, str]: window_width, window_height = get_terminal_size() checkbox_width = 4 spacing_width = 2 - + models = self.all_models label_width = max([len(models[x].name) for x in models]) description_width = window_width - label_width - checkbox_width - spacing_width @@ -367,30 +364,28 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): result = dict() for x in models.keys(): description = models[x].description - description = description[0 : description_width - 3] + "..." \ - if description and len(description) > description_width \ - else description if description else '' - result[x] = f"%-{label_width}s %s" % (models[x].name, description) + description = ( + description[0 : description_width - 3] + "..." + if description and len(description) > description_width + else description + if description + else "" + ) + result[x] = f"%-{label_width}s %s" % (models[x].name, description) return result - + def _get_columns(self) -> int: window_width, window_height = get_terminal_size() - cols = ( - 4 - if window_width > 240 - else 3 - if window_width > 160 - else 2 - if window_width > 80 - else 1 - ) + cols = 4 if window_width > 240 else 3 if window_width > 160 else 2 if window_width > 80 else 1 return min(cols, len(self.installed_models)) - def confirm_deletions(self, selections: InstallSelections)->bool: + def confirm_deletions(self, selections: InstallSelections) -> bool: remove_models = selections.remove_models if len(remove_models) > 0: mods = "\n".join([ModelManager.parse_key(x)[0] for x in remove_models]) - return npyscreen.notify_ok_cancel(f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}") + return npyscreen.notify_ok_cancel( + f"These unchecked models will be deleted from disk. Continue?\n---------\n{mods}" + ) else: return True @@ -399,20 +394,20 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): app = self.parentApp if not self.confirm_deletions(app.install_selections): return - - self.monitor.entry_widget.buffer(['Processing...'],scroll_end=True) + + self.monitor.entry_widget.buffer(["Processing..."], scroll_end=True) self.ok_button.hidden = True self.display() - + # for communication with the subprocess parent_conn, child_conn = Pipe() p = Process( - target = process_and_execute, + target=process_and_execute, kwargs=dict( - opt = app.program_opts, - selections = app.install_selections, - conn_out = child_conn, - ) + opt=app.program_opts, + selections=app.install_selections, + conn_out=child_conn, + ), ) p.start() child_conn.close() @@ -429,7 +424,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = True self.editing = False - + def on_done(self): self.marshall_arguments() if not self.confirm_deletions(self.parentApp.install_selections): @@ -437,75 +432,79 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): self.parentApp.setNextForm(None) self.parentApp.user_cancelled = False self.editing = False - + ########## This routine monitors the child process that is performing model installation and removal ##### def while_waiting(self): - '''Called during idle periods. Main task is to update the Log Messages box with messages - from the child process that does the actual installation/removal''' + """Called during idle periods. Main task is to update the Log Messages box with messages + from the child process that does the actual installation/removal""" c = self.subprocess_connection if not c: return - + monitor_widget = self.monitor.entry_widget while c.poll(): try: - data = c.recv_bytes().decode('utf-8') - data.strip('\n') + data = c.recv_bytes().decode("utf-8") + data.strip("\n") # processing child is requesting user input to select the # right configuration file - if data.startswith('*need v2 config'): - _,model_path,*_ = data.split(":",2) + if data.startswith("*need v2 config"): + _, model_path, *_ = data.split(":", 2) self._return_v2_config(model_path) # processing child is done - elif data=='*done*': + elif data == "*done*": self._close_subprocess_and_regenerate_form() break # update the log message box else: - data=make_printable(data) - data=data.replace('[A','') + data = make_printable(data) + data = data.replace("[A", "") monitor_widget.buffer( - textwrap.wrap(data, - width=monitor_widget.width, - subsequent_indent=' ', - ), - scroll_end=True + textwrap.wrap( + data, + width=monitor_widget.width, + subsequent_indent=" ", + ), + scroll_end=True, ) self.display() - except (EOFError,OSError): + except (EOFError, OSError): self.subprocess_connection = None - def _return_v2_config(self,model_path: str): + def _return_v2_config(self, model_path: str): c = self.subprocess_connection model_name = Path(model_path).name message = select_stable_diffusion_config_file(model_name=model_name) - c.send_bytes(message.encode('utf-8')) + c.send_bytes(message.encode("utf-8")) def _close_subprocess_and_regenerate_form(self): app = self.parentApp self.subprocess_connection.close() self.subprocess_connection = None - self.monitor.entry_widget.buffer(['** Action Complete **']) + self.monitor.entry_widget.buffer(["** Action Complete **"]) self.display() - + # rebuild the form, saving and restoring some of the fields that need to be preserved. saved_messages = self.monitor.entry_widget.values # autoload_dir = str(config.root_path / self.pipeline_models['autoload_directory'].value) # autoscan = self.pipeline_models['autoscan_on_startup'].value - + app.main_form = app.addForm( - "MAIN", addModelsForm, name="Install Stable Diffusion Models", multipage=self.multipage, + "MAIN", + addModelsForm, + name="Install Stable Diffusion Models", + multipage=self.multipage, ) app.switchForm("MAIN") - + app.main_form.monitor.entry_widget.values = saved_messages - app.main_form.monitor.entry_widget.buffer([''],scroll_end=True) + app.main_form.monitor.entry_widget.buffer([""], scroll_end=True) # app.main_form.pipeline_models['autoload_directory'].value = autoload_dir # app.main_form.pipeline_models['autoscan_on_startup'].value = autoscan - + def marshall_arguments(self): """ Assemble arguments and store as attributes of the application: @@ -520,21 +519,29 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): all_models = self.all_models # Defined models (in INITIAL_CONFIG.yaml or models.yaml) to add/remove - ui_sections = [self.starter_pipelines, self.pipeline_models, - self.controlnet_models, self.lora_models, self.ti_models] + ui_sections = [ + self.starter_pipelines, + self.pipeline_models, + self.controlnet_models, + self.lora_models, + self.ti_models, + ] for section in ui_sections: - if not 'models_selected' in section: + if not "models_selected" in section: continue - selected = set([section['models'][x] for x in section['models_selected'].value]) + selected = set([section["models"][x] for x in section["models_selected"].value]) models_to_install = [x for x in selected if not self.all_models[x].installed] - models_to_remove = [x for x in section['models'] if x not in selected and self.all_models[x].installed] + models_to_remove = [x for x in section["models"] if x not in selected and self.all_models[x].installed] selections.remove_models.extend(models_to_remove) - selections.install_models.extend(all_models[x].path or all_models[x].repo_id \ - for x in models_to_install if all_models[x].path or all_models[x].repo_id) + selections.install_models.extend( + all_models[x].path or all_models[x].repo_id + for x in models_to_install + if all_models[x].path or all_models[x].repo_id + ) # models located in the 'download_ids" section for section in ui_sections: - if downloads := section.get('download_ids'): + if downloads := section.get("download_ids"): selections.install_models.extend(downloads.value.split()) # load directory and whether to scan on startup @@ -543,8 +550,9 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage): # self.parentApp.autoload_pending = False # selections.autoscan_on_startup = self.pipeline_models['autoscan_on_startup'].value + class AddModelApplication(npyscreen.NPSAppManaged): - def __init__(self,opt): + def __init__(self, opt): super().__init__() self.program_opts = opt self.user_cancelled = False @@ -554,76 +562,83 @@ class AddModelApplication(npyscreen.NPSAppManaged): def onStart(self): npyscreen.setTheme(npyscreen.Themes.DefaultTheme) self.main_form = self.addForm( - "MAIN", addModelsForm, name="Install Stable Diffusion Models", cycle_widgets=False, + "MAIN", + addModelsForm, + name="Install Stable Diffusion Models", + cycle_widgets=False, ) -class StderrToMessage(): + +class StderrToMessage: def __init__(self, connection: Connection): self.connection = connection - def write(self, data:str): - self.connection.send_bytes(data.encode('utf-8')) + def write(self, data: str): + self.connection.send_bytes(data.encode("utf-8")) def flush(self): pass + # -------------------------------------------------------- -def ask_user_for_prediction_type(model_path: Path, - tui_conn: Connection=None - )->SchedulerPredictionType: +def ask_user_for_prediction_type(model_path: Path, tui_conn: Connection = None) -> SchedulerPredictionType: if tui_conn: - logger.debug('Waiting for user response...') - return _ask_user_for_pt_tui(model_path, tui_conn) + logger.debug("Waiting for user response...") + return _ask_user_for_pt_tui(model_path, tui_conn) else: return _ask_user_for_pt_cmdline(model_path) -def _ask_user_for_pt_cmdline(model_path: Path)->SchedulerPredictionType: + +def _ask_user_for_pt_cmdline(model_path: Path) -> SchedulerPredictionType: choices = [SchedulerPredictionType.Epsilon, SchedulerPredictionType.VPrediction, None] print( -f""" + f""" Please select the type of the V2 checkpoint named {model_path.name}: [1] A model based on Stable Diffusion v2 trained on 512 pixel images (SD-2-base) [2] A model based on Stable Diffusion v2 trained on 768 pixel images (SD-2-768) [3] Skip this model and come back later. """ - ) + ) choice = None ok = False while not ok: try: - choice = input('select> ').strip() - choice = choices[int(choice)-1] + choice = input("select> ").strip() + choice = choices[int(choice) - 1] ok = True except (ValueError, IndexError): - print(f'{choice} is not a valid choice') + print(f"{choice} is not a valid choice") except EOFError: return return choice - -def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection)->SchedulerPredictionType: + + +def _ask_user_for_pt_tui(model_path: Path, tui_conn: Connection) -> SchedulerPredictionType: try: - tui_conn.send_bytes(f'*need v2 config for:{model_path}'.encode('utf-8')) + tui_conn.send_bytes(f"*need v2 config for:{model_path}".encode("utf-8")) # note that we don't do any status checking here - response = tui_conn.recv_bytes().decode('utf-8') + response = tui_conn.recv_bytes().decode("utf-8") if response is None: return None - elif response == 'epsilon': + elif response == "epsilon": return SchedulerPredictionType.epsilon - elif response == 'v': + elif response == "v": return SchedulerPredictionType.VPrediction - elif response == 'abort': - logger.info('Conversion aborted') + elif response == "abort": + logger.info("Conversion aborted") return None else: return response except: return None - + + # -------------------------------------------------------- -def process_and_execute(opt: Namespace, - selections: InstallSelections, - conn_out: Connection=None, - ): +def process_and_execute( + opt: Namespace, + selections: InstallSelections, + conn_out: Connection = None, +): # set up so that stderr is sent to conn_out if conn_out: translator = StderrToMessage(conn_out) @@ -633,66 +648,57 @@ def process_and_execute(opt: Namespace, logger.handlers.clear() logger.addHandler(logging.StreamHandler(translator)) - installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x,conn_out)) + installer = ModelInstall(config, prediction_type_helper=lambda x: ask_user_for_prediction_type(x, conn_out)) installer.install(selections) if conn_out: - conn_out.send_bytes('*done*'.encode('utf-8')) + conn_out.send_bytes("*done*".encode("utf-8")) conn_out.close() -def do_listings(opt)->bool: + +def do_listings(opt) -> bool: """List installed models of various sorts, and return True if any were requested.""" model_manager = ModelManager(config.model_conf_path) - if opt.list_models == 'diffusers': + if opt.list_models == "diffusers": print("Diffuser models:") model_manager.print_models() - elif opt.list_models == 'controlnets': + elif opt.list_models == "controlnets": print("Installed Controlnet Models:") cnm = model_manager.list_controlnet_models() - print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]),prefix=' ')) - elif opt.list_models == 'loras': + print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" ")) + elif opt.list_models == "loras": print("Installed LoRA/LyCORIS Models:") cnm = model_manager.list_lora_models() - print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]),prefix=' ')) - elif opt.list_models == 'tis': + print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" ")) + elif opt.list_models == "tis": print("Installed Textual Inversion Embeddings:") cnm = model_manager.list_ti_models() - print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]),prefix=' ')) + print(textwrap.indent("\n".join([x for x in cnm if cnm[x]]), prefix=" ")) else: return False return True + # -------------------------------------------------------- def select_and_download_models(opt: Namespace): - precision = ( - "float32" - if opt.full_precision - else choose_precision(torch.device(choose_torch_device())) - ) + precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device())) config.precision = precision helper = lambda x: ask_user_for_prediction_type(x) # if do_listings(opt): # pass - + installer = ModelInstall(config, prediction_type_helper=helper) if opt.list_models: installer.list_models(opt.list_models) elif opt.add or opt.delete: - selections = InstallSelections( - install_models = opt.add or [], - remove_models = opt.delete or [] - ) + selections = InstallSelections(install_models=opt.add or [], remove_models=opt.delete or []) installer.install(selections) elif opt.default_only: - selections = InstallSelections( - install_models = installer.default_model() - ) + selections = InstallSelections(install_models=installer.default_model()) installer.install(selections) elif opt.yes_to_all: - selections = InstallSelections( - install_models = installer.recommended_models() - ) + selections = InstallSelections(install_models=installer.recommended_models()) installer.install(selections) # this is where the TUI is called @@ -707,15 +713,15 @@ def select_and_download_models(opt: Namespace): try: installApp.run() except KeyboardInterrupt as e: - if hasattr(installApp,'main_form'): - if installApp.main_form.subprocess \ - and installApp.main_form.subprocess.is_alive(): - logger.info('Terminating subprocesses') + if hasattr(installApp, "main_form"): + if installApp.main_form.subprocess and installApp.main_form.subprocess.is_alive(): + logger.info("Terminating subprocesses") installApp.main_form.subprocess.terminate() installApp.main_form.subprocess = None raise e process_and_execute(opt, installApp.install_selections) + # ------------------------------------- def main(): parser = argparse.ArgumentParser(description="InvokeAI model downloader") @@ -770,19 +776,17 @@ def main(): help="path to root of install directory", ) opt = parser.parse_args() - + invoke_args = [] if opt.root: - invoke_args.extend(['--root',opt.root]) + invoke_args.extend(["--root", opt.root]) if opt.full_precision: - invoke_args.extend(['--precision','float32']) + invoke_args.extend(["--precision", "float32"]) config.parse_args(invoke_args) logger = InvokeAILogger().getLogger(config=config) if not config.model_conf_path.exists(): - logger.info( - "Your InvokeAI root directory is not set up. Calling invokeai-configure." - ) + logger.info("Your InvokeAI root directory is not set up. Calling invokeai-configure.") from invokeai.frontend.install import invokeai_configure invokeai_configure() @@ -800,20 +804,18 @@ def main(): logger.info("Goodbye! Come back soon.") except widget.NotEnoughSpaceForWidget as e: if str(e).startswith("Height of 1 allocated"): - logger.error( - "Insufficient vertical space for the interface. Please make your window taller and try again" - ) - input('Press any key to continue...') + logger.error("Insufficient vertical space for the interface. Please make your window taller and try again") + input("Press any key to continue...") except Exception as e: if str(e).startswith("addwstr"): logger.error( "Insufficient horizontal space for the interface. Please make your window wider and try again." ) else: - print(f'An exception has occurred: {str(e)} Details:') + print(f"An exception has occurred: {str(e)} Details:") print(traceback.format_exc(), file=sys.stderr) - input('Press any key to continue...') - + input("Press any key to continue...") + # ------------------------------------- if __name__ == "__main__": diff --git a/invokeai/frontend/install/widgets.py b/invokeai/frontend/install/widgets.py index 65f1589589..10da15bf13 100644 --- a/invokeai/frontend/install/widgets.py +++ b/invokeai/frontend/install/widgets.py @@ -14,40 +14,42 @@ import textwrap import npyscreen.wgmultiline as wgmultiline from npyscreen import fmPopup from shutil import get_terminal_size -from curses import BUTTON2_CLICKED,BUTTON3_CLICKED +from curses import BUTTON2_CLICKED, BUTTON3_CLICKED # minimum size for UIs MIN_COLS = 130 MIN_LINES = 38 + # ------------------------------------- def set_terminal_size(columns: int, lines: int): ts = get_terminal_size() - width = max(columns,ts.columns) - height = max(lines,ts.lines) + width = max(columns, ts.columns) + height = max(lines, ts.lines) OS = platform.uname().system if OS == "Windows": pass # not working reliably - ask user to adjust the window - #_set_terminal_size_powershell(width,height) + # _set_terminal_size_powershell(width,height) elif OS in ["Darwin", "Linux"]: - _set_terminal_size_unix(width,height) + _set_terminal_size_unix(width, height) # check whether it worked.... ts = get_terminal_size() pause = False if ts.columns < columns: - print('\033[1mThis window is too narrow for the user interface.\033[0m') + print("\033[1mThis window is too narrow for the user interface.\033[0m") pause = True if ts.lines < lines: - print('\033[1mThis window is too short for the user interface.\033[0m') + print("\033[1mThis window is too short for the user interface.\033[0m") pause = True if pause: - input('Maximize the window then press any key to continue..') + input("Maximize the window then press any key to continue..") + def _set_terminal_size_powershell(width: int, height: int): - script=f''' + script = f""" $pshost = get-host $pswindow = $pshost.ui.rawui $newsize = $pswindow.buffersize @@ -58,8 +60,9 @@ $newsize = $pswindow.windowsize $newsize.height = {height} $newsize.width = {width} $pswindow.windowsize = $newsize -''' - subprocess.run(["powershell","-Command","-"],input=script,text=True) +""" + subprocess.run(["powershell", "-Command", "-"], input=script, text=True) + def _set_terminal_size_unix(width: int, height: int): import fcntl @@ -67,15 +70,16 @@ def _set_terminal_size_unix(width: int, height: int): # These terminals accept the size command and report that the # size changed, but they lie!!! - for bad_terminal in ['TERMINATOR_UUID', 'ALACRITTY_WINDOW_ID']: + for bad_terminal in ["TERMINATOR_UUID", "ALACRITTY_WINDOW_ID"]: if os.environ.get(bad_terminal): return - + winsize = struct.pack("HHHH", height, width, 0, 0) fcntl.ioctl(sys.stdout.fileno(), termios.TIOCSWINSZ, winsize) sys.stdout.write("\x1b[8;{height};{width}t".format(height=height, width=width)) sys.stdout.flush() + def set_min_terminal_size(min_cols: int, min_lines: int): # make sure there's enough room for the ui term_cols, term_lines = get_terminal_size() @@ -85,6 +89,7 @@ def set_min_terminal_size(min_cols: int, min_lines: int): lines = max(term_lines, min_lines) set_terminal_size(cols, lines) + class IntSlider(npyscreen.Slider): def translate_value(self): stri = "%2d / %2d" % (self.value, self.out_of) @@ -92,23 +97,25 @@ class IntSlider(npyscreen.Slider): stri = stri.rjust(l) return stri + # ------------------------------------- # fix npyscreen form so that cursor wraps both forward and backward class CyclingForm(object): def find_previous_editable(self, *args): done = False - n = self.editw-1 + n = self.editw - 1 while not done: if self._widgets__[n].editable and not self._widgets__[n].hidden: self.editw = n done = True n -= 1 - if n<0: + if n < 0: if self.cycle_widgets: - n = len(self._widgets__)-1 + n = len(self._widgets__) - 1 else: done = True - + + # ------------------------------------- class CenteredTitleText(npyscreen.TitleText): def __init__(self, *args, **keywords): @@ -159,7 +166,8 @@ class FloatSlider(npyscreen.Slider): class FloatTitleSlider(npyscreen.TitleText): _entry_type = FloatSlider -class SelectColumnBase(): + +class SelectColumnBase: def make_contained_widgets(self): self._my_widgets = [] column_width = self.width // self.columns @@ -217,11 +225,12 @@ class SelectColumnBase(): row_no = rel_y // self._contained_widget_height self.cursor_line = column_no * column_height + row_no if bstate & curses.BUTTON1_DOUBLE_CLICKED: - if hasattr(self,'on_mouse_double_click'): + if hasattr(self, "on_mouse_double_click"): self.on_mouse_double_click(self.cursor_line) self.display() -class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect): + +class MultiSelectColumns(SelectColumnBase, npyscreen.MultiSelect): def __init__(self, screen, columns: int = 1, values: list = [], **keywords): self.columns = columns self.value_cnt = len(values) @@ -231,15 +240,17 @@ class MultiSelectColumns( SelectColumnBase, npyscreen.MultiSelect): def on_mouse_double_click(self, cursor_line): self.h_select_toggle(cursor_line) -class SingleSelectWithChanged(npyscreen.SelectOne): - def __init__(self,*args,**kwargs): - super().__init__(*args,**kwargs) - def h_select(self,ch): +class SingleSelectWithChanged(npyscreen.SelectOne): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def h_select(self, ch): super().h_select(ch) if self.on_changed: self.on_changed(self.value) + class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged): def __init__(self, screen, columns: int = 1, values: list = [], **keywords): self.columns = columns @@ -254,23 +265,25 @@ class SingleSelectColumns(SelectColumnBase, SingleSelectWithChanged): def when_cursor_moved(self): self.h_select(self.cursor_line) - def h_cursor_line_right(self,ch): - self.h_exit_down('bye bye') + def h_cursor_line_right(self, ch): + self.h_exit_down("bye bye") + class TextBoxInner(npyscreen.MultiLineEdit): - - def __init__(self,*args,**kwargs): - super().__init__(*args,**kwargs) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.yank = None - self.handlers.update({ - "^A": self.h_cursor_to_start, - "^E": self.h_cursor_to_end, - "^K": self.h_kill, - "^F": self.h_cursor_right, - "^B": self.h_cursor_left, - "^Y": self.h_yank, - "^V": self.h_paste, - }) + self.handlers.update( + { + "^A": self.h_cursor_to_start, + "^E": self.h_cursor_to_end, + "^K": self.h_kill, + "^F": self.h_cursor_right, + "^B": self.h_cursor_left, + "^Y": self.h_yank, + "^V": self.h_paste, + } + ) def h_cursor_to_start(self, input): self.cursor_position = 0 @@ -279,27 +292,27 @@ class TextBoxInner(npyscreen.MultiLineEdit): self.cursor_position = len(self.value) def h_kill(self, input): - self.yank = self.value[self.cursor_position:] - self.value = self.value[:self.cursor_position] + self.yank = self.value[self.cursor_position :] + self.value = self.value[: self.cursor_position] def h_yank(self, input): if self.yank: self.paste(self.yank) def paste(self, text: str): - self.value = self.value[:self.cursor_position] + text + self.value[self.cursor_position:] + self.value = self.value[: self.cursor_position] + text + self.value[self.cursor_position :] self.cursor_position += len(text) - def h_paste(self, input: int=0): + def h_paste(self, input: int = 0): try: text = pyperclip.paste() except ModuleNotFoundError: text = "To paste with the mouse on Linux, please install the 'xclip' program." self.paste(text) - + def handle_mouse_event(self, mouse_event): mouse_id, rel_x, rel_y, z, bstate = self.interpret_mouse_event(mouse_event) - if bstate & (BUTTON2_CLICKED|BUTTON3_CLICKED): + if bstate & (BUTTON2_CLICKED | BUTTON3_CLICKED): self.h_paste() # def update(self, clear=True): @@ -320,77 +333,87 @@ class TextBoxInner(npyscreen.MultiLineEdit): # self.rely, self.relx + WIDTH, curses.ACS_VLINE, HEIGHT # ) - # # draw corners - # self.parent.curses_pad.addch( - # self.rely, - # self.relx, - # curses.ACS_ULCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely, - # self.relx + WIDTH, - # curses.ACS_URCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely + HEIGHT, - # self.relx, - # curses.ACS_LLCORNER, - # ) - # self.parent.curses_pad.addch( - # self.rely + HEIGHT, - # self.relx + WIDTH, - # curses.ACS_LRCORNER, - # ) + # # draw corners + # self.parent.curses_pad.addch( + # self.rely, + # self.relx, + # curses.ACS_ULCORNER, + # ) + # self.parent.curses_pad.addch( + # self.rely, + # self.relx + WIDTH, + # curses.ACS_URCORNER, + # ) + # self.parent.curses_pad.addch( + # self.rely + HEIGHT, + # self.relx, + # curses.ACS_LLCORNER, + # ) + # self.parent.curses_pad.addch( + # self.rely + HEIGHT, + # self.relx + WIDTH, + # curses.ACS_LRCORNER, + # ) + + # # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work + # (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width) + # self.relx += 1 + # self.rely += 1 + # self.height -= 1 + # self.width -= 1 + # super().update(clear=False) + # (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width) - # # fool our superclass into thinking drawing area is smaller - this is really hacky but it seems to work - # (relx, rely, height, width) = (self.relx, self.rely, self.height, self.width) - # self.relx += 1 - # self.rely += 1 - # self.height -= 1 - # self.width -= 1 - # super().update(clear=False) - # (self.relx, self.rely, self.height, self.width) = (relx, rely, height, width) class TextBox(npyscreen.BoxTitle): _contained_widget = TextBoxInner + class BufferBox(npyscreen.BoxTitle): _contained_widget = npyscreen.BufferPager + class ConfirmCancelPopup(fmPopup.ActionPopup): DEFAULT_COLUMNS = 100 + def on_ok(self): self.value = True + def on_cancel(self): self.value = False - + + class FileBox(npyscreen.BoxTitle): _contained_widget = npyscreen.Filename - + + class PrettyTextBox(npyscreen.BoxTitle): _contained_widget = TextBox - + + def _wrap_message_lines(message, line_length): lines = [] - for line in message.split('\n'): + for line in message.split("\n"): lines.extend(textwrap.wrap(line.rstrip(), line_length)) return lines - + + def _prepare_message(message): if isinstance(message, list) or isinstance(message, tuple): - return "\n".join([ s.rstrip() for s in message]) - #return "\n".join(message) + return "\n".join([s.rstrip() for s in message]) + # return "\n".join(message) else: return message - + + def select_stable_diffusion_config_file( - form_color: str='DANGER', - wrap:bool =True, - model_name:str='Unknown', + form_color: str = "DANGER", + wrap: bool = True, + model_name: str = "Unknown", ): message = f"Please select the correct base model for the V2 checkpoint named '{model_name}'. Press to skip installation." title = "CONFIG FILE SELECTION" - options=[ + options = [ "An SD v2.x base model (512 pixels; no 'parameterization:' line in its yaml file)", "An SD v2.x v-predictive model (768 pixels; 'parameterization: \"v\"' line in its yaml file)", "Skip installation for now and come back later", @@ -403,22 +426,22 @@ def select_stable_diffusion_config_file( lines=16, ) F.preserve_selected_widget = True - + mlw = F.add( wgmultiline.Pager, max_height=4, editable=False, ) - mlw_width = mlw.width-1 + mlw_width = mlw.width - 1 if wrap: message = _wrap_message_lines(message, mlw_width) mlw.values = message choice = F.add( npyscreen.SelectOne, - values = options, - value = [0], - max_height = len(options)+1, + values=options, + value=[0], + max_height=len(options) + 1, scroll_exit=True, ) @@ -426,6 +449,6 @@ def select_stable_diffusion_config_file( F.edit() if not F.value: return None - assert choice.value[0] in range(0,3),'invalid choice' - choices = ['epsilon','v','abort'] + assert choice.value[0] in range(0, 3), "invalid choice" + choices = ["epsilon", "v", "abort"] return choices[choice.value[0]] diff --git a/invokeai/frontend/legacy_launch_invokeai.py b/invokeai/frontend/legacy_launch_invokeai.py index 349fa5b945..e4509db6e5 100644 --- a/invokeai/frontend/legacy_launch_invokeai.py +++ b/invokeai/frontend/legacy_launch_invokeai.py @@ -2,18 +2,22 @@ import os import sys import argparse + def main(): parser = argparse.ArgumentParser() - parser.add_argument('--web', action='store_true') - opts,_ = parser.parse_known_args() + parser.add_argument("--web", action="store_true") + opts, _ = parser.parse_known_args() if opts.web: - sys.argv.pop(sys.argv.index('--web')) + sys.argv.pop(sys.argv.index("--web")) from invokeai.app.api_app import invoke_api + invoke_api() else: from invokeai.app.cli_app import invoke_cli + invoke_cli() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/invokeai/frontend/merge/__init__.py b/invokeai/frontend/merge/__init__.py index f1fc66c39e..3a2e4474a5 100644 --- a/invokeai/frontend/merge/__init__.py +++ b/invokeai/frontend/merge/__init__.py @@ -2,4 +2,3 @@ Initialization file for invokeai.frontend.merge """ from .merge_diffusers import main as invokeai_merge_diffusers - diff --git a/invokeai/frontend/merge/merge_diffusers.py b/invokeai/frontend/merge/merge_diffusers.py index c20d913883..81a7b3f6c8 100644 --- a/invokeai/frontend/merge/merge_diffusers.py +++ b/invokeai/frontend/merge/merge_diffusers.py @@ -20,13 +20,17 @@ from omegaconf import OmegaConf import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_management import ( - ModelMerger, MergeInterpolationMethod, - ModelManager, ModelType, BaseModelType, + ModelMerger, + MergeInterpolationMethod, + ModelManager, + ModelType, + BaseModelType, ) from invokeai.frontend.install.widgets import FloatTitleSlider, TextBox, SingleSelectColumns config = InvokeAIAppConfig.get_config() + def _parse_args() -> Namespace: parser = argparse.ArgumentParser(description="InvokeAI model merging") parser.add_argument( @@ -134,14 +138,14 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): self.base_select = self.add_widget_intelligent( SingleSelectColumns, values=[ - 'Models Built on SD-1.x', - 'Models Built on SD-2.x', + "Models Built on SD-1.x", + "Models Built on SD-2.x", ], value=[self.current_base], - columns = 4, - max_height = 2, + columns=4, + max_height=2, relx=8, - scroll_exit = True, + scroll_exit=True, ) self.base_select.on_changed = self._populate_models self.add_widget_intelligent( @@ -300,15 +304,11 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): def validate_field_values(self) -> bool: bad_fields = [] model_names = self.model_names - selected_models = set( - (model_names[self.model1.value[0]], model_names[self.model2.value[0]]) - ) + selected_models = set((model_names[self.model1.value[0]], model_names[self.model2.value[0]])) if self.model3.value[0] > 0: selected_models.add(model_names[self.model3.value[0] - 1]) if len(selected_models) < 2: - bad_fields.append( - f"Please select two or three DIFFERENT models to compare. You selected {selected_models}" - ) + bad_fields.append(f"Please select two or three DIFFERENT models to compare. You selected {selected_models}") if len(bad_fields) > 0: message = "The following problems were detected and must be corrected:" for problem in bad_fields: @@ -318,7 +318,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): else: return True - def get_model_names(self, base_model: BaseModelType=None) -> List[str]: + def get_model_names(self, base_model: BaseModelType = None) -> List[str]: model_names = [ info["name"] for info in self.model_manager.list_models(model_type=ModelType.Main, base_model=base_model) @@ -326,20 +326,21 @@ class mergeModelsForm(npyscreen.FormMultiPageAction): ] return sorted(model_names) - def _populate_models(self,value=None): + def _populate_models(self, value=None): base_model = tuple(BaseModelType)[value[0]] self.model_names = self.get_model_names(base_model) - + models_plus_none = self.model_names.copy() models_plus_none.insert(0, "None") self.model1.values = self.model_names self.model2.values = self.model_names self.model3.values = models_plus_none - + self.display() + class Mergeapp(npyscreen.NPSAppManaged): - def __init__(self, model_manager:ModelManager): + def __init__(self, model_manager: ModelManager): super().__init__() self.model_manager = model_manager @@ -367,9 +368,7 @@ def run_cli(args: Namespace): if not args.merged_model_name: args.merged_model_name = "+".join(args.model_names) - logger.info( - f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"' - ) + logger.info(f'No --merged_model_name provided. Defaulting to "{args.merged_model_name}"') model_manager = ModelManager(config.model_conf_path) assert ( @@ -383,7 +382,7 @@ def run_cli(args: Namespace): def main(): args = _parse_args() - config.parse_args(['--root',str(args.root_dir)]) + config.parse_args(["--root", str(args.root_dir)]) try: if args.front_end: @@ -392,13 +391,9 @@ def main(): run_cli(args) except widget.NotEnoughSpaceForWidget as e: if str(e).startswith("Height of 1 allocated"): - logger.error( - "You need to have at least two diffusers models defined in models.yaml in order to merge" - ) + logger.error("You need to have at least two diffusers models defined in models.yaml in order to merge") else: - logger.error( - "Not enough room for the user interface. Try making this window larger." - ) + logger.error("Not enough room for the user interface. Try making this window larger.") sys.exit(-1) except Exception as e: logger.error(e) diff --git a/invokeai/frontend/training/textual_inversion.py b/invokeai/frontend/training/textual_inversion.py index e1c7b3749f..25debf4bdc 100755 --- a/invokeai/frontend/training/textual_inversion.py +++ b/invokeai/frontend/training/textual_inversion.py @@ -23,16 +23,14 @@ from omegaconf import OmegaConf import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig -from ...backend.training import ( - do_textual_inversion_training, - parse_args -) +from ...backend.training import do_textual_inversion_training, parse_args TRAINING_DATA = "text-inversion-training-data" TRAINING_DIR = "text-inversion-output" CONF_FILE = "preferences.conf" config = None + class textualInversionForm(npyscreen.FormMultiPageAction): resolutions = [512, 768, 1024] lr_schedulers = [ @@ -111,9 +109,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): npyscreen.TitleSelectOne, name="Learnable property:", values=self.learnable_properties, - value=self.learnable_properties.index( - saved_args.get("learnable_property", "object") - ), + value=self.learnable_properties.index(saved_args.get("learnable_property", "object")), max_height=4, scroll_exit=True, ) @@ -243,9 +239,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): def initializer_changed(self): placeholder = self.placeholder_token.value self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)" - self.train_data_dir.value = str( - config.root_dir / TRAINING_DATA / placeholder - ) + self.train_data_dir.value = str(config.root_dir / TRAINING_DATA / placeholder) self.output_dir.value = str(config.root_dir / TRAINING_DIR / placeholder) self.resume_from_checkpoint.value = Path(self.output_dir.value).exists() @@ -254,9 +248,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): self.parentApp.setNextForm(None) self.editing = False self.parentApp.ti_arguments = self.marshall_arguments() - npyscreen.notify( - "Launching textual inversion training. This will take a while..." - ) + npyscreen.notify("Launching textual inversion training. This will take a while...") else: self.editing = True @@ -266,13 +258,9 @@ class textualInversionForm(npyscreen.FormMultiPageAction): def validate_field_values(self) -> bool: bad_fields = [] if self.model.value is None: - bad_fields.append( - "Model Name must correspond to a known model in models.yaml" - ) + bad_fields.append("Model Name must correspond to a known model in models.yaml") if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value): - bad_fields.append( - "Trigger term must only contain alphanumeric characters, the dot and hyphen" - ) + bad_fields.append("Trigger term must only contain alphanumeric characters, the dot and hyphen") if self.train_data_dir.value is None: bad_fields.append("Data Training Directory cannot be empty") if self.output_dir.value is None: @@ -288,16 +276,8 @@ class textualInversionForm(npyscreen.FormMultiPageAction): def get_model_names(self) -> Tuple[List[str], int]: conf = OmegaConf.load(config.root_dir / "configs/models.yaml") - model_names = [ - idx - for idx in sorted(list(conf.keys())) - if conf[idx].get("format", None) == "diffusers" - ] - defaults = [ - idx - for idx in range(len(model_names)) - if "default" in conf[model_names[idx]] - ] + model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get("format", None) == "diffusers"] + defaults = [idx for idx in range(len(model_names)) if "default" in conf[model_names[idx]]] default = defaults[0] if len(defaults) > 0 else 0 return (model_names, default) @@ -310,9 +290,7 @@ class textualInversionForm(npyscreen.FormMultiPageAction): resolution=self.resolutions[self.resolution.value[0]], lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]], mixed_precision=self.precisions[self.mixed_precision.value[0]], - learnable_property=self.learnable_properties[ - self.learnable_property.value[0] - ], + learnable_property=self.learnable_properties[self.learnable_property.value[0]], ) # all the strings and booleans @@ -374,9 +352,7 @@ def copy_to_embeddings_folder(args: dict): os.makedirs(destination, exist_ok=True) logger.info(f"Training completed. Copying learned_embeds.bin into {str(destination)}") shutil.copy(source, destination) - if ( - input("Delete training logs and intermediate checkpoints? [y] ") or "y" - ).startswith(("y", "Y")): + if (input("Delete training logs and intermediate checkpoints? [y] ") or "y").startswith(("y", "Y")): shutil.rmtree(Path(args["output_dir"])) else: logger.info(f'Keeping {args["output_dir"]}') @@ -423,7 +399,7 @@ def do_front_end(args: Namespace): save_args(args) try: - do_textual_inversion_training(InvokeAIAppConfig.get_config(),**args) + do_textual_inversion_training(InvokeAIAppConfig.get_config(), **args) copy_to_embeddings_folder(args) except Exception as e: logger.error("An exception occurred during training. The exception was:") @@ -434,19 +410,19 @@ def do_front_end(args: Namespace): def main(): global config - + args = parse_args() config = InvokeAIAppConfig.get_config() # change root if needed if args.root_dir: config.root = args.root_dir - + try: if args.front_end: do_front_end(args) else: - do_textual_inversion_training(config,**vars(args)) + do_textual_inversion_training(config, **vars(args)) except AssertionError as e: logger.error(e) sys.exit(-1) @@ -454,13 +430,9 @@ def main(): pass except (widget.NotEnoughSpaceForWidget, Exception) as e: if str(e).startswith("Height of 1 allocated"): - logger.error( - "You need to have at least one diffusers models defined in models.yaml in order to train" - ) + logger.error("You need to have at least one diffusers models defined in models.yaml in order to train") elif str(e).startswith("addwstr"): - logger.error( - "Not enough window space for the interface. Please make your window larger and try again." - ) + logger.error("Not enough window space for the interface. Please make your window larger and try again.") else: logger.error(e) sys.exit(-1) diff --git a/invokeai/frontend/web/__init__.py b/invokeai/frontend/web/__init__.py index 010129ece2..e9758b27b6 100644 --- a/invokeai/frontend/web/__init__.py +++ b/invokeai/frontend/web/__init__.py @@ -1,3 +1,3 @@ -''' +""" Initialization file for invokeai.frontend.web -''' +""" diff --git a/invokeai/version/__init__.py b/invokeai/version/__init__.py index 01ef84ea4d..21dfcad3ca 100644 --- a/invokeai/version/__init__.py +++ b/invokeai/version/__init__.py @@ -6,6 +6,7 @@ from .invokeai_version import __version__ __app_id__ = "invoke-ai/InvokeAI" __app_name__ = "InvokeAI" + def _ignore_xformers_triton_message_on_windows(): import logging diff --git a/notebooks/notebook_helpers.py b/notebooks/notebook_helpers.py index 663b212ac5..b87f18ba18 100644 --- a/notebooks/notebook_helpers.py +++ b/notebooks/notebook_helpers.py @@ -2,6 +2,7 @@ from torchvision.datasets.utils import download_url from ldm.util import instantiate_from_config import torch import os + # todo ? from google.colab import files from IPython.display import Image as ipyimg @@ -16,21 +17,21 @@ import time from omegaconf import OmegaConf from ldm.invoke.devices import choose_torch_device -def download_models(mode): +def download_models(mode): if mode == "superresolution": # this is the small bsr light model - url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1' - url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1' + url_conf = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + url_ckpt = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - path_conf = 'logs/diffusion/superresolution_bsr/configs/project.yaml' - path_ckpt = 'logs/diffusion/superresolution_bsr/checkpoints/last.ckpt' + path_conf = "logs/diffusion/superresolution_bsr/configs/project.yaml" + path_ckpt = "logs/diffusion/superresolution_bsr/checkpoints/last.ckpt" download_url(url_conf, path_conf) download_url(url_ckpt, path_ckpt) - path_conf = path_conf + '/?dl=1' # fix it - path_ckpt = path_ckpt + '/?dl=1' # fix it + path_conf = path_conf + "/?dl=1" # fix it + path_ckpt = path_ckpt + "/?dl=1" # fix it return path_conf, path_ckpt else: @@ -62,20 +63,20 @@ def get_custom_cond(mode): if mode == "superresolution": uploaded_img = files.upload() filename = next(iter(uploaded_img)) - name, filetype = filename.split(".") # todo assumes just one dot in name ! + name, filetype = filename.split(".") # todo assumes just one dot in name ! os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}") elif mode == "text_conditional": - w = widgets.Text(value='A cake with cream!', disabled=True) + w = widgets.Text(value="A cake with cream!", disabled=True) display(w) - with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f: + with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", "w") as f: f.write(w.value) elif mode == "class_conditional": w = widgets.IntSlider(min=0, max=1000) display(w) - with open(f"{dest}/{mode}/custom.txt", 'w') as f: + with open(f"{dest}/{mode}/custom.txt", "w") as f: f.write(w.value) else: @@ -94,11 +95,7 @@ def select_cond_path(mode): path = os.path.join(path, mode) onlyfiles = [f for f in sorted(os.listdir(path))] - selected = widgets.RadioButtons( - options=onlyfiles, - description='Select conditioning:', - disabled=False - ) + selected = widgets.RadioButtons(options=onlyfiles, description="Select conditioning:", disabled=False) display(selected) selected_path = os.path.join(path, selected.value) return selected_path @@ -113,9 +110,9 @@ def get_cond(mode, selected_path): c = Image.open(selected_path) c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True) - c_up = rearrange(c_up, '1 c h w -> 1 h w c') - c = rearrange(c, '1 c h w -> 1 h w c') - c = 2. * c - 1. + c_up = rearrange(c_up, "1 c h w -> 1 h w c") + c = rearrange(c, "1 c h w -> 1 h w c") + c = 2.0 * c - 1.0 device = choose_torch_device() c = c.to(device) @@ -130,7 +127,6 @@ def visualize_cond_img(path): def run(model, selected_path, task, custom_steps, resize_enabled=False, classifier_ckpt=None, global_step=None): - example = get_cond(task, selected_path) save_intermediate_vid = False @@ -138,10 +134,10 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi masked = False guider = None ckwargs = None - mode = 'ddim' + mode = "ddim" ddim_use_x0_pred = False - temperature = 1. - eta = 1. + temperature = 1.0 + eta = 1.0 make_progrow = True custom_shape = None @@ -152,14 +148,17 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi ks = 128 stride = 64 vqf = 4 # - model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), - "vqf": vqf, - "patch_distributed_vq": True, - "tie_braker": False, - "clip_max_weight": 0.5, - "clip_min_weight": 0.01, - "clip_max_tie_weight": 0.5, - "clip_min_tie_weight": 0.01} + model.split_input_params = { + "ks": (ks, ks), + "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01, + } else: if hasattr(model, "split_input_params"): delattr(model, "split_input_params") @@ -170,53 +169,112 @@ def run(model, selected_path, task, custom_steps, resize_enabled=False, classifi for n in range(n_runs): if custom_shape is not None: x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) - x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0]) + x_T = repeat(x_T, "1 c h w -> b c h w", b=custom_shape[0]) - logs = make_convolutional_sample(example, model, - mode=mode, custom_steps=custom_steps, - eta=eta, swap_mode=False , masked=masked, - invert_mask=invert_mask, quantize_x0=False, - custom_schedule=None, decode_interval=10, - resize_enabled=resize_enabled, custom_shape=custom_shape, - temperature=temperature, noise_dropout=0., - corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid, - make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred - ) + logs = make_convolutional_sample( + example, + model, + mode=mode, + custom_steps=custom_steps, + eta=eta, + swap_mode=False, + masked=masked, + invert_mask=invert_mask, + quantize_x0=False, + custom_schedule=None, + decode_interval=10, + resize_enabled=resize_enabled, + custom_shape=custom_shape, + temperature=temperature, + noise_dropout=0.0, + corrector=guider, + corrector_kwargs=ckwargs, + x_T=x_T, + save_intermediate_vid=save_intermediate_vid, + make_progrow=make_progrow, + ddim_use_x0_pred=ddim_use_x0_pred, + ) return logs @torch.no_grad() -def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, - mask=None, x0=None, quantize_x0=False, img_callback=None, - temperature=1., noise_dropout=0., score_corrector=None, - corrector_kwargs=None, x_T=None, log_every_t=None - ): - +def convsample_ddim( + model, + cond, + steps, + shape, + eta=1.0, + callback=None, + normals_sequence=None, + mask=None, + x0=None, + quantize_x0=False, + img_callback=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + x_T=None, + log_every_t=None, +): ddim = DDIMSampler(model) bs = shape[0] # dont know where this comes from but wayne shape = shape[1:] # cut batch dim print(f"Sampling with eta = {eta}; steps: {steps}") - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, - normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, - mask=mask, x0=x0, temperature=temperature, verbose=False, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, x_T=x_T) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + conditioning=cond, + callback=callback, + normals_sequence=normals_sequence, + quantize_x0=quantize_x0, + eta=eta, + mask=mask, + x0=x0, + temperature=temperature, + verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False, - invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000, - resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, - corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False): +def make_convolutional_sample( + batch, + model, + mode="vanilla", + custom_steps=None, + eta=1.0, + swap_mode=False, + masked=False, + invert_mask=True, + quantize_x0=False, + custom_schedule=None, + decode_interval=1000, + resize_enabled=False, + custom_shape=None, + temperature=1.0, + noise_dropout=0.0, + corrector=None, + corrector_kwargs=None, + x_T=None, + save_intermediate_vid=False, + make_progrow=True, + ddim_use_x0_pred=False, +): log = dict() - z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=not (hasattr(model, 'split_input_params') - and model.cond_stage_key == 'coordinates_bbox'), - return_original_cond=True) + z, c, x, xrec, xc = model.get_input( + batch, + model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, "split_input_params") and model.cond_stage_key == "coordinates_bbox"), + return_original_cond=True, + ) log_every_t = 1 if save_intermediate_vid else None @@ -231,30 +289,41 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e if ismap(xc): log["original_conditioning"] = model.to_rgb(xc) - if hasattr(model, 'cond_stage_key'): + if hasattr(model, "cond_stage_key"): log[model.cond_stage_key] = model.to_rgb(xc) else: log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) if model.cond_stage_model: log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) - if model.cond_stage_key =='class_label': + if model.cond_stage_key == "class_label": log[model.cond_stage_key] = xc[model.cond_stage_key] with model.ema_scope("Plotting"): t0 = time.time() img_cb = None - sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, - eta=eta, - quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0, - temperature=temperature, noise_dropout=noise_dropout, - score_corrector=corrector, corrector_kwargs=corrector_kwargs, - x_T=x_T, log_every_t=log_every_t) + sample, intermediates = convsample_ddim( + model, + c, + steps=custom_steps, + shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, + img_callback=img_cb, + mask=None, + x0=z0, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + ) t1 = time.time() if ddim_use_x0_pred: - sample = intermediates['pred_x0'][-1] + sample = intermediates["pred_x0"][-1] x_sample = model.decode_first_stage(sample) diff --git a/pyproject.toml b/pyproject.toml index bbb96b58fd..af7feeb33b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -189,3 +189,6 @@ output = "coverage/index.xml" [flake8] max-line-length = 120 + +[tool.black] +line-length = 120 diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 0226fa1c2a..61d32b6df5 100755 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -4,6 +4,6 @@ import warnings from invokeai.frontend.install import invokeai_configure as configure -if __name__ == '__main__': +if __name__ == "__main__": warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning) configure() diff --git a/scripts/controlnet_legacy_txt2img_example.py b/scripts/controlnet_legacy_txt2img_example.py index eb299c9d47..8400cc0290 100644 --- a/scripts/controlnet_legacy_txt2img_example.py +++ b/scripts/controlnet_legacy_txt2img_example.py @@ -28,11 +28,12 @@ canny_image.show() print("loading base model stable-diffusion-1.5") model_config_path = os.getcwd() + "/../configs/models.yaml" model_manager = ModelManager(model_config_path) -model = model_manager.get_model('stable-diffusion-1.5') +model = model_manager.get_model("stable-diffusion-1.5") print("loading control model lllyasviel/sd-controlnet-canny") -canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", - torch_dtype=torch.float16).to("cuda") +canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to( + "cuda" +) print("testing Txt2Img() constructor with control_model arg") txt2img_canny = Txt2Img(model, control_model=canny_controlnet) @@ -49,6 +50,3 @@ outputs = txt2img_canny.generate( generate_output = next(outputs) out_image = generate_output.image out_image.show() - - - diff --git a/scripts/dream.py b/scripts/dream.py index 66c7600c6f..12176db41e 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -3,8 +3,9 @@ import warnings from invokeai.frontend.CLI import invokeai_command_line_interface as main -warnings.warn("dream.py is being deprecated, please run invoke.py for the " - "new UI/API or legacy_api.py for the old API", - DeprecationWarning) -main() +warnings.warn( + "dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API", + DeprecationWarning, +) +main() diff --git a/scripts/images2prompt.py b/scripts/images2prompt.py index 625be83482..058fc0da40 100755 --- a/scripts/images2prompt.py +++ b/scripts/images2prompt.py @@ -1,12 +1,14 @@ #!/usr/bin/env python -'''This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py''' +"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py""" import sys -from PIL import Image,PngImagePlugin +from PIL import Image, PngImagePlugin if len(sys.argv) < 2: print("Usage: file2prompt.py ...") - print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them.") + print( + "This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them." + ) exit(-1) filenames = sys.argv[1:] @@ -14,17 +16,13 @@ for f in filenames: try: im = Image.open(f) try: - prompt = im.text['Dream'] + prompt = im.text["Dream"] except KeyError: - prompt = '' - print(f'{f}: {prompt}') + prompt = "" + print(f"{f}: {prompt}") except FileNotFoundError: - sys.stderr.write(f'{f} not found\n') + sys.stderr.write(f"{f} not found\n") continue except PermissionError: - sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n') + sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n") continue - - - - diff --git a/scripts/invokeai-cli.py b/scripts/invokeai-cli.py index bb67d9d540..b32a892261 100755 --- a/scripts/invokeai-cli.py +++ b/scripts/invokeai-cli.py @@ -3,18 +3,22 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import logging -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) import os import sys + def main(): # Change working directory to the repo root os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) # TODO: Parse some top-level args here. from invokeai.app.cli_app import invoke_cli + invoke_cli() -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/scripts/invokeai-model-install.py b/scripts/invokeai-model-install.py index 97bb499812..b4b0dbc8ef 100644 --- a/scripts/invokeai-model-install.py +++ b/scripts/invokeai-model-install.py @@ -1,3 +1,3 @@ from invokeai.frontend.install.model_install import main -main() +main() diff --git a/scripts/invokeai-web.py b/scripts/invokeai-web.py index 9ac7ee5cb9..829cc4b911 100755 --- a/scripts/invokeai-web.py +++ b/scripts/invokeai-web.py @@ -3,18 +3,21 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import logging -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + +logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage()) import os import sys + def main(): # Change working directory to the repo root os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from invokeai.app.api_app import invoke_api + invoke_api() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/make_models_markdown_table.py b/scripts/make_models_markdown_table.py index 128ced371d..8e7c528a46 100755 --- a/scripts/make_models_markdown_table.py +++ b/scripts/make_models_markdown_table.py @@ -1,23 +1,24 @@ #!/usr/bin/env python -''' +""" This script is used at release time to generate a markdown table describing the starter models. This text is then manually copied into 050_INSTALL_MODELS.md. -''' +""" from omegaconf import OmegaConf from pathlib import Path def main(): - initial_models_file = Path(__file__).parent / '../invokeai/configs/INITIAL_MODELS.yaml' + initial_models_file = Path(__file__).parent / "../invokeai/configs/INITIAL_MODELS.yaml" models = OmegaConf.load(initial_models_file) - print('|Model Name | HuggingFace Repo ID | Description | URL |') - print('|---------- | ---------- | ----------- | --- |') + print("|Model Name | HuggingFace Repo ID | Description | URL |") + print("|---------- | ---------- | ----------- | --- |") for model in models: repo_id = models[model].repo_id - url = f'https://huggingface.co/{repo_id}' - print(f'|{model}|{repo_id}|{models[model].description}|{url} |') + url = f"https://huggingface.co/{repo_id}" + print(f"|{model}|{repo_id}|{models[model].description}|{url} |") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/scripts/orig_scripts/img2img.py b/scripts/orig_scripts/img2img.py index 9f74f25bf2..2601bc2562 100644 --- a/scripts/orig_scripts/img2img.py +++ b/scripts/orig_scripts/img2img.py @@ -18,7 +18,7 @@ from pytorch_lightning import seed_everything from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -from ldm.invoke.devices import choose_torch_device +from ldm.invoke.devices import choose_torch_device def chunk(it, size): @@ -55,7 +55,7 @@ def load_img(path): image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) - return 2.*image - 1. + return 2.0 * image - 1.0 def main(): @@ -66,33 +66,24 @@ def main(): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) - parser.add_argument( - "--init-img", - type=str, - nargs="?", - help="path to the input image" - ) + parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image") parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/img2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) parser.add_argument( "--skip_save", - action='store_true', + action="store_true", help="do not save indiviual samples. For speed measurements.", ) @@ -105,12 +96,12 @@ def main(): parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across all samples ", ) @@ -187,11 +178,7 @@ def main(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) opt = parser.parse_args() @@ -232,18 +219,18 @@ def main(): assert os.path.isfile(opt.init_img) init_image = load_img(opt.init_img).to(device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False) - assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]' + assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]" t_enc = int(opt.strength * opt.ddim_steps) print(f"target t_enc is {t_enc} steps") precision_scope = autocast if opt.precision == "autocast" else nullcontext - if device.type in ['mps', 'cpu']: - precision_scope = nullcontext # have to use f32 on mps + if device.type in ["mps", "cpu"]: + precision_scope = nullcontext # have to use f32 on mps with torch.no_grad(): with precision_scope(device.type): with model.ema_scope(): @@ -259,37 +246,42 @@ def main(): c = model.get_learned_conditioning(prompts) # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device)) # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc,) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + ) x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if not opt.skip_save: for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 all_samples.append(x_samples) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 toc = time.time() - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": diff --git a/scripts/orig_scripts/inpaint.py b/scripts/orig_scripts/inpaint.py index b8245db322..8dfbbfb045 100644 --- a/scripts/orig_scripts/inpaint.py +++ b/scripts/orig_scripts/inpaint.py @@ -8,25 +8,26 @@ from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.invoke.devices import choose_torch_device + def make_batch(image, mask, device): image = np.array(Image.open(image).convert("RGB")) - image = image.astype(np.float32)/255.0 - image = image[None].transpose(0,3,1,2) + image = image.astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) mask = np.array(Image.open(mask).convert("L")) - mask = mask.astype(np.float32)/255.0 - mask = mask[None,None] + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 mask = torch.from_numpy(mask) - masked_image = (1-mask)*image + masked_image = (1 - mask) * image batch = {"image": image, "mask": mask, "masked_image": masked_image} for k in batch: batch[k] = batch[k].to(device=device) - batch[k] = batch[k]*2.0-1.0 + batch[k] = batch[k] * 2.0 - 1.0 return batch @@ -58,11 +59,10 @@ if __name__ == "__main__": config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") model = instantiate_from_config(config.model) - model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], - strict=False) + model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False) - device = choose_torch_device() - model = model.to(device) + device = choose_torch_device() + model = model.to(device) sampler = DDIMSampler(model) os.makedirs(opt.outdir, exist_ok=True) @@ -74,25 +74,19 @@ if __name__ == "__main__": # encode masked image and concat downsampled mask c = model.cond_stage_model.encode(batch["masked_image"]) - cc = torch.nn.functional.interpolate(batch["mask"], - size=c.shape[-2:]) + cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:]) c = torch.cat((c, cc), dim=1) - shape = (c.shape[1]-1,)+c.shape[2:] - samples_ddim, _ = sampler.sample(S=opt.steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False) + shape = (c.shape[1] - 1,) + c.shape[2:] + samples_ddim, _ = sampler.sample( + S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False + ) x_samples_ddim = model.decode_first_stage(samples_ddim) - image = torch.clamp((batch["image"]+1.0)/2.0, - min=0.0, max=1.0) - mask = torch.clamp((batch["mask"]+1.0)/2.0, - min=0.0, max=1.0) - predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, - min=0.0, max=1.0) + image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0) + mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0) + predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - inpainted = (1-mask)*image+mask*predicted_image - inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 + inpainted = (1 - mask) * image + mask * predicted_image + inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255 Image.fromarray(inpainted.astype(np.uint8)).save(outpath) diff --git a/scripts/orig_scripts/knn2img.py b/scripts/orig_scripts/knn2img.py index e6eaaecab5..845613479b 100644 --- a/scripts/orig_scripts/knn2img.py +++ b/scripts/orig_scripts/knn2img.py @@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False): class Searcher(object): - def __init__(self, database, retriever_version='ViT-L/14'): + def __init__(self, database, retriever_version="ViT-L/14"): assert database in DATABASES # self.database = self.load_database(database) self.database_name = database - self.searcher_savedir = f'data/rdm/searchers/{self.database_name}' - self.database_path = f'data/rdm/retrieval_databases/{self.database_name}' + self.searcher_savedir = f"data/rdm/searchers/{self.database_name}" + self.database_path = f"data/rdm/retrieval_databases/{self.database_name}" self.retriever = self.load_retriever(version=retriever_version) - self.database = {'embedding': [], - 'img_id': [], - 'patch_coords': []} + self.database = {"embedding": [], "img_id": [], "patch_coords": []} self.load_database() self.load_searcher() - def train_searcher(self, k, - metric='dot_product', - searcher_savedir=None): - - print('Start training searcher') - searcher = scann.scann_ops_pybind.builder(self.database['embedding'] / - np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis], - k, metric) + def train_searcher(self, k, metric="dot_product", searcher_savedir=None): + print("Start training searcher") + searcher = scann.scann_ops_pybind.builder( + self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric + ) self.searcher = searcher.score_brute_force().build() - print('Finish training searcher') + print("Finish training searcher") if searcher_savedir is not None: print(f'Save trained searcher under "{searcher_savedir}"') @@ -91,36 +86,40 @@ class Searcher(object): def load_single_file(self, saved_embeddings): compressed = np.load(saved_embeddings) self.database = {key: compressed[key] for key in compressed.files} - print('Finished loading of clip embeddings.') + print("Finished loading of clip embeddings.") def load_multi_files(self, data_archive): out_data = {key: [] for key in self.database} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: out_data[key].append(d[key]) return out_data def load_database(self): - print(f'Load saved patch embedding from "{self.database_path}"') - file_content = glob.glob(os.path.join(self.database_path, '*.npz')) + file_content = glob.glob(os.path.join(self.database_path, "*.npz")) if len(file_content) == 1: self.load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(self.load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in - self.database} + self.database = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database + } else: raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?') print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.') - def load_retriever(self, version='ViT-L/14', ): + def load_retriever( + self, + version="ViT-L/14", + ): model = FrozenClipImageEmbedder(model=version) if torch.cuda.is_available(): model.cuda() @@ -128,14 +127,14 @@ class Searcher(object): return model def load_searcher(self): - print(f'load searcher for database {self.database_name} from {self.searcher_savedir}') + print(f"load searcher for database {self.database_name} from {self.searcher_savedir}") self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir) - print('Finished loading searcher.') + print("Finished loading searcher.") def search(self, x, k): - if self.searcher is None and self.database['embedding'].shape[0] < 2e4: - self.train_searcher(k) # quickly fit searcher on the fly for small databases - assert self.searcher is not None, 'Cannot search with uninitialized searcher' + if self.searcher is None and self.database["embedding"].shape[0] < 2e4: + self.train_searcher(k) # quickly fit searcher on the fly for small databases + assert self.searcher is not None, "Cannot search with uninitialized searcher" if isinstance(x, torch.Tensor): x = x.detach().cpu().numpy() if len(x.shape) == 3: @@ -146,17 +145,19 @@ class Searcher(object): nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k) end = time.time() - out_embeddings = self.database['embedding'][nns] - out_img_ids = self.database['img_id'][nns] - out_pc = self.database['patch_coords'][nns] + out_embeddings = self.database["embedding"][nns] + out_img_ids = self.database["img_id"][nns] + out_pc = self.database["patch_coords"][nns] - out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], - 'img_ids': out_img_ids, - 'patch_coords': out_pc, - 'queries': x, - 'exec_time': end - start, - 'nns': nns, - 'q_embeddings': query_embeddings} + out = { + "nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis], + "img_ids": out_img_ids, + "patch_coords": out_pc, + "queries": x, + "exec_time": end - start, + "nns": nns, + "q_embeddings": query_embeddings, + } return out @@ -173,20 +174,16 @@ if __name__ == "__main__": type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) @@ -206,7 +203,7 @@ if __name__ == "__main__": parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) @@ -287,14 +284,14 @@ if __name__ == "__main__": parser.add_argument( "--database", type=str, - default='artbench-surrealism', + default="artbench-surrealism", choices=DATABASES, help="The database used for the search, only applied when --use_neighbors=True", ) parser.add_argument( "--use_neighbors", default=False, - action='store_true', + action="store_true", help="Include neighbors in addition to text prompt for conditioning", ) parser.add_argument( @@ -358,41 +355,43 @@ if __name__ == "__main__": uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) - c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): prompts = list(prompts) shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=c.shape[0], - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - ) + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=c.shape[0], + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + ) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 all_samples.append(x_samples_ddim) if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.") diff --git a/scripts/orig_scripts/main.py b/scripts/orig_scripts/main.py index 6a88f84380..8269809fbc 100644 --- a/scripts/orig_scripts/main.py +++ b/scripts/orig_scripts/main.py @@ -25,15 +25,19 @@ from pytorch_lightning.utilities import rank_zero_info from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config + def fix_func(orig): - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + def new_func(*args, **kw): device = kw.get("device", "mps") - kw["device"]="cpu" + kw["device"] = "cpu" return orig(*args, **kw).to(device) + return new_func return orig + torch.rand = fix_func(torch.rand) torch.rand_like = fix_func(torch.rand_like) torch.randn = fix_func(torch.randn) @@ -43,18 +47,19 @@ torch.randint_like = fix_func(torch.randint_like) torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) + def load_model_from_config(config, ckpt, verbose=False): - print(f'Loading model from {ckpt}') - pl_sd = torch.load(ckpt, map_location='cpu') - sd = pl_sd['state_dict'] + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] config.model.params.ckpt_path = ckpt model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: - print('missing keys:') + print("missing keys:") print(m) if len(u) > 0 and verbose: - print('unexpected keys:') + print("unexpected keys:") print(u) if torch.cuda.is_available(): @@ -66,132 +71,130 @@ def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( - '-n', - '--name', + "-n", + "--name", type=str, const=True, - default='', - nargs='?', - help='postfix for logdir', + default="", + nargs="?", + help="postfix for logdir", ) parser.add_argument( - '-r', - '--resume', + "-r", + "--resume", type=str, const=True, - default='', - nargs='?', - help='resume from logdir or checkpoint in logdir', + default="", + nargs="?", + help="resume from logdir or checkpoint in logdir", ) parser.add_argument( - '-b', - '--base', - nargs='*', - metavar='base_config.yaml', - help='paths to base configs. Loaded from left-to-right. ' - 'Parameters can be overwritten or added with command-line options of the form `--key value`.', + "-b", + "--base", + nargs="*", + metavar="base_config.yaml", + help="paths to base configs. Loaded from left-to-right. " + "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( - '-t', - '--train', + "-t", + "--train", type=str2bool, const=True, default=False, - nargs='?', - help='train', + nargs="?", + help="train", ) parser.add_argument( - '--no-test', + "--no-test", type=str2bool, const=True, default=False, - nargs='?', - help='disable test', + nargs="?", + help="disable test", ) + parser.add_argument("-p", "--project", help="name of new or path to existing project") parser.add_argument( - '-p', '--project', help='name of new or path to existing project' - ) - parser.add_argument( - '-d', - '--debug', + "-d", + "--debug", type=str2bool, - nargs='?', + nargs="?", const=True, default=False, - help='enable post-mortem debugging', + help="enable post-mortem debugging", ) parser.add_argument( - '-s', - '--seed', + "-s", + "--seed", type=int, default=23, - help='seed for seed_everything', + help="seed for seed_everything", ) parser.add_argument( - '-f', - '--postfix', + "-f", + "--postfix", type=str, - default='', - help='post-postfix for default name', + default="", + help="post-postfix for default name", ) parser.add_argument( - '-l', - '--logdir', + "-l", + "--logdir", type=str, - default='logs', - help='directory for logging dat shit', + default="logs", + help="directory for logging dat shit", ) parser.add_argument( - '--scale_lr', + "--scale_lr", type=str2bool, - nargs='?', + nargs="?", const=True, default=True, - help='scale base-lr by ngpu * batch_size * n_accumulate', + help="scale base-lr by ngpu * batch_size * n_accumulate", ) parser.add_argument( - '--datadir_in_name', + "--datadir_in_name", type=str2bool, - nargs='?', + nargs="?", const=True, default=True, - help='Prepend the final directory in the data_root to the output directory name', + help="Prepend the final directory in the data_root to the output directory name", ) parser.add_argument( - '--actual_resume', + "--actual_resume", type=str, - default='', - help='Path to model to actually resume from', + default="", + help="Path to model to actually resume from", ) parser.add_argument( - '--data_root', + "--data_root", type=str, required=True, - help='Path to directory with training images', + help="Path to directory with training images", ) parser.add_argument( - '--embedding_manager_ckpt', + "--embedding_manager_ckpt", type=str, - default='', - help='Initialize embedding manager from a checkpoint', + default="", + help="Initialize embedding manager from a checkpoint", ) parser.add_argument( - '--init_word', + "--init_word", type=str, - help='Word to use as source for initial token embedding.', + help="Word to use as source for initial token embedding.", ) return parser @@ -226,9 +229,7 @@ def worker_init_fn(_): if isinstance(dataset, Txt2ImgIterableBaseDataset): split_size = dataset.num_records // worker_info.num_workers # reset num_records to the true number to retain reliable length information - dataset.sample_ids = dataset.valid_ids[ - worker_id * split_size : (worker_id + 1) * split_size - ] + dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size] current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) else: @@ -252,25 +253,19 @@ class DataModuleFromConfig(pl.LightningDataModule): super().__init__() self.batch_size = batch_size self.dataset_configs = dict() - self.num_workers = ( - num_workers if num_workers is not None else batch_size * 2 - ) + self.num_workers = num_workers if num_workers is not None else batch_size * 2 self.use_worker_init_fn = use_worker_init_fn if train is not None: - self.dataset_configs['train'] = train + self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: - self.dataset_configs['validation'] = validation - self.val_dataloader = partial( - self._val_dataloader, shuffle=shuffle_val_dataloader - ) + self.dataset_configs["validation"] = validation + self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) if test is not None: - self.dataset_configs['test'] = test - self.test_dataloader = partial( - self._test_dataloader, shuffle=shuffle_test_loader - ) + self.dataset_configs["test"] = test + self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) if predict is not None: - self.dataset_configs['predict'] = predict + self.dataset_configs["predict"] = predict self.predict_dataloader = self._predict_dataloader self.wrap = wrap @@ -279,24 +274,19 @@ class DataModuleFromConfig(pl.LightningDataModule): instantiate_from_config(data_cfg) def setup(self, stage=None): - self.datasets = dict( - (k, instantiate_from_config(self.dataset_configs[k])) - for k in self.dataset_configs - ) + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): - is_iterable_dataset = isinstance( - self.datasets['train'], Txt2ImgIterableBaseDataset - ) + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader( - self.datasets['train'], + self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True, @@ -304,15 +294,12 @@ class DataModuleFromConfig(pl.LightningDataModule): ) def _val_dataloader(self, shuffle=False): - if ( - isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) - or self.use_worker_init_fn - ): + if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader( - self.datasets['validation'], + self.datasets["validation"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, @@ -320,9 +307,7 @@ class DataModuleFromConfig(pl.LightningDataModule): ) def _test_dataloader(self, shuffle=False): - is_iterable_dataset = isinstance( - self.datasets['train'], Txt2ImgIterableBaseDataset - ) + is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: @@ -332,7 +317,7 @@ class DataModuleFromConfig(pl.LightningDataModule): shuffle = shuffle and (not is_iterable_dataset) return DataLoader( - self.datasets['test'], + self.datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, @@ -340,15 +325,12 @@ class DataModuleFromConfig(pl.LightningDataModule): ) def _predict_dataloader(self, shuffle=False): - if ( - isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) - or self.use_worker_init_fn - ): + if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader( - self.datasets['predict'], + self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, @@ -356,9 +338,7 @@ class DataModuleFromConfig(pl.LightningDataModule): class SetupCallback(Callback): - def __init__( - self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config - ): + def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): super().__init__() self.resume = resume self.now = now @@ -370,8 +350,8 @@ class SetupCallback(Callback): def on_keyboard_interrupt(self, trainer, pl_module): if trainer.global_rank == 0: - print('Summoning checkpoint.') - ckpt_path = os.path.join(self.ckptdir, 'last.ckpt') + print("Summoning checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def on_pretrain_routine_start(self, trainer, pl_module): @@ -381,36 +361,31 @@ class SetupCallback(Callback): os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) - if 'callbacks' in self.lightning_config: - if ( - 'metrics_over_trainsteps_checkpoint' - in self.lightning_config['callbacks'] - ): + if "callbacks" in self.lightning_config: + if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: os.makedirs( - os.path.join(self.ckptdir, 'trainstep_checkpoints'), + os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True, ) - print('Project config') + print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save( self.config, - os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)), + os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)), ) - print('Lightning config') + print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save( - OmegaConf.create({'lightning': self.lightning_config}), - os.path.join( - self.cfgdir, '{}-lightning.yaml'.format(self.now) - ), + OmegaConf.create({"lightning": self.lightning_config}), + os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), ) else: # ModelCheckpoint callback created log directory --- remove it if not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) - dst = os.path.join(dst, 'child_runs', name) + dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) @@ -435,10 +410,8 @@ class ImageLogger(Callback): self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images - self.logger_log_images = { } - self.log_steps = [ - 2**n for n in range(int(np.log2(self.batch_freq)) + 1) - ] + self.logger_log_images = {} + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp @@ -448,10 +421,8 @@ class ImageLogger(Callback): self.log_first_step = log_first_step @rank_zero_only - def log_local( - self, save_dir, split, images, global_step, current_epoch, batch_idx - ): - root = os.path.join(save_dir, 'images', split) + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: @@ -459,22 +430,16 @@ class ImageLogger(Callback): grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) - filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format( - k, global_step, current_epoch, batch_idx - ) + filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) - def log_img(self, pl_module, batch, batch_idx, split='train'): - check_idx = ( - batch_idx if self.log_on_batch_idx else pl_module.global_step - ) + def log_img(self, pl_module, batch, batch_idx, split="train"): + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step if ( self.check_frequency(check_idx) - and hasattr( # batch_idx % self.batch_freq == 0 - pl_module, 'log_images' - ) + and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 and callable(pl_module.log_images) and self.max_images > 0 ): @@ -485,9 +450,7 @@ class ImageLogger(Callback): pl_module.eval() with torch.no_grad(): - images = pl_module.log_images( - batch, split=split, **self.log_images_kwargs - ) + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) for k in images: N = min(images[k].shape[0], self.max_images) @@ -506,18 +469,16 @@ class ImageLogger(Callback): batch_idx, ) - logger_log_images = self.logger_log_images.get( - logger, lambda *args, **kwargs: None - ) + logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, check_idx): - if ( - (check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps) - ) and (check_idx > 0 or self.log_first_step): + if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( + check_idx > 0 or self.log_first_step + ): try: self.log_steps.pop(0) except IndexError as e: @@ -526,23 +487,15 @@ class ImageLogger(Callback): return True return False - def on_train_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None - ): - if not self.disabled and ( - pl_module.global_step > 0 or self.log_first_step - ): - self.log_img(pl_module, batch, batch_idx, split='train') + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") - def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None - ): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): if not self.disabled and pl_module.global_step > 0: - self.log_img(pl_module, batch, batch_idx, split='val') - if hasattr(pl_module, 'calibrate_grad_norm'): - if ( - pl_module.calibrate_grad_norm and batch_idx % 25 == 0 - ) and batch_idx > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if hasattr(pl_module, "calibrate_grad_norm"): + if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) @@ -562,19 +515,17 @@ class CUDACallback(Callback): try: epoch_time = trainer.training_type_plugin.reduce(epoch_time) - rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds') + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") if torch.cuda.is_available(): - max_memory = ( - torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 - ) + max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20 max_memory = trainer.training_type_plugin.reduce(max_memory) - rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB') + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") except AttributeError: pass -class ModeSwapCallback(Callback): +class ModeSwapCallback(Callback): def __init__(self, swap_step=2000): super().__init__() self.is_frozen = False @@ -589,7 +540,8 @@ class ModeSwapCallback(Callback): self.is_frozen = False trainer.optimizers = [pl_module.configure_opt_model()] -if __name__ == '__main__': + +if __name__ == "__main__": # custom parser to specify config files, train, test and debug mode, # postfix, resume. # `--key value` arguments are interpreted as arguments to the trainer. @@ -631,7 +583,7 @@ if __name__ == '__main__': # params: # key: value - now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S') + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` @@ -644,50 +596,47 @@ if __name__ == '__main__': opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( - '-n/--name and -r/--resume cannot be specified both.' - 'If you want to resume training in a new log folder, ' - 'use -n/--name in combination with --resume_from_checkpoint' + "-n/--name and -r/--resume cannot be specified both." + "If you want to resume training in a new log folder, " + "use -n/--name in combination with --resume_from_checkpoint" ) if opt.resume: if not os.path.exists(opt.resume): - raise ValueError('Cannot find {}'.format(opt.resume)) + raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): - paths = opt.resume.split('/') + paths = opt.resume.split("/") # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) - logdir = '/'.join(paths[:-2]) + logdir = "/".join(paths[:-2]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume - logdir = opt.resume.rstrip('/') - ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt') + logdir = opt.resume.rstrip("/") + ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") opt.resume_from_checkpoint = ckpt - base_configs = sorted( - glob.glob(os.path.join(logdir, 'configs/*.yaml')) - ) + base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base - _tmp = logdir.split('/') + _tmp = logdir.split("/") nowname = _tmp[-1] else: if opt.name: - name = '_' + opt.name + name = "_" + opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] - name = '_' + cfg_name + name = "_" + cfg_name else: - name = '' + name = "" if opt.datadir_in_name: now = os.path.basename(os.path.normpath(opt.data_root)) + now - nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) - ckptdir = os.path.join(logdir, 'checkpoints') - cfgdir = os.path.join(logdir, 'configs') + ckptdir = os.path.join(logdir, "checkpoints") + cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed) try: @@ -695,19 +644,19 @@ if __name__ == '__main__': configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) - lightning_config = config.pop('lightning', OmegaConf.create()) + lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config - trainer_config = lightning_config.get('trainer', OmegaConf.create()) + trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp - trainer_config['accelerator'] = 'auto' + trainer_config["accelerator"] = "auto" for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) - if not 'gpus' in trainer_config: - del trainer_config['accelerator'] + if not "gpus" in trainer_config: + del trainer_config["accelerator"] cpu = True else: - gpuinfo = trainer_config['gpus'] - print(f'Running on GPUs {gpuinfo}') + gpuinfo = trainer_config["gpus"] + print(f"Running on GPUs {gpuinfo}") cpu = False trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config @@ -715,9 +664,7 @@ if __name__ == '__main__': # model # config.model.params.personalization_config.params.init_word = opt.init_word - config.model.params.personalization_config.params.embedding_manager_ckpt = ( - opt.embedding_manager_ckpt - ) + config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt if opt.init_word: config.model.params.personalization_config.params.initializer_words = [opt.init_word] @@ -731,142 +678,128 @@ if __name__ == '__main__': trainer_kwargs = dict() # default logger configs - def_logger = 'csv' - def_logger_target = 'CSVLogger' + def_logger = "csv" + def_logger_target = "CSVLogger" default_logger_cfgs = { - 'wandb': { - 'target': 'pytorch_lightning.loggers.WandbLogger', - 'params': { - 'name': nowname, - 'save_dir': logdir, - 'offline': opt.debug, - 'id': nowname, + "wandb": { + "target": "pytorch_lightning.loggers.WandbLogger", + "params": { + "name": nowname, + "save_dir": logdir, + "offline": opt.debug, + "id": nowname, }, }, def_logger: { - 'target': 'pytorch_lightning.loggers.' + def_logger_target, - 'params': { - 'name': def_logger, - 'save_dir': logdir, + "target": "pytorch_lightning.loggers." + def_logger_target, + "params": { + "name": def_logger, + "save_dir": logdir, }, }, } default_logger_cfg = default_logger_cfgs[def_logger] - if 'logger' in lightning_config: + if "logger" in lightning_config: logger_cfg = lightning_config.logger else: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) - trainer_kwargs['logger'] = instantiate_from_config(logger_cfg) + trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { - 'target': 'pytorch_lightning.callbacks.ModelCheckpoint', - 'params': { - 'dirpath': ckptdir, - 'filename': '{epoch:06}', - 'verbose': True, - 'save_last': True, + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckptdir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, }, } - if hasattr(model, 'monitor'): - print(f'Monitoring {model.monitor} as checkpoint metric.') - default_modelckpt_cfg['params']['monitor'] = model.monitor - default_modelckpt_cfg['params']['save_top_k'] = 1 + if hasattr(model, "monitor"): + print(f"Monitoring {model.monitor} as checkpoint metric.") + default_modelckpt_cfg["params"]["monitor"] = model.monitor + default_modelckpt_cfg["params"]["save_top_k"] = 1 - if 'modelcheckpoint' in lightning_config: + if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) - print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}') - if version.parse(pl.__version__) < version.parse('1.4.0'): - trainer_kwargs['checkpoint_callback'] = instantiate_from_config( - modelckpt_cfg - ) + print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") + if version.parse(pl.__version__) < version.parse("1.4.0"): + trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) # add callback which sets up log directory default_callbacks_cfg = { - 'setup_callback': { - 'target': 'main.SetupCallback', - 'params': { - 'resume': opt.resume, - 'now': now, - 'logdir': logdir, - 'ckptdir': ckptdir, - 'cfgdir': cfgdir, - 'config': config, - 'lightning_config': lightning_config, + "setup_callback": { + "target": "main.SetupCallback", + "params": { + "resume": opt.resume, + "now": now, + "logdir": logdir, + "ckptdir": ckptdir, + "cfgdir": cfgdir, + "config": config, + "lightning_config": lightning_config, }, }, - 'image_logger': { - 'target': 'main.ImageLogger', - 'params': { - 'batch_frequency': 750, - 'max_images': 4, - 'clamp': True, + "image_logger": { + "target": "main.ImageLogger", + "params": { + "batch_frequency": 750, + "max_images": 4, + "clamp": True, }, }, - 'learning_rate_logger': { - 'target': 'main.LearningRateMonitor', - 'params': { - 'logging_interval': 'step', + "learning_rate_logger": { + "target": "main.LearningRateMonitor", + "params": { + "logging_interval": "step", # "log_momentum": True }, }, - 'cuda_callback': {'target': 'main.CUDACallback'}, + "cuda_callback": {"target": "main.CUDACallback"}, } - if version.parse(pl.__version__) >= version.parse('1.4.0'): - default_callbacks_cfg.update( - {'checkpoint_callback': modelckpt_cfg} - ) + if version.parse(pl.__version__) >= version.parse("1.4.0"): + default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) - if 'callbacks' in lightning_config: + if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() - if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: + if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: print( - 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.' + "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." ) default_metrics_over_trainsteps_ckpt_dict = { - 'metrics_over_trainsteps_checkpoint': { - 'target': 'pytorch_lightning.callbacks.ModelCheckpoint', - 'params': { - 'dirpath': os.path.join( - ckptdir, 'trainstep_checkpoints' - ), - 'filename': '{epoch:06}-{step:09}', - 'verbose': True, - 'save_top_k': -1, - 'every_n_train_steps': 10000, - 'save_weights_only': True, + "metrics_over_trainsteps_checkpoint": { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), + "filename": "{epoch:06}-{step:09}", + "verbose": True, + "save_top_k": -1, + "every_n_train_steps": 10000, + "save_weights_only": True, }, } } - default_callbacks_cfg.update( - default_metrics_over_trainsteps_ckpt_dict - ) + default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) - if 'ignore_keys_callback' in callbacks_cfg and hasattr( - trainer_opt, 'resume_from_checkpoint' - ): - callbacks_cfg.ignore_keys_callback.params[ - 'ckpt_path' - ] = trainer_opt.resume_from_checkpoint - elif 'ignore_keys_callback' in callbacks_cfg: - del callbacks_cfg['ignore_keys_callback'] + if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"): + callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint + elif "ignore_keys_callback" in callbacks_cfg: + del callbacks_cfg["ignore_keys_callback"] - trainer_kwargs['callbacks'] = [ - instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg - ] - trainer_kwargs['max_steps'] = trainer_opt.max_steps + trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + trainer_kwargs["max_steps"] = trainer_opt.max_steps - if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - trainer_opt.accelerator = 'mps' + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + trainer_opt.accelerator = "mps" trainer_opt.detect_anomaly = False trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) @@ -882,11 +815,9 @@ if __name__ == '__main__': # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() - print('#### Data #####') + print("#### Data #####") for k in data.datasets: - print( - f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}' - ) + print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") # configure learning rate bs, base_lr = ( @@ -894,24 +825,20 @@ if __name__ == '__main__': config.model.base_learning_rate, ) if not cpu: - gpus = str(lightning_config.trainer.gpus).strip(', ').split(',') + gpus = str(lightning_config.trainer.gpus).strip(", ").split(",") ngpu = len(gpus) else: ngpu = 1 - if 'accumulate_grad_batches' in lightning_config.trainer: - accumulate_grad_batches = ( - lightning_config.trainer.accumulate_grad_batches - ) + if "accumulate_grad_batches" in lightning_config.trainer: + accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 - print(f'accumulate_grad_batches = {accumulate_grad_batches}') - lightning_config.trainer.accumulate_grad_batches = ( - accumulate_grad_batches - ) + print(f"accumulate_grad_batches = {accumulate_grad_batches}") + lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( - 'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format( + "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, @@ -921,15 +848,15 @@ if __name__ == '__main__': ) else: model.learning_rate = base_lr - print('++++ NOT USING LR SCALING ++++') - print(f'Setting learning rate to {model.learning_rate:.2e}') + print("++++ NOT USING LR SCALING ++++") + print(f"Setting learning rate to {model.learning_rate:.2e}") # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: - print('Summoning checkpoint.') - ckpt_path = os.path.join(ckptdir, 'last.ckpt') + print("Summoning checkpoint.") + ckpt_path = os.path.join(ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): @@ -964,7 +891,7 @@ if __name__ == '__main__': # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) - dst = os.path.join(dst, 'debug_runs', name) + dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) # if trainer.global_rank == 0: diff --git a/scripts/orig_scripts/merge_embeddings.py b/scripts/orig_scripts/merge_embeddings.py index 97e72f9128..c2cf1acb1a 100755 --- a/scripts/orig_scripts/merge_embeddings.py +++ b/scripts/orig_scripts/merge_embeddings.py @@ -7,21 +7,30 @@ from functools import partial import torch -def get_placeholder_loop(placeholder_string, embedder, use_bert): - new_placeholder = None +def get_placeholder_loop(placeholder_string, embedder, use_bert): + new_placeholder = None while True: if new_placeholder is None: - new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") + new_placeholder = input( + f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: " + ) else: - new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") + new_placeholder = input( + f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: " + ) - token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder) + token = ( + get_bert_token_for_string(embedder.tknz_fn, new_placeholder) + if use_bert + else get_clip_token_for_string(embedder.tokenizer, new_placeholder) + ) if token is not None: return new_placeholder, token + def get_clip_token_for_string(tokenizer, string): batch_encoding = tokenizer( string, @@ -30,7 +39,7 @@ def get_clip_token_for_string(tokenizer, string): return_length=True, return_overflowing_tokens=False, padding="max_length", - return_tensors="pt" + return_tensors="pt", ) tokens = batch_encoding["input_ids"] @@ -40,6 +49,7 @@ def get_clip_token_for_string(tokenizer, string): return None + def get_bert_token_for_string(tokenizer, string): token = tokenizer(string) if torch.count_nonzero(token) == 3: @@ -49,22 +59,17 @@ def get_bert_token_for_string(tokenizer, string): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument( "--root_dir", type=str, - default='.', - help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'." + default=".", + help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.", ) parser.add_argument( - "--manager_ckpts", - type=str, - nargs="+", - required=True, - help="Paths to a set of embedding managers to be merged." + "--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged." ) parser.add_argument( @@ -75,13 +80,14 @@ if __name__ == "__main__": ) parser.add_argument( - "-sd", "--use_bert", + "-sd", + "--use_bert", action="store_true", - help="Flag to denote that we are not merging stable diffusion embeddings" + help="Flag to denote that we are not merging stable diffusion embeddings", ) args = parser.parse_args() - Globals.root=args.root_dir + Globals.root = args.root_dir if args.use_bert: embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() diff --git a/scripts/orig_scripts/sample_diffusion.py b/scripts/orig_scripts/sample_diffusion.py index 876fe3c364..9f08b6702a 100644 --- a/scripts/orig_scripts/sample_diffusion.py +++ b/scripts/orig_scripts/sample_diffusion.py @@ -10,12 +10,13 @@ from PIL import Image from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config -rescale = lambda x: (x + 1.) / 2. +rescale = lambda x: (x + 1.0) / 2.0 + def custom_to_pil(x): x = x.detach().cpu() - x = torch.clamp(x, -1., 1.) - x = (x + 1.) / 2. + x = torch.clamp(x, -1.0, 1.0) + x = (x + 1.0) / 2.0 x = x.permute(1, 2, 0).numpy() x = (255 * x).astype(np.uint8) x = Image.fromarray(x) @@ -51,49 +52,51 @@ def logs2pil(logs, keys=["sample"]): @torch.no_grad() -def convsample(model, shape, return_intermediates=True, - verbose=True, - make_prog_row=False): - - +def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False): if not make_prog_row: - return model.p_sample_loop(None, shape, - return_intermediates=return_intermediates, verbose=verbose) + return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose) else: - return model.progressive_denoising( - None, shape, verbose=True - ) + return model.progressive_denoising(None, shape, verbose=True) @torch.no_grad() -def convsample_ddim(model, steps, shape, eta=1.0 - ): +def convsample_ddim(model, steps, shape, eta=1.0): ddim = DDIMSampler(model) bs = shape[0] shape = shape[1:] - samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,) + samples, intermediates = ddim.sample( + steps, + batch_size=bs, + shape=shape, + eta=eta, + verbose=False, + ) return samples, intermediates @torch.no_grad() -def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,): - - +def make_convolutional_sample( + model, + batch_size, + vanilla=False, + custom_steps=None, + eta=1.0, +): log = dict() - shape = [batch_size, - model.model.diffusion_model.in_channels, - model.model.diffusion_model.image_size, - model.model.diffusion_model.image_size] + shape = [ + batch_size, + model.model.diffusion_model.in_channels, + model.model.diffusion_model.image_size, + model.model.diffusion_model.image_size, + ] with model.ema_scope("Plotting"): t0 = time.time() if vanilla: - sample, progrow = convsample(model, shape, - make_prog_row=True) + sample, progrow = convsample(model, shape, make_prog_row=True) else: - sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, - eta=eta) + sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta) t1 = time.time() @@ -101,32 +104,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non log["sample"] = x_sample log["time"] = t1 - t0 - log['throughput'] = sample.shape[0] / (t1 - t0) + log["throughput"] = sample.shape[0] / (t1 - t0) print(f'Throughput for this batch: {log["throughput"]}') return log + def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None): if vanilla: - print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.') + print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.") else: - print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}') - + print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}") tstart = time.time() - n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1 + n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1 # path = logdir if model.cond_stage_model is None: all_images = [] print(f"Running unconditional sampling for {n_samples} samples") for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"): - logs = make_convolutional_sample(model, batch_size=batch_size, - vanilla=vanilla, custom_steps=custom_steps, - eta=eta) + logs = make_convolutional_sample( + model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta + ) n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample") all_images.extend([custom_to_np(logs["sample"])]) if n_saved >= n_samples: - print(f'Finish after generating {n_saved} samples') + print(f"Finish after generating {n_saved} samples") break all_img = np.concatenate(all_images, axis=0) all_img = all_img[:n_samples] @@ -135,7 +138,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None np.savez(nppath, all_img) else: - raise NotImplementedError('Currently only sampling for unconditional models supported.') + raise NotImplementedError("Currently only sampling for unconditional models supported.") print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.") @@ -168,58 +171,33 @@ def get_parser(): nargs="?", help="load from logdir or checkpoint in logdir", ) - parser.add_argument( - "-n", - "--n_samples", - type=int, - nargs="?", - help="number of samples to draw", - default=50000 - ) + parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000) parser.add_argument( "-e", "--eta", type=float, nargs="?", help="eta for ddim sampling (0.0 yields deterministic sampling)", - default=1.0 + default=1.0, ) parser.add_argument( "-v", "--vanilla_sample", default=False, - action='store_true', + action="store_true", help="vanilla sampling (default option is DDIM sampling)?", ) + parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none") parser.add_argument( - "-l", - "--logdir", - type=str, - nargs="?", - help="extra logdir", - default="none" - ) - parser.add_argument( - "-c", - "--custom_steps", - type=int, - nargs="?", - help="number of steps for ddim and fastdpm sampling", - default=50 - ) - parser.add_argument( - "--batch_size", - type=int, - nargs="?", - help="the bs", - default=10 + "-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50 ) + parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10) return parser def load_model_from_config(config, sd): model = instantiate_from_config(config) - model.load_state_dict(sd,strict=False) + model.load_state_dict(sd, strict=False) model.cuda() model.eval() return model @@ -233,8 +211,7 @@ def load_model(config, ckpt, gpu, eval_mode): else: pl_sd = {"state_dict": None} global_step = None - model = load_model_from_config(config.model, - pl_sd["state_dict"]) + model = load_model_from_config(config.model, pl_sd["state_dict"]) return model, global_step @@ -253,9 +230,9 @@ if __name__ == "__main__": if os.path.isfile(opt.resume): # paths = opt.resume.split("/") try: - logdir = '/'.join(opt.resume.split('/')[:-1]) + logdir = "/".join(opt.resume.split("/")[:-1]) # idx = len(paths)-paths[::-1].index("logs")+1 - print(f'Logdir is {logdir}') + print(f"Logdir is {logdir}") except ValueError: paths = opt.resume.split("/") idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt @@ -278,7 +255,8 @@ if __name__ == "__main__": if opt.logdir != "none": locallog = logdir.split(os.sep)[-1] - if locallog == "": locallog = logdir.split(os.sep)[-2] + if locallog == "": + locallog = logdir.split(os.sep)[-2] print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'") logdir = os.path.join(opt.logdir, locallog) @@ -301,13 +279,19 @@ if __name__ == "__main__": sampling_file = os.path.join(logdir, "sampling_config.yaml") sampling_conf = vars(opt) - with open(sampling_file, 'w') as f: + with open(sampling_file, "w") as f: yaml.dump(sampling_conf, f, default_flow_style=False) print(sampling_conf) - - run(model, imglogdir, eta=opt.eta, - vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps, - batch_size=opt.batch_size, nplog=numpylogdir) + run( + model, + imglogdir, + eta=opt.eta, + vanilla=opt.vanilla_sample, + n_samples=opt.n_samples, + custom_steps=opt.custom_steps, + batch_size=opt.batch_size, + nplog=numpylogdir, + ) print("done.") diff --git a/scripts/orig_scripts/train_searcher.py b/scripts/orig_scripts/train_searcher.py index 1e7904889c..7ba57f1165 100644 --- a/scripts/orig_scripts/train_searcher.py +++ b/scripts/orig_scripts/train_searcher.py @@ -13,21 +13,26 @@ def search_bruteforce(searcher): return searcher.score_brute_force().build() -def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search): - return searcher.tree(num_leaves=num_leaves, - num_leaves_to_search=num_leaves_to_search, - training_sample_size=partioning_trainsize). \ - score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() +def search_partioned_ah( + searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search +): + return ( + searcher.tree( + num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize + ) + .score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold) + .reorder(reorder_k) + .build() + ) def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): - return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( - reorder_k).build() + return ( + searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() + ) + def load_datapool(dpath): - - def load_single_file(saved_embeddings): compressed = np.load(saved_embeddings) database = {key: compressed[key] for key in compressed.files} @@ -35,23 +40,26 @@ def load_datapool(dpath): def load_multi_files(data_archive): database = {key: [] for key in data_archive[0].files} - for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): + for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."): for key in d.files: database[key].append(d[key]) return database print(f'Load saved patch embedding from "{dpath}"') - file_content = glob.glob(os.path.join(dpath, '*.npz')) + file_content = glob.glob(os.path.join(dpath, "*.npz")) if len(file_content) == 1: data_pool = load_single_file(file_content[0]) elif len(file_content) > 1: data = [np.load(f) for f in file_content] - prefetched_data = parallel_data_prefetch(load_multi_files, data, - n_proc=min(len(data), cpu_count()), target_data_type='dict') + prefetched_data = parallel_data_prefetch( + load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict" + ) - data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} + data_pool = { + key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys() + } else: raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') @@ -59,16 +67,17 @@ def load_datapool(dpath): return data_pool -def train_searcher(opt, - metric='dot_product', - partioning_trainsize=None, - reorder_k=None, - # todo tune - aiq_thld=0.2, - dims_per_block=2, - num_leaves=None, - num_leaves_to_search=None,): - +def train_searcher( + opt, + metric="dot_product", + partioning_trainsize=None, + reorder_k=None, + # todo tune + aiq_thld=0.2, + dims_per_block=2, + num_leaves=None, + num_leaves_to_search=None, +): data_pool = load_datapool(opt.database) k = opt.knn @@ -77,71 +86,83 @@ def train_searcher(opt, # normalize # embeddings = - searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) - pool_size = data_pool['embedding'].shape[0] + searcher = scann.scann_ops_pybind.builder( + data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric + ) + pool_size = data_pool["embedding"].shape[0] - print(*(['#'] * 100)) - print('Initializing scaNN searcher with the following values:') - print(f'k: {k}') - print(f'metric: {metric}') - print(f'reorder_k: {reorder_k}') - print(f'anisotropic_quantization_threshold: {aiq_thld}') - print(f'dims_per_block: {dims_per_block}') - print(*(['#'] * 100)) - print('Start training searcher....') - print(f'N samples in pool is {pool_size}') + print(*(["#"] * 100)) + print("Initializing scaNN searcher with the following values:") + print(f"k: {k}") + print(f"metric: {metric}") + print(f"reorder_k: {reorder_k}") + print(f"anisotropic_quantization_threshold: {aiq_thld}") + print(f"dims_per_block: {dims_per_block}") + print(*(["#"] * 100)) + print("Start training searcher....") + print(f"N samples in pool is {pool_size}") # this reflects the recommended design choices proposed at # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md if pool_size < 2e4: - print('Using brute force search.') + print("Using brute force search.") searcher = search_bruteforce(searcher) elif 2e4 <= pool_size and pool_size < 1e5: - print('Using asymmetric hashing search and reordering.') + print("Using asymmetric hashing search and reordering.") searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) else: - print('Using using partioning, asymmetric hashing search and reordering.') + print("Using using partioning, asymmetric hashing search and reordering.") if not partioning_trainsize: - partioning_trainsize = data_pool['embedding'].shape[0] // 10 + partioning_trainsize = data_pool["embedding"].shape[0] // 10 if not num_leaves: num_leaves = int(np.sqrt(pool_size)) if not num_leaves_to_search: num_leaves_to_search = max(num_leaves // 20, 1) - print('Partitioning params:') - print(f'num_leaves: {num_leaves}') - print(f'num_leaves_to_search: {num_leaves_to_search}') + print("Partitioning params:") + print(f"num_leaves: {num_leaves}") + print(f"num_leaves_to_search: {num_leaves_to_search}") # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) - searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, - partioning_trainsize, num_leaves, num_leaves_to_search) + searcher = search_partioned_ah( + searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search + ) - print('Finish training searcher') + print("Finish training searcher") searcher_savedir = opt.target_path os.makedirs(searcher_savedir, exist_ok=True) searcher.serialize(searcher_savedir) print(f'Saved trained searcher under "{searcher_savedir}"') -if __name__ == '__main__': + +if __name__ == "__main__": sys.path.append(os.getcwd()) parser = argparse.ArgumentParser() - parser.add_argument('--database', - '-d', - default='data/rdm/retrieval_databases/openimages', - type=str, - help='path to folder containing the clip feature of the database') - parser.add_argument('--target_path', - '-t', - default='data/rdm/searchers/openimages', - type=str, - help='path to the target folder where the searcher shall be stored.') - parser.add_argument('--knn', - '-k', - default=20, - type=int, - help='number of nearest neighbors, for which the searcher shall be optimized') + parser.add_argument( + "--database", + "-d", + default="data/rdm/retrieval_databases/openimages", + type=str, + help="path to folder containing the clip feature of the database", + ) + parser.add_argument( + "--target_path", + "-t", + default="data/rdm/searchers/openimages", + type=str, + help="path to the target folder where the searcher shall be stored.", + ) + parser.add_argument( + "--knn", + "-k", + default=20, + type=int, + help="number of nearest neighbors, for which the searcher shall be optimized", + ) - opt, _ = parser.parse_known_args() + opt, _ = parser.parse_known_args() - train_searcher(opt,) \ No newline at end of file + train_searcher( + opt, + ) diff --git a/scripts/orig_scripts/txt2img.py b/scripts/orig_scripts/txt2img.py index 0d350d2c73..58767c122d 100644 --- a/scripts/orig_scripts/txt2img.py +++ b/scripts/orig_scripts/txt2img.py @@ -15,10 +15,11 @@ from contextlib import contextmanager, nullcontext import k_diffusion as K import torch.nn as nn -from ldm.util import instantiate_from_config +from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler -from ldm.invoke.devices import choose_torch_device +from ldm.invoke.devices import choose_torch_device + def chunk(it, size): it = iter(it) @@ -53,23 +54,19 @@ def main(): type=str, nargs="?", default="a painting of a virus monster playing guitar", - help="the prompt to render" + help="the prompt to render", ) parser.add_argument( - "--outdir", - type=str, - nargs="?", - help="dir to write results to", - default="outputs/txt2img-samples" + "--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples" ) parser.add_argument( "--skip_grid", - action='store_true', + action="store_true", help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", ) parser.add_argument( "--skip_save", - action='store_true', + action="store_true", help="do not save individual samples. For speed measurements.", ) parser.add_argument( @@ -80,22 +77,22 @@ def main(): ) parser.add_argument( "--plms", - action='store_true', + action="store_true", help="use plms sampling", ) parser.add_argument( "--klms", - action='store_true', + action="store_true", help="use klms sampling", ) parser.add_argument( "--laion400m", - action='store_true', + action="store_true", help="uses the LAION400M model", ) parser.add_argument( "--fixed_code", - action='store_true', + action="store_true", help="if enabled, uses the same starting code across samples ", ) parser.add_argument( @@ -176,11 +173,7 @@ def main(): help="the seed (for reproducible sampling)", ) parser.add_argument( - "--precision", - type=str, - help="evaluate at this precision", - choices=["full", "autocast"], - default="autocast" + "--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast" ) opt = parser.parse_args() @@ -190,17 +183,17 @@ def main(): opt.ckpt = "models/ldm/text2img-large/model.ckpt" opt.outdir = "outputs/txt2img-samples-laion400m" - config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") seed_everything(opt.seed) device = torch.device(choose_torch_device()) - model = model.to(device) + model = model.to(device) - #for klms + # for klms model_wrap = K.external.CompVisDenoiser(model) + class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() @@ -232,10 +225,10 @@ def main(): print(f"reading prompts from {opt.from_file}") with open(opt.from_file, "r") as f: data = f.read().splitlines() - if (len(data) >= batch_size): + if len(data) >= batch_size: data = list(chunk(data, batch_size)) else: - while (len(data) < batch_size): + while len(data) < batch_size: data.append(data[-1]) data = [data] @@ -247,14 +240,14 @@ def main(): start_code = None if opt.fixed_code: shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f] - if device.type == 'mps': - start_code = torch.randn(shape, device='cpu').to(device) + if device.type == "mps": + start_code = torch.randn(shape, device="cpu").to(device) else: torch.randn(shape, device=device) - precision_scope = autocast if opt.precision=="autocast" else nullcontext - if device.type in ['mps', 'cpu']: - precision_scope = nullcontext # have to use f32 on mps + precision_scope = autocast if opt.precision == "autocast" else nullcontext + if device.type in ["mps", "cpu"]: + precision_scope = nullcontext # have to use f32 on mps with torch.no_grad(): with precision_scope(device.type): with model.ema_scope(): @@ -271,23 +264,25 @@ def main(): shape = [opt.C, opt.H // opt.f, opt.W // opt.f] if not opt.klms: - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=uc, - eta=opt.ddim_eta, - x_T=start_code) + samples_ddim, _ = sampler.sample( + S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code, + ) else: sigmas = model_wrap.get_sigmas(opt.ddim_steps) if start_code: x = start_code else: - x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw + x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw model_wrap_cfg = CFGDenoiser(model_wrap) - extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale} + extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale} samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args) x_samples_ddim = model.decode_first_stage(samples_ddim) @@ -295,9 +290,10 @@ def main(): if not opt.skip_save: for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, f"{base_count:05}.png")) + os.path.join(sample_path, f"{base_count:05}.png") + ) base_count += 1 if not opt.skip_grid: @@ -306,18 +302,17 @@ def main(): if not opt.skip_grid: # additionally, save as grid grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = rearrange(grid, "n b c h w -> (n b) c h w") grid = make_grid(grid, nrow=n_rows) # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png')) + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png")) grid_count += 1 toc = time.time() - print(f"Your samples are ready and waiting for you here: \n{outpath} \n" - f" \nEnjoy.") + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") if __name__ == "__main__": diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 46fd86984c..2c73ea0d79 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -6,6 +6,3 @@ from invokeai.backend.model_management.model_probe import ModelProbe info = ModelProbe().probe(Path(sys.argv[1])) print(info) - - - diff --git a/scripts/pypi_helper.py b/scripts/pypi_helper.py index 4dccd6383d..6c1f9b9033 100755 --- a/scripts/pypi_helper.py +++ b/scripts/pypi_helper.py @@ -5,7 +5,8 @@ import requests from invokeai.version import __version__ local_version = str(__version__).replace("-", "") -package_name = 'InvokeAI' +package_name = "InvokeAI" + def get_pypi_versions(package_name=package_name) -> list[str]: """Get the versions of the package from PyPI""" diff --git a/scripts/scan_models_directory.py b/scripts/scan_models_directory.py index 778a6c5ed5..de7278f5b7 100755 --- a/scripts/scan_models_directory.py +++ b/scripts/scan_models_directory.py @@ -1,8 +1,8 @@ #!/usr/bin/env python -''' +""" Scan the models directory and print out a new models.yaml -''' +""" import os import sys @@ -11,49 +11,51 @@ import argparse from pathlib import Path from omegaconf import OmegaConf + def main(): parser = argparse.ArgumentParser(description="Model directory scanner") - parser.add_argument('models_directory') - parser.add_argument('--all-models', - default=False, - action='store_true', - help='If true, then generates stanzas for all models; otherwise just diffusers' - ) - + parser.add_argument("models_directory") + parser.add_argument( + "--all-models", + default=False, + action="store_true", + help="If true, then generates stanzas for all models; otherwise just diffusers", + ) + args = parser.parse_args() directory = args.models_directory conf = OmegaConf.create() - conf['_version'] = '3.0.0' - + conf["_version"] = "3.0.0" + for root, dirs, files in os.walk(directory): - parents = root.split('/') - subpaths = parents[parents.index('models')+1:] + parents = root.split("/") + subpaths = parents[parents.index("models") + 1 :] if len(subpaths) < 2: continue base, model_type, *_ = subpaths - - if args.all_models or model_type=='diffusers': + + if args.all_models or model_type == "diffusers": for d in dirs: - conf[f'{base}/{model_type}/{d}'] = dict( - path = os.path.join(root,d), - description = f'{model_type} model {d}', - format = 'folder', - base = base, + conf[f"{base}/{model_type}/{d}"] = dict( + path=os.path.join(root, d), + description=f"{model_type} model {d}", + format="folder", + base=base, ) for f in files: basename = Path(f).stem format = Path(f).suffix[1:] - conf[f'{base}/{model_type}/{basename}'] = dict( - path = os.path.join(root,f), - description = f'{model_type} model {basename}', - format = format, - base = base, + conf[f"{base}/{model_type}/{basename}"] = dict( + path=os.path.join(root, f), + description=f"{model_type} model {basename}", + format=format, + base=base, ) - - OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout) - -if __name__ == '__main__': + OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout) + + +if __name__ == "__main__": main() diff --git a/scripts/sd-metadata.py b/scripts/sd-metadata.py index 6f9c73757d..1a27e73d95 100755 --- a/scripts/sd-metadata.py +++ b/scripts/sd-metadata.py @@ -13,10 +13,10 @@ filenames = sys.argv[1:] for f in filenames: try: metadata = retrieve_metadata(f) - print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4)) + print(f"{f}:\n", json.dumps(metadata["sd-metadata"], indent=4)) except FileNotFoundError: - sys.stderr.write(f'{f} not found\n') + sys.stderr.write(f"{f} not found\n") continue except PermissionError: - sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n') + sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n") continue diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index bc4a3f4176..e0ee120b54 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -42,26 +42,24 @@ def simple_graph(): def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations return InvocationServices( - model_manager = None, # type: ignore - events = TestEventService(), - logger = None, # type: ignore - images = None, # type: ignore - latents = None, # type: ignore - boards = None, # type: ignore - board_images = None, # type: ignore - queue = MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=sqlite_memory, table_name="graphs" + model_manager=None, # type: ignore + events=TestEventService(), + logger=None, # type: ignore + images=None, # type: ignore + latents=None, # type: ignore + boards=None, # type: ignore + board_images=None, # type: ignore + queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), + graph_execution_manager=SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" ), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor(), - configuration = None, # type: ignore + processor=DefaultInvocationProcessor(), + configuration=None, # type: ignore ) -def invoke_next( - g: GraphExecutionState, services: InvocationServices -) -> tuple[BaseInvocation, BaseInvocationOutput]: +def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[BaseInvocation, BaseInvocationOutput]: n = g.next() if n is None: return (None, None) @@ -130,9 +128,7 @@ def test_graph_state_expands_iterator(mock_services): def test_graph_state_collects(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node( - PromptCollectionTestInvocation(id="1", collection=list(test_prompts)) - ) + graph.add_node(PromptCollectionTestInvocation(id="1", collection=list(test_prompts))) graph.add_node(IterateInvocation(id="2")) graph.add_node(PromptTestInvocation(id="3")) graph.add_node(CollectInvocation(id="4")) @@ -158,16 +154,10 @@ def test_graph_state_prepares_eagerly(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node( - PromptCollectionTestInvocation( - id="prompt_collection", collection=list(test_prompts) - ) - ) + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) - graph.add_edge( - create_edge("prompt_collection", "collection", "iterate", "collection") - ) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) # separated, fully-preparable chain of nodes @@ -193,21 +183,13 @@ def test_graph_executes_depth_first(mock_services): graph = Graph() test_prompts = ["Banana sushi", "Cat sushi"] - graph.add_node( - PromptCollectionTestInvocation( - id="prompt_collection", collection=list(test_prompts) - ) - ) + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) graph.add_node(IterateInvocation(id="iterate")) graph.add_node(PromptTestInvocation(id="prompt_iterated")) graph.add_node(PromptTestInvocation(id="prompt_successor")) - graph.add_edge( - create_edge("prompt_collection", "collection", "iterate", "collection") - ) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) - graph.add_edge( - create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt") - ) + graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) g = GraphExecutionState(graph=graph) n1 = invoke_next(g, mock_services) diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 4741e7f58b..8eba6d468f 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -35,20 +35,20 @@ def simple_graph(): def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations return InvocationServices( - model_manager = None, # type: ignore - events = TestEventService(), - logger = None, # type: ignore - images = None, # type: ignore - latents = None, # type: ignore - boards = None, # type: ignore - board_images = None, # type: ignore - queue = MemoryInvocationQueue(), - graph_library=SqliteItemStorage[LibraryGraph]( - filename=sqlite_memory, table_name="graphs" + model_manager=None, # type: ignore + events=TestEventService(), + logger=None, # type: ignore + images=None, # type: ignore + latents=None, # type: ignore + boards=None, # type: ignore + board_images=None, # type: ignore + queue=MemoryInvocationQueue(), + graph_library=SqliteItemStorage[LibraryGraph](filename=sqlite_memory, table_name="graphs"), + graph_execution_manager=SqliteItemStorage[GraphExecutionState]( + filename=sqlite_memory, table_name="graph_executions" ), - graph_execution_manager = SqliteItemStorage[GraphExecutionState](filename = sqlite_memory, table_name = 'graph_executions'), - processor = DefaultInvocationProcessor(), - configuration = None, # type: ignore + processor=DefaultInvocationProcessor(), + configuration=None, # type: ignore ) diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index e2a4c8b343..0f893cb14c 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,5 +1,21 @@ -from .test_nodes import ImageToImageTestInvocation, TextToImageTestInvocation, ListPassThroughInvocation, PromptTestInvocation -from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation +from .test_nodes import ( + ImageToImageTestInvocation, + TextToImageTestInvocation, + ListPassThroughInvocation, + PromptTestInvocation, +) +from invokeai.app.services.graph import ( + Edge, + Graph, + GraphInvocation, + InvalidEdgeError, + NodeAlreadyInGraphError, + NodeNotFoundError, + are_connections_compatible, + EdgeConnection, + CollectInvocation, + IterateInvocation, +) from invokeai.app.invocations.upscale import ESRGANInvocation from invokeai.app.invocations.image import * from invokeai.app.invocations.math import AddInvocation, SubtractInvocation @@ -11,35 +27,38 @@ import pytest # Helpers def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge: return Edge( - source=EdgeConnection(node_id = from_id, field = from_field), - destination=EdgeConnection(node_id = to_id, field = to_field) + source=EdgeConnection(node_id=from_id, field=from_field), + destination=EdgeConnection(node_id=to_id, field=to_field), ) + # Tests def test_connections_are_compatible(): - from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id="1", prompt="Banana sushi") from_field = "image" - to_node = ESRGANInvocation(id = "2") + to_node = ESRGANInvocation(id="2") to_field = "image" result = are_connections_compatible(from_node, from_field, to_node, to_field) assert result == True + def test_connections_are_incompatible(): - from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id="1", prompt="Banana sushi") from_field = "image" - to_node = ESRGANInvocation(id = "2") + to_node = ESRGANInvocation(id="2") to_field = "strength" result = are_connections_compatible(from_node, from_field, to_node, to_field) assert result == False + def test_connections_incompatible_with_invalid_fields(): - from_node = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + from_node = TextToImageTestInvocation(id="1", prompt="Banana sushi") from_field = "invalid_field" - to_node = ESRGANInvocation(id = "2") + to_node = ESRGANInvocation(id="2") to_field = "image" # From field is invalid @@ -53,138 +72,152 @@ def test_connections_incompatible_with_invalid_fields(): result = are_connections_compatible(from_node, from_field, to_node, to_field) assert result == False + def test_graph_can_add_node(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) assert n.id in g.nodes + def test_graph_fails_to_add_node_with_duplicate_id(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) - n2 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id="1", prompt="Banana sushi the second") with pytest.raises(NodeAlreadyInGraphError): g.add_node(n2) + def test_graph_updates_node(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) - n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id="2", prompt="Banana sushi the second") g.add_node(n2) - nu = TextToImageTestInvocation(id = "1", prompt = "Banana sushi updated") + nu = TextToImageTestInvocation(id="1", prompt="Banana sushi updated") g.update_node("1", nu) assert g.nodes["1"].prompt == "Banana sushi updated" + def test_graph_fails_to_update_node_if_type_changes(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) - n2 = ESRGANInvocation(id = "2") + n2 = ESRGANInvocation(id="2") g.add_node(n2) - nu = ESRGANInvocation(id = "1") + nu = ESRGANInvocation(id="1") with pytest.raises(TypeError): g.update_node("1", nu) + def test_graph_allows_non_conflicting_id_change(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) - n2 = ESRGANInvocation(id = "2") + n2 = ESRGANInvocation(id="2") g.add_node(n2) - e1 = create_edge(n.id,"image",n2.id,"image") + e1 = create_edge(n.id, "image", n2.id, "image") g.add_edge(e1) - - nu = TextToImageTestInvocation(id = "3", prompt = "Banana sushi") + + nu = TextToImageTestInvocation(id="3", prompt="Banana sushi") g.update_node("1", nu) with pytest.raises(NodeNotFoundError): g.get_node("1") - + assert g.get_node("3").prompt == "Banana sushi" assert len(g.edges) == 1 - assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges + assert ( + Edge(source=EdgeConnection(node_id="3", field="image"), destination=EdgeConnection(node_id="2", field="image")) + in g.edges + ) + def test_graph_fails_to_update_node_id_if_conflict(): g = Graph() - n = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.add_node(n) - n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi the second") + n2 = TextToImageTestInvocation(id="2", prompt="Banana sushi the second") g.add_node(n2) - - nu = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") + + nu = TextToImageTestInvocation(id="2", prompt="Banana sushi") with pytest.raises(NodeAlreadyInGraphError): g.update_node("1", nu) + def test_graph_adds_edge(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") + e = create_edge(n1.id, "image", n2.id, "image") g.add_edge(e) assert e in g.edges + def test_graph_fails_to_add_edge_with_cycle(): g = Graph() - n1 = ESRGANInvocation(id = "1") + n1 = ESRGANInvocation(id="1") g.add_node(n1) - e = create_edge(n1.id,"image",n1.id,"image") + e = create_edge(n1.id, "image", n1.id, "image") with pytest.raises(InvalidEdgeError): g.add_edge(e) + def test_graph_fails_to_add_edge_with_long_cycle(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") - n3 = ESRGANInvocation(id = "3") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") + n3 = ESRGANInvocation(id="3") g.add_node(n1) g.add_node(n2) g.add_node(n3) - e1 = create_edge(n1.id,"image",n2.id,"image") - e2 = create_edge(n2.id,"image",n3.id,"image") - e3 = create_edge(n3.id,"image",n2.id,"image") + e1 = create_edge(n1.id, "image", n2.id, "image") + e2 = create_edge(n2.id, "image", n3.id, "image") + e3 = create_edge(n3.id, "image", n2.id, "image") g.add_edge(e1) g.add_edge(e2) with pytest.raises(InvalidEdgeError): g.add_edge(e3) + def test_graph_fails_to_add_edge_with_missing_node_id(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e1 = create_edge("1","image","3","image") - e2 = create_edge("3","image","1","image") + e1 = create_edge("1", "image", "3", "image") + e2 = create_edge("3", "image", "1", "image") with pytest.raises(InvalidEdgeError): g.add_edge(e1) with pytest.raises(InvalidEdgeError): g.add_edge(e2) + def test_graph_fails_to_add_edge_when_destination_exists(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") - n3 = ESRGANInvocation(id = "3") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") + n3 = ESRGANInvocation(id="3") g.add_node(n1) g.add_node(n2) g.add_node(n3) - e1 = create_edge(n1.id,"image",n2.id,"image") - e2 = create_edge(n1.id,"image",n3.id,"image") - e3 = create_edge(n2.id,"image",n3.id,"image") + e1 = create_edge(n1.id, "image", n2.id, "image") + e2 = create_edge(n1.id, "image", n3.id, "image") + e3 = create_edge(n2.id, "image", n3.id, "image") g.add_edge(e1) g.add_edge(e2) with pytest.raises(InvalidEdgeError): @@ -193,208 +226,223 @@ def test_graph_fails_to_add_edge_when_destination_exists(): def test_graph_fails_to_add_edge_with_mismatched_types(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e1 = create_edge("1","image","2","strength") + e1 = create_edge("1", "image", "2", "strength") with pytest.raises(InvalidEdgeError): g.add_edge(e1) + def test_graph_connects_collector(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = ListPassThroughInvocation(id = "4") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = TextToImageTestInvocation(id="2", prompt="Banana sushi 2") + n3 = CollectInvocation(id="3") + n4 = ListPassThroughInvocation(id="4") g.add_node(n1) g.add_node(n2) g.add_node(n3) g.add_node(n4) - e1 = create_edge("1","image","3","item") - e2 = create_edge("2","image","3","item") - e3 = create_edge("3","collection","4","collection") + e1 = create_edge("1", "image", "3", "item") + e2 = create_edge("2", "image", "3", "item") + e3 = create_edge("3", "collection", "4", "collection") g.add_edge(e1) g.add_edge(e2) g.add_edge(e3) + # TODO: test that derived types mixed with base types are compatible + def test_graph_collector_invalid_with_varying_input_types(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "banana sushi 2") - n3 = CollectInvocation(id = "3") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = PromptTestInvocation(id="2", prompt="banana sushi 2") + n3 = CollectInvocation(id="3") g.add_node(n1) g.add_node(n2) g.add_node(n3) - e1 = create_edge("1","image","3","item") - e2 = create_edge("2","prompt","3","item") + e1 = create_edge("1", "image", "3", "item") + e2 = create_edge("2", "prompt", "3", "item") g.add_edge(e1) - + with pytest.raises(InvalidEdgeError): g.add_edge(e2) + def test_graph_collector_invalid_with_varying_input_output(): g = Graph() - n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = ListPassThroughInvocation(id = "4") + n1 = PromptTestInvocation(id="1", prompt="Banana sushi") + n2 = PromptTestInvocation(id="2", prompt="Banana sushi 2") + n3 = CollectInvocation(id="3") + n4 = ListPassThroughInvocation(id="4") g.add_node(n1) g.add_node(n2) g.add_node(n3) g.add_node(n4) - e1 = create_edge("1","prompt","3","item") - e2 = create_edge("2","prompt","3","item") - e3 = create_edge("3","collection","4","collection") + e1 = create_edge("1", "prompt", "3", "item") + e2 = create_edge("2", "prompt", "3", "item") + e3 = create_edge("3", "collection", "4", "collection") g.add_edge(e1) g.add_edge(e2) with pytest.raises(InvalidEdgeError): g.add_edge(e3) + def test_graph_collector_invalid_with_non_list_output(): g = Graph() - n1 = PromptTestInvocation(id = "1", prompt = "Banana sushi") - n2 = PromptTestInvocation(id = "2", prompt = "Banana sushi 2") - n3 = CollectInvocation(id = "3") - n4 = PromptTestInvocation(id = "4") + n1 = PromptTestInvocation(id="1", prompt="Banana sushi") + n2 = PromptTestInvocation(id="2", prompt="Banana sushi 2") + n3 = CollectInvocation(id="3") + n4 = PromptTestInvocation(id="4") g.add_node(n1) g.add_node(n2) g.add_node(n3) g.add_node(n4) - e1 = create_edge("1","prompt","3","item") - e2 = create_edge("2","prompt","3","item") - e3 = create_edge("3","collection","4","prompt") + e1 = create_edge("1", "prompt", "3", "item") + e2 = create_edge("2", "prompt", "3", "item") + e3 = create_edge("3", "collection", "4", "prompt") g.add_edge(e1) g.add_edge(e2) with pytest.raises(InvalidEdgeError): g.add_edge(e3) + def test_graph_connects_iterator(): g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") + n1 = ListPassThroughInvocation(id="1") + n2 = IterateInvocation(id="2") + n3 = ImageToImageTestInvocation(id="3", prompt="Banana sushi") g.add_node(n1) g.add_node(n2) g.add_node(n3) - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","image") + e1 = create_edge("1", "collection", "2", "collection") + e2 = create_edge("2", "item", "3", "image") g.add_edge(e1) g.add_edge(e2) + # TODO: TEST INVALID ITERATOR SCENARIOS + def test_graph_iterator_invalid_if_multiple_inputs(): g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = ImageToImageTestInvocation(id = "3", prompt = "Banana sushi") - n4 = ListPassThroughInvocation(id = "4") + n1 = ListPassThroughInvocation(id="1") + n2 = IterateInvocation(id="2") + n3 = ImageToImageTestInvocation(id="3", prompt="Banana sushi") + n4 = ListPassThroughInvocation(id="4") g.add_node(n1) g.add_node(n2) g.add_node(n3) g.add_node(n4) - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","image") - e3 = create_edge("4","collection","2","collection") + e1 = create_edge("1", "collection", "2", "collection") + e2 = create_edge("2", "item", "3", "image") + e3 = create_edge("4", "collection", "2", "collection") g.add_edge(e1) g.add_edge(e2) with pytest.raises(InvalidEdgeError): g.add_edge(e3) + def test_graph_iterator_invalid_if_input_not_list(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = IterateInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = IterateInvocation(id="2") g.add_node(n1) g.add_node(n2) - e1 = create_edge("1","collection","2","collection") + e1 = create_edge("1", "collection", "2", "collection") with pytest.raises(InvalidEdgeError): g.add_edge(e1) + def test_graph_iterator_invalid_if_output_and_input_types_different(): g = Graph() - n1 = ListPassThroughInvocation(id = "1") - n2 = IterateInvocation(id = "2") - n3 = PromptTestInvocation(id = "3", prompt = "Banana sushi") + n1 = ListPassThroughInvocation(id="1") + n2 = IterateInvocation(id="2") + n3 = PromptTestInvocation(id="3", prompt="Banana sushi") g.add_node(n1) g.add_node(n2) g.add_node(n3) - e1 = create_edge("1","collection","2","collection") - e2 = create_edge("2","item","3","prompt") + e1 = create_edge("1", "collection", "2", "collection") + e2 = create_edge("2", "item", "3", "prompt") g.add_edge(e1) with pytest.raises(InvalidEdgeError): g.add_edge(e2) + def test_graph_validates(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e1 = create_edge("1","image","2","image") + e1 = create_edge("1", "image", "2", "image") g.add_edge(e1) assert g.is_valid() == True + def test_graph_invalid_if_edges_reference_missing_nodes(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") g.nodes[n1.id] = n1 - e1 = create_edge("1","image","2","image") + e1 = create_edge("1", "image", "2", "image") g.edges.append(e1) assert g.is_valid() == False + def test_graph_invalid_if_subgraph_invalid(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id = "2", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") n1.graph.nodes[n1_1.id] = n1_1 - e1 = create_edge("1","image","2","image") + e1 = create_edge("1", "image", "2", "image") n1.graph.edges.append(e1) g.nodes[n1.id] = n1 assert g.is_valid() == False + def test_graph_invalid_if_has_cycle(): g = Graph() - n1 = ESRGANInvocation(id = "1") - n2 = ESRGANInvocation(id = "2") + n1 = ESRGANInvocation(id="1") + n2 = ESRGANInvocation(id="2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 - e1 = create_edge("1","image","2","image") - e2 = create_edge("2","image","1","image") + e1 = create_edge("1", "image", "2", "image") + e2 = create_edge("2", "image", "1", "image") g.edges.append(e1) g.edges.append(e2) assert g.is_valid() == False + def test_graph_invalid_with_invalid_connection(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.nodes[n1.id] = n1 g.nodes[n2.id] = n2 - e1 = create_edge("1","image","2","strength") + e1 = create_edge("1", "image", "2", "strength") g.edges.append(e1) assert g.is_valid() == False @@ -403,47 +451,47 @@ def test_graph_invalid_with_invalid_connection(): # TODO: Subgraph operations def test_graph_gets_subgraph_node(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) - result = g.get_node('1.1') + result = g.get_node("1.1") assert result is not None - assert result.id == '1' + assert result.id == "1" assert result == n1_1 def test_graph_expands_subgraph(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") n1.graph = Graph() - n1_1 = AddInvocation(id = "1", a = 1, b = 2) - n1_2 = SubtractInvocation(id = "2", b = 3) + n1_1 = AddInvocation(id="1", a=1, b=2) + n1_2 = SubtractInvocation(id="2", b=3) n1.graph.add_node(n1_1) n1.graph.add_node(n1_2) - n1.graph.add_edge(create_edge("1","a","2","a")) + n1.graph.add_edge(create_edge("1", "a", "2", "a")) g.add_node(n1) - n2 = AddInvocation(id = "2", b = 5) + n2 = AddInvocation(id="2", b=5) g.add_node(n2) - g.add_edge(create_edge("1.2","a","2","a")) + g.add_edge(create_edge("1.2", "a", "2", "a")) dg = g.nx_graph_flat() - assert set(dg.nodes) == set(['1.1', '1.2', '2']) - assert set(dg.edges) == set([('1.1', '1.2'), ('1.2', '2')]) + assert set(dg.nodes) == set(["1.1", "1.2", "2"]) + assert set(dg.edges) == set([("1.1", "1.2"), ("1.2", "2")]) def test_graph_subgraph_t2i(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") # Get text to image default graph lg = create_text_to_image() @@ -451,28 +499,26 @@ def test_graph_subgraph_t2i(): g.add_node(n1) - n2 = ParamIntInvocation(id = "2", a = 512) - n3 = ParamIntInvocation(id = "3", a = 256) + n2 = ParamIntInvocation(id="2", a=512) + n3 = ParamIntInvocation(id="3", a=256) g.add_node(n2) g.add_node(n3) - g.add_edge(create_edge("2","a","1.width","a")) - g.add_edge(create_edge("3","a","1.height","a")) - - n4 = ShowImageInvocation(id = "4") + g.add_edge(create_edge("2", "a", "1.width", "a")) + g.add_edge(create_edge("3", "a", "1.height", "a")) + + n4 = ShowImageInvocation(id="4") g.add_node(n4) - g.add_edge(create_edge("1.8","image","4","image")) + g.add_edge(create_edge("1.8", "image", "4", "image")) # Validate dg = g.nx_graph_flat() - assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '2', '3', '4']) - expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] - expected_edges.extend([ - ('2','1.width'), - ('3','1.height'), - ('1.8','4') - ]) + assert set(dg.nodes) == set( + ["1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"] + ) + expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] + expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) print(expected_edges) print(list(dg.edges)) assert set(dg.edges) == set(expected_edges) @@ -480,86 +526,90 @@ def test_graph_subgraph_t2i(): def test_graph_fails_to_get_missing_subgraph_node(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) with pytest.raises(NodeNotFoundError): - result = g.get_node('1.2') + result = g.get_node("1.2") + def test_graph_fails_to_enumerate_non_subgraph_node(): g = Graph() - n1 = GraphInvocation(id = "1") + n1 = GraphInvocation(id="1") n1.graph = Graph() n1.graph.add_node - n1_1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") + n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") n1.graph.add_node(n1_1) g.add_node(n1) - - n2 = ESRGANInvocation(id = "2") + + n2 = ESRGANInvocation(id="2") g.add_node(n2) with pytest.raises(NodeNotFoundError): - result = g.get_node('2.1') + result = g.get_node("2.1") + def test_graph_gets_networkx_graph(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") + e = create_edge(n1.id, "image", n2.id, "image") g.add_edge(e) nxg = g.nx_graph() - assert '1' in nxg.nodes - assert '2' in nxg.nodes - assert ('1','2') in nxg.edges + assert "1" in nxg.nodes + assert "2" in nxg.nodes + assert ("1", "2") in nxg.edges # TODO: Graph serializes and deserializes def test_graph_can_serialize(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") + e = create_edge(n1.id, "image", n2.id, "image") g.add_edge(e) # Not throwing on this line is sufficient json = g.json() + def test_graph_can_deserialize(): g = Graph() - n1 = TextToImageTestInvocation(id = "1", prompt = "Banana sushi") - n2 = ESRGANInvocation(id = "2") + n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") + n2 = ESRGANInvocation(id="2") g.add_node(n1) g.add_node(n2) - e = create_edge(n1.id,"image",n2.id,"image") + e = create_edge(n1.id, "image", n2.id, "image") g.add_edge(e) json = g.json() g2 = Graph.parse_raw(json) assert g2 is not None - assert g2.nodes['1'] is not None - assert g2.nodes['2'] is not None + assert g2.nodes["1"] is not None + assert g2.nodes["2"] is not None assert len(g2.edges) == 1 - assert g2.edges[0].source.node_id == '1' - assert g2.edges[0].source.field == 'image' - assert g2.edges[0].destination.node_id == '2' - assert g2.edges[0].destination.field == 'image' + assert g2.edges[0].source.node_id == "1" + assert g2.edges[0].source.field == "image" + assert g2.edges[0].destination.node_id == "2" + assert g2.edges[0].destination.field == "image" + def test_graph_can_generate_schema(): # Not throwing on this line is sufficient # NOTE: if this test fails, it's PROBABLY because a new invocation type is breaking schema generation - schema = Graph.schema_json(indent = 2) + schema = Graph.schema_json(indent=2) diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index af011954c5..13338e9261 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -5,79 +5,92 @@ from invokeai.app.services.invocation_services import InvocationServices from pydantic import Field import pytest + # Define test invocations before importing anything that uses invocations class ListPassThroughInvocationOutput(BaseInvocationOutput): - type: Literal['test_list_output'] = 'test_list_output' + type: Literal["test_list_output"] = "test_list_output" collection: list[ImageField] = Field(default_factory=list) + class ListPassThroughInvocation(BaseInvocation): - type: Literal['test_list'] = 'test_list' + type: Literal["test_list"] = "test_list" collection: list[ImageField] = Field(default_factory=list) def invoke(self, context: InvocationContext) -> ListPassThroughInvocationOutput: - return ListPassThroughInvocationOutput(collection = self.collection) + return ListPassThroughInvocationOutput(collection=self.collection) + class PromptTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_prompt_output'] = 'test_prompt_output' + type: Literal["test_prompt_output"] = "test_prompt_output" + + prompt: str = Field(default="") - prompt: str = Field(default = "") class PromptTestInvocation(BaseInvocation): - type: Literal['test_prompt'] = 'test_prompt' + type: Literal["test_prompt"] = "test_prompt" - prompt: str = Field(default = "") + prompt: str = Field(default="") def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: - return PromptTestInvocationOutput(prompt = self.prompt) + return PromptTestInvocationOutput(prompt=self.prompt) + class ErrorInvocation(BaseInvocation): - type: Literal['test_error'] = 'test_error' + type: Literal["test_error"] = "test_error" def invoke(self, context: InvocationContext) -> PromptTestInvocationOutput: raise Exception("This invocation is supposed to fail") + class ImageTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_image_output'] = 'test_image_output' + type: Literal["test_image_output"] = "test_image_output" image: ImageField = Field() -class TextToImageTestInvocation(BaseInvocation): - type: Literal['test_text_to_image'] = 'test_text_to_image' - prompt: str = Field(default = "") +class TextToImageTestInvocation(BaseInvocation): + type: Literal["test_text_to_image"] = "test_text_to_image" + + prompt: str = Field(default="") def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) -class ImageToImageTestInvocation(BaseInvocation): - type: Literal['test_image_to_image'] = 'test_image_to_image' - prompt: str = Field(default = "") +class ImageToImageTestInvocation(BaseInvocation): + type: Literal["test_image_to_image"] = "test_image_to_image" + + prompt: str = Field(default="") image: Union[ImageField, None] = Field(default=None) def invoke(self, context: InvocationContext) -> ImageTestInvocationOutput: return ImageTestInvocationOutput(image=ImageField(image_name=self.id)) + class PromptCollectionTestInvocationOutput(BaseInvocationOutput): - type: Literal['test_prompt_collection_output'] = 'test_prompt_collection_output' + type: Literal["test_prompt_collection_output"] = "test_prompt_collection_output" collection: list[str] = Field(default_factory=list) + class PromptCollectionTestInvocation(BaseInvocation): - type: Literal['test_prompt_collection'] = 'test_prompt_collection' + type: Literal["test_prompt_collection"] = "test_prompt_collection" collection: list[str] = Field() def invoke(self, context: InvocationContext) -> PromptCollectionTestInvocationOutput: return PromptCollectionTestInvocationOutput(collection=self.collection.copy()) + from invokeai.app.services.events import EventServiceBase from invokeai.app.services.graph import Edge, EdgeConnection + def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge: return Edge( - source=EdgeConnection(node_id = from_id, field = from_field), - destination=EdgeConnection(node_id = to_id, field = to_field)) + source=EdgeConnection(node_id=from_id, field=from_field), + destination=EdgeConnection(node_id=to_id, field=to_field), + ) class TestEvent: @@ -88,6 +101,7 @@ class TestEvent: self.event_name = event_name self.payload = payload + class TestEventService(EventServiceBase): events: list @@ -98,8 +112,10 @@ class TestEventService(EventServiceBase): def dispatch(self, event_name: str, payload: Any) -> None: pass + def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None: import time + start_time = time.time() while time.time() - start_time < timeout: if condition(): diff --git a/tests/nodes/test_sqlite.py b/tests/nodes/test_sqlite.py index a803af5635..a9eb542e44 100644 --- a/tests/nodes/test_sqlite.py +++ b/tests/nodes/test_sqlite.py @@ -3,110 +3,131 @@ from pydantic import BaseModel, Field class TestModel(BaseModel): - id: str = Field(description = "ID") - name: str = Field(description = "Name") + id: str = Field(description="ID") + name: str = Field(description="Name") def test_sqlite_service_can_create_and_get(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - assert db.get('1') == TestModel(id = '1', name = 'Test') + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + assert db.get("1") == TestModel(id="1", name="Test") + def test_sqlite_service_can_list(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) results = db.list() assert results.page == 0 assert results.pages == 1 assert results.per_page == 10 assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + assert results.items == [ + TestModel(id="1", name="Test"), + TestModel(id="2", name="Test"), + TestModel(id="3", name="Test"), + ] + def test_sqlite_service_can_delete(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.delete('1') - assert db.get('1') is None + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.delete("1") + assert db.get("1") is None + def test_sqlite_service_calls_set_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") called = False + def on_changed(item: TestModel): nonlocal called called = True + db.on_changed(on_changed) - db.set(TestModel(id = '1', name = 'Test')) + db.set(TestModel(id="1", name="Test")) assert called + def test_sqlite_service_calls_delete_callback(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") called = False + def on_deleted(item_id: str): nonlocal called called = True + db.on_deleted(on_deleted) - db.set(TestModel(id = '1', name = 'Test')) - db.delete('1') + db.set(TestModel(id="1", name="Test")) + db.delete("1") assert called + def test_sqlite_service_can_list_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.list(page = 0, per_page = 2) + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) + results = db.list(page=0, per_page=2) assert results.page == 0 assert results.pages == 2 assert results.per_page == 2 assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] + def test_sqlite_service_can_list_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.list(page = 1, per_page = 2) + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) + results = db.list(page=1, per_page=2) assert results.page == 1 assert results.pages == 2 assert results.per_page == 2 assert results.total == 3 - assert results.items == [TestModel(id = '3', name = 'Test')] + assert results.items == [TestModel(id="3", name="Test")] + def test_sqlite_service_can_search(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test') + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) + results = db.search(query="Test") assert results.page == 0 assert results.pages == 1 assert results.per_page == 10 assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test'), TestModel(id = '3', name = 'Test')] + assert results.items == [ + TestModel(id="1", name="Test"), + TestModel(id="2", name="Test"), + TestModel(id="3", name="Test"), + ] + def test_sqlite_service_can_search_with_pagination(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test', page = 0, per_page = 2) + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) + results = db.search(query="Test", page=0, per_page=2) assert results.page == 0 assert results.pages == 2 assert results.per_page == 2 assert results.total == 3 - assert results.items == [TestModel(id = '1', name = 'Test'), TestModel(id = '2', name = 'Test')] + assert results.items == [TestModel(id="1", name="Test"), TestModel(id="2", name="Test")] + def test_sqlite_service_can_search_with_pagination_and_offset(): - db = SqliteItemStorage[TestModel](sqlite_memory, 'test', 'id') - db.set(TestModel(id = '1', name = 'Test')) - db.set(TestModel(id = '2', name = 'Test')) - db.set(TestModel(id = '3', name = 'Test')) - results = db.search(query = 'Test', page = 1, per_page = 2) + db = SqliteItemStorage[TestModel](sqlite_memory, "test", "id") + db.set(TestModel(id="1", name="Test")) + db.set(TestModel(id="2", name="Test")) + db.set(TestModel(id="3", name="Test")) + results = db.search(query="Test", page=1, per_page=2) assert results.page == 1 assert results.pages == 2 assert results.per_page == 2 assert results.total == 3 - assert results.items == [TestModel(id = '3', name = 'Test')] + assert results.items == [TestModel(id="3", name="Test")] diff --git a/tests/test_config.py b/tests/test_config.py index 498a47748e..5d3dc46aa4 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,87 +5,91 @@ import sys from omegaconf import OmegaConf from pathlib import Path -os.environ['INVOKEAI_ROOT']='/tmp' +os.environ["INVOKEAI_ROOT"] = "/tmp" from invokeai.app.services.config import InvokeAIAppConfig init1 = OmegaConf.create( -''' + """ InvokeAI: Features: always_use_cpu: false Memory/Performance: max_cache_size: 5 tiled_decode: false -''' +""" ) init2 = OmegaConf.create( -''' + """ InvokeAI: Features: always_use_cpu: true Memory/Performance: max_cache_size: 2 tiled_decode: true -''' +""" ) + def test_use_init(): # note that we explicitly set omegaconf dict and argv here # so that the values aren't read from ~invokeai/invokeai.yaml and # sys.argv respectively. conf1 = InvokeAIAppConfig.get_config() assert conf1 - conf1.parse_args(conf=init1,argv=[]) + conf1.parse_args(conf=init1, argv=[]) assert not conf1.tiled_decode - assert conf1.max_cache_size==5 + assert conf1.max_cache_size == 5 assert not conf1.always_use_cpu conf2 = InvokeAIAppConfig.get_config() assert conf2 - conf2.parse_args(conf=init2,argv=[]) + conf2.parse_args(conf=init2, argv=[]) assert conf2.tiled_decode - assert conf2.max_cache_size==2 - assert not hasattr(conf2,'invalid_attribute') - + assert conf2.max_cache_size == 2 + assert not hasattr(conf2, "invalid_attribute") + + def test_argv_override(): conf = InvokeAIAppConfig.get_config() - conf.parse_args(conf=init1,argv=['--always_use_cpu','--max_cache=10']) + conf.parse_args(conf=init1, argv=["--always_use_cpu", "--max_cache=10"]) assert conf.always_use_cpu - assert conf.max_cache_size==10 - assert conf.outdir==Path('outputs') # this is the default - + assert conf.max_cache_size == 10 + assert conf.outdir == Path("outputs") # this is the default + + def test_env_override(): - # argv overrides + # argv overrides conf = InvokeAIAppConfig() - conf.parse_args(conf=init1,argv=['--max_cache=10']) - assert conf.always_use_cpu==False - os.environ['INVOKEAI_always_use_cpu'] = 'True' - conf.parse_args(conf=init1,argv=['--max_cache=10']) - assert conf.always_use_cpu==True + conf.parse_args(conf=init1, argv=["--max_cache=10"]) + assert conf.always_use_cpu == False + os.environ["INVOKEAI_always_use_cpu"] = "True" + conf.parse_args(conf=init1, argv=["--max_cache=10"]) + assert conf.always_use_cpu == True # environment variables should be case insensitive - os.environ['InvokeAI_Max_Cache_Size'] = '15' + os.environ["InvokeAI_Max_Cache_Size"] = "15" conf = InvokeAIAppConfig() - conf.parse_args(conf=init1,argv=[]) + conf.parse_args(conf=init1, argv=[]) assert conf.max_cache_size == 15 conf = InvokeAIAppConfig() - conf.parse_args(conf=init1,argv=['--no-always_use_cpu','--max_cache=10']) - assert conf.always_use_cpu==False - assert conf.max_cache_size==10 + conf.parse_args(conf=init1, argv=["--no-always_use_cpu", "--max_cache=10"]) + assert conf.always_use_cpu == False + assert conf.max_cache_size == 10 conf = InvokeAIAppConfig.get_config(max_cache_size=20) - conf.parse_args(conf=init1,argv=[]) - assert conf.max_cache_size==20 + conf.parse_args(conf=init1, argv=[]) + assert conf.max_cache_size == 20 + def test_type_coercion(): conf = InvokeAIAppConfig().get_config() - conf.parse_args(argv=['--root=/tmp/foobar']) - assert conf.root==Path('/tmp/foobar') - assert isinstance(conf.root,Path) - conf = InvokeAIAppConfig.get_config(root='/tmp/different') - conf.parse_args(argv=['--root=/tmp/foobar']) - assert conf.root==Path('/tmp/different') - assert isinstance(conf.root,Path) + conf.parse_args(argv=["--root=/tmp/foobar"]) + assert conf.root == Path("/tmp/foobar") + assert isinstance(conf.root, Path) + conf = InvokeAIAppConfig.get_config(root="/tmp/different") + conf.parse_args(argv=["--root=/tmp/foobar"]) + assert conf.root == Path("/tmp/different") + assert isinstance(conf.root, Path) diff --git a/tests/test_path.py b/tests/test_path.py index 52936142c7..0da5f2f8dc 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -11,6 +11,7 @@ import invokeai.frontend.web.dist as frontend import invokeai.configs as configs import invokeai.app.assets.images as image_assets + class ConfigsTestCase(unittest.TestCase): """Test the configuration related imports and objects"""