From da187d6a875cd27459ec1c811fbe9e48ffe94b65 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 17 Jul 2023 15:51:56 -0400 Subject: [PATCH] API model downloads are now threaded and generate progress events --- invokeai/app/api/events.py | 4 +- invokeai/app/api/routers/models.py | 56 ++++----- invokeai/app/services/events.py | 109 ++++++++++++++--- .../app/services/model_manager_service.py | 4 +- .../backend/install/model_install_backend.py | 113 +++++++++++++----- .../backend/model_management/model_manager.py | 6 +- invokeai/backend/util/util.py | 37 +++++- 7 files changed, 247 insertions(+), 82 deletions(-) diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py index 41414a9230..ff13b0ba57 100644 --- a/invokeai/app/api/events.py +++ b/invokeai/app/api/events.py @@ -9,7 +9,6 @@ from fastapi_events.dispatcher import dispatch from ..services.events import EventServiceBase - class FastAPIEventService(EventServiceBase): event_handler_id: int __queue: Queue @@ -28,6 +27,9 @@ class FastAPIEventService(EventServiceBase): self.__queue.put(None) def dispatch(self, event_name: str, payload: Any) -> None: + # TODO: Remove next two debugging lines + from .dependencies import ApiDependencies + ApiDependencies.invoker.services.logger.debug(f'dispatch {event_name} / {payload}') self.__queue.put(dict(event_name=event_name, payload=payload)) async def __dispatch_from_queue(self, stop_event: threading.Event): diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 923a3767a3..0af921d8f9 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,6 +2,7 @@ import pathlib +import threading from typing import Literal, List, Optional, Union from fastapi import Body, Path, Query, Response @@ -106,50 +107,43 @@ async def update_model( "/import", operation_id="import_model", responses= { - 201: {"description" : "The model imported successfully"}, - 404: {"description" : "The model could not be found"}, - 424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"}, - 409: {"description" : "There is already a model corresponding to this path or repo_id"}, + 200: {"description" : "The path was queued for import"}, }, - status_code=201, - response_model=ImportModelResponse + status_code=200 ) async def import_model( location: str = Body(description="A model path, repo_id or URL to import"), prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), -) -> ImportModelResponse: - """ Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ +) -> str: + """ + Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically. + This call launches a background thread to process the imported model and always succeeds. Results are reported in the background + as the following events: + - model_import_started(import_path) + - model_import_completed(import_path, AddModelResults) + - download_started(url) + - download_progress(url,downloaded_bytes,total_bytes) + - download_completed(url,status_code,download_path) + """ items_to_import = {location} prediction_types = { x.value: x for x in SchedulerPredictionType } logger = ApiDependencies.invoker.services.logger + events = ApiDependencies.invoker.services.events try: - installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( - items_to_import = items_to_import, - prediction_type_helper = lambda x: prediction_types.get(prediction_type) - ) - info = installed_models.get(location) - - if not info: - logger.error("Import failed") - raise HTTPException(status_code=424) - - logger.info(f'Successfully imported {location}, got {info}') - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=info.name, - base_model=info.base_model, - model_type=info.model_type - ) - return parse_obj_as(ImportModelResponse, model_raw) - - except ModelNotFoundException as e: + import_thread = threading.Thread(target = ApiDependencies.invoker.services.model_manager.heuristic_import, + kwargs = dict(items_to_import = items_to_import, + prediction_type_helper = lambda x: prediction_types.get(prediction_type), + event_bus = events, + ) + ) + import_thread.start() + return 'request queued' + except Exception as e: logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) @models_router.post( "/add", diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index 6c516c9b74..69fde412b9 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -1,9 +1,11 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from typing import Any, Optional +from pathlib import Path + from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp -from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo +from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo, AddModelResult class EventServiceBase: session_event: str = "session_event" @@ -104,21 +106,15 @@ class EventServiceBase: def emit_model_load_started ( self, - graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: SubModelType, ) -> None: """Emitted when a model is requested""" - self.__emit_session_event( + self.dispatch( event_name="model_load_started", payload=dict( - graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -128,9 +124,6 @@ class EventServiceBase: def emit_model_load_completed( self, - graph_execution_state_id: str, - node: dict, - source_node_id: str, model_name: str, base_model: BaseModelType, model_type: ModelType, @@ -138,12 +131,9 @@ class EventServiceBase: model_info: ModelInfo, ) -> None: """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_session_event( + self.dispatch( event_name="model_load_completed", payload=dict( - graph_execution_state_id=graph_execution_state_id, - node=node, - source_node_id=source_node_id, model_name=model_name, base_model=base_model, model_type=model_type, @@ -151,3 +141,92 @@ class EventServiceBase: model_info=model_info, ), ) + + def emit_model_import_started ( + self, + import_path: str, # can be a local path, URL or repo_id + )->None: + """Emitted when a model import commences""" + self.dispatch( + event_name="model_import_started", + payload=dict( + import_path = import_path, + ) + ) + + def emit_model_import_completed ( + self, + import_path: str, # can be a local path, URL or repo_id + import_info: AddModelResult, + success: bool= True, + error: str = None, + + )->None: + """Emitted when a model import completes""" + self.dispatch( + event_name="model_import_completed", + payload=dict( + import_path = import_path, + import_info = import_info, + success = success, + error = error, + ) + ) + + def emit_download_started ( + self, + url: str, + + )->None: + """Emitted when a download thread starts""" + self.dispatch( + event_name="download_started", + payload=dict( + url = url, + ) + ) + + def emit_download_progress ( + self, + url: str, + downloaded_size: int, + total_size: int, + )->None: + """ + Emitted at intervals during a download process + :param url: Requested URL + :param downloaded_size: Bytes downloaded so far + :param total_size: Total bytes to download + """ + self.dispatch( + event_name="download_progress", + payload=dict( + url = url, + downloaded_size = downloaded_size, + total_size = total_size, + ) + ) + + def emit_download_completed ( + self, + url: str, + status_code: int, + download_path: Path, + + )->None: + """ + Emitted when a download thread completes. + :param url: Requested URL + :param status_code: HTTP status code from request + :param download_path: Path to downloaded file + """ + self.dispatch( + event_name="download_completed", + payload=dict( + url = url, + status_code = status_code, + download_path = download_path, + ) + ) + + diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 7dba1dff06..e238f10140 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -26,6 +26,7 @@ import torch from invokeai.app.models.exceptions import CanceledException from ...backend.util import choose_precision, choose_torch_device from .config import InvokeAIAppConfig +from .events import EventServiceBase if TYPE_CHECKING: from ..invocations.baseinvocation import BaseInvocation, InvocationContext @@ -552,6 +553,7 @@ class ModelManagerService(ModelManagerServiceBase): def heuristic_import(self, items_to_import: set[str], prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, + event_bus: Optional[EventServiceBase]=None, )->dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. @@ -569,7 +571,7 @@ class ModelManagerService(ModelManagerServiceBase): of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. ''' - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + return self.mgr.heuristic_import(items_to_import, prediction_type_helper, event_bus=event_bus) def merge_models( self, diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index 559dac6f61..4be7a647e6 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -89,13 +89,16 @@ class ModelInstall(object): config:InvokeAIAppConfig, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, model_manager: ModelManager = None, - access_token:str = None): + access_token:str = None, + event_bus = None, # EventServicesBase - getting circular import errors + ): self.config = config self.mgr = model_manager or ModelManager(config.model_conf_path) self.datasets = OmegaConf.load(Dataset_path) self.prediction_helper = prediction_type_helper self.access_token = access_token or HfFolder.get_token() self.reverse_paths = self._reverse_paths(self.datasets) + self.event_bus = event_bus def all_models(self)->Dict[str,ModelLoadInfo]: ''' @@ -197,39 +200,63 @@ class ModelInstall(object): Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. ''' + if self.event_bus: + self.event_bus.emit_model_import_started(str(model_path_id_or_url)) + if not models_installed: models_installed = dict() # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url path = Path(model_path_id_or_url) - # checkpoint file, or similar - if path.is_file(): - models_installed.update({str(path):self._install_path(path)}) - # folders style or similar - elif path.is_dir() and any([(path/x).exists() for x in \ - {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} - ] - ): - models_installed.update(self._install_path(path)) + try: + # checkpoint file, or similar + if path.is_file(): + models_installed.update({str(path):self._install_path(path)}) - # recursive scan - elif path.is_dir(): - for child in path.iterdir(): - self.heuristic_import(child, models_installed=models_installed) + # folders style or similar + elif path.is_dir() and any([(path/x).exists() for x in \ + {'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'} + ] + ): + models_installed.update(self._install_path(path)) - # huggingface repo - elif len(str(model_path_id_or_url).split('/')) == 2: - models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) + # recursive scan + elif path.is_dir(): + for child in path.iterdir(): + self.heuristic_import(child, models_installed=models_installed) - # a URL - elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): - models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) + # huggingface repo + elif len(str(model_path_id_or_url).split('/')) == 2: + models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) - else: - raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') + # a URL + elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): + models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) + else: + errmsg = f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping' + raise KeyError(errmsg) + + if self.event_bus: + for path, add_model_result in models_installed.items(): + self.event_bus.emit_model_import_completed( + str(path), + import_info = add_model_result, + ) + except Exception as e: + if self.event_bus: + self.event_bus.emit_model_import_completed( + str(path), + import_info = None, + success = False, + error = str(e), + ) + return models_installed + else: + raise + return models_installed # install a model from a local path. The optional info parameter is there to prevent @@ -238,10 +265,14 @@ class ModelInstall(object): info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) if not info: logger.warning(f'Unable to parse format of {path}') - return None + raise ValueError(f'Unable to parse format of {path}') + model_name = path.stem if path.is_file() else path.name + if self.mgr.model_exists(model_name, info.base_type, info.model_type): - raise ValueError(f'A model named "{model_name}" is already installed.') + errmsg = f'A model named "{model_name}" is already installed.' + raise ValueError(errmsg) + attributes = self._make_attributes(path,info) return self.mgr.add_model(model_name = model_name, base_model = info.base_type, @@ -251,7 +282,7 @@ class ModelInstall(object): def _install_url(self, url: str)->AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: - location = download_with_resume(url,Path(staging)) + location = download_with_resume(url,Path(staging),event_bus=self.event_bus) if not location: logger.error(f'Unable to download {url}. Skipping.') info = ModelProbe().heuristic_probe(location) @@ -384,7 +415,8 @@ class ModelInstall(object): p = hf_download_with_resume(repo_id, model_dir=location, model_name=filename, - access_token = self.access_token + access_token = self.access_token, + event_bus = self.event_bus, ) if p: paths.append(p) @@ -425,12 +457,15 @@ def hf_download_from_pretrained( return destination # --------------------------------------------- +# TODO: This function is almost identical to invokeai.backend.util.download_with_resume +# and should be merged def hf_download_with_resume( repo_id: str, model_dir: str, model_name: str, model_dest: Path = None, access_token: str = None, + event_bus = None, ) -> Path: model_dest = model_dest or Path(os.path.join(model_dir, model_name)) os.makedirs(model_dir, exist_ok=True) @@ -447,15 +482,22 @@ def hf_download_with_resume( open_mode = "ab" resp = requests.get(url, headers=header, stream=True) - total = int(resp.headers.get("content-length", 0)) + content_length = int(resp.headers.get("content-length", 0)) + + if event_bus: + event_bus.emit_download_started(url) if ( resp.status_code == 416 ): # "range not satisfiable", which means nothing to return logger.info(f"{model_name}: complete file found. Skipping.") + if event_bus: + event_bus.emit_download_completed(url,resp.status_code,model_dest) return model_dest elif resp.status_code == 404: logger.warning("File not found") + if event_bus: + event_bus.emit_download_completed(url,resp.status_code,None) return None elif resp.status_code != 200: logger.warning(f"{model_name}: {resp.reason}") @@ -464,11 +506,15 @@ def hf_download_with_resume( else: logger.info(f"{model_name}: Downloading...") + MB10 = 10 * 1048576 + downloaded = exist_size + previous_interval = 0 + try: with open(model_dest, open_mode) as file, tqdm( desc=model_name, initial=exist_size, - total=total + exist_size, + total=content_length + exist_size, unit="iB", unit_scale=True, unit_divisor=1000, @@ -476,9 +522,20 @@ def hf_download_with_resume( for data in resp.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) + downloaded += size + if event_bus and downloaded // MB10 > previous_interval: + previous_interval = downloaded // MB10 + event_bus.emit_download_progress(url, downloaded, content_length) + except Exception as e: logger.error(f"An error occurred while downloading {model_name}: {str(e)}") + if event_bus: + event_bus.emit_download_completed(url,500,None) return None + + if event_bus: + event_bus.emit_download_completed(url,resp.status_code,model_dest) + return model_dest diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index c62f42b88d..9d1a017c4a 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -946,6 +946,7 @@ class ModelManager(object): def heuristic_import(self, items_to_import: Set[str], prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, + event_bus = None, # EventServiceBase, with circular dependency issues )->Dict[str, AddModelResult]: '''Import a list of paths, repo_ids or URLs. Returns the set of successfully imported items. @@ -973,10 +974,13 @@ class ModelManager(object): installer = ModelInstall(config = self.app_config, prediction_type_helper = prediction_type_helper, - model_manager = self) + model_manager = self, + event_bus = event_bus, + ) for thing in items_to_import: installed = installer.heuristic_import(thing) successfully_installed.update(installed) + self.commit() return successfully_installed diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index 1cc632e483..7221be6af3 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -21,7 +21,6 @@ from tqdm import tqdm import invokeai.backend.util.logging as logger from .devices import torch_dtype - def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot @@ -285,7 +284,11 @@ def ask_user(question: str, answers: list): # ------------------------------------- -def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: +def download_with_resume(url: str, + dest: Path, + access_token: str = None, + event_bus = None, # EventServiceBase (circular import issues) + ) -> Path: """ Download a model file. :param url: https, http or ftp URL @@ -323,8 +326,13 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path os.remove(dest) exist_size = 0 + if event_bus: + event_bus.emit_download_started(url) + if resp.status_code == 416 or (content_length > 0 and exist_size == content_length): logger.warning(f"{dest}: complete file found. Skipping.") + if event_bus: + event_bus.emit_download_completed(url,resp.status_code,dest) return dest elif resp.status_code == 206 or exist_size > 0: logger.warning(f"{dest}: partial file found. Resuming...") @@ -333,11 +341,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path else: logger.info(f"{dest}: Downloading...") - try: - if content_length < 2000: - logger.error(f"ERROR DOWNLOADING {url}: {resp.text}") + # If less than 2K, it's not a model - usually an HTML document of some sort + if content_length < 2000: + logger.error(f"ERROR DOWNLOADING {url}: {resp.text}") + if event_bus: + event_bus.emit_download_completed(url, 500, None) return None + # these variables are used in progress reporting events + MB10 = 10 * 1048576 + downloaded = exist_size + previous_interval = 0 + + try: + with open(dest, open_mode) as file, tqdm( desc=str(dest), initial=exist_size, @@ -349,10 +366,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path for data in resp.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) + downloaded += size + if event_bus and downloaded // MB10 > previous_interval: + previous_interval = downloaded // MB10 + event_bus.emit_download_progress(url, downloaded, content_length) + except Exception as e: logger.error(f"An error occurred while downloading {dest}: {str(e)}") + if event_bus: + event_bus.emit_download_completed(url,500,None) return None + if event_bus: + event_bus.emit_download_completed(url,resp.status_code,dest) + return dest