API model downloads are now threaded and generate progress events

This commit is contained in:
Lincoln Stein 2023-07-17 15:51:56 -04:00
parent 1d3fda80aa
commit da187d6a87
7 changed files with 247 additions and 82 deletions

View File

@ -9,7 +9,6 @@ from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
@ -28,6 +27,9 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
# TODO: Remove next two debugging lines
from .dependencies import ApiDependencies
ApiDependencies.invoker.services.logger.debug(f'dispatch {event_name} / {payload}')
self.__queue.put(dict(event_name=event_name, payload=payload))
async def __dispatch_from_queue(self, stop_event: threading.Event):

View File

@ -2,6 +2,7 @@
import pathlib
import threading
from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response
@ -106,50 +107,43 @@ async def update_model(
"/import",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
200: {"description" : "The path was queued for import"},
},
status_code=201,
response_model=ImportModelResponse
status_code=200
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
) -> str:
"""
Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically.
This call launches a background thread to process the imported model and always succeeds. Results are reported in the background
as the following events:
- model_import_started(import_path)
- model_import_completed(import_path, AddModelResults)
- download_started(url)
- download_progress(url,downloaded_bytes,total_bytes)
- download_completed(url,status_code,download_path)
"""
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
events = ApiDependencies.invoker.services.events
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
import_thread = threading.Thread(target = ApiDependencies.invoker.services.model_manager.heuristic_import,
kwargs = dict(items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type),
event_bus = events,
)
)
import_thread.start()
return 'request queued'
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@models_router.post(
"/add",

View File

@ -1,9 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional
from pathlib import Path
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo, AddModelResult
class EventServiceBase:
session_event: str = "session_event"
@ -104,21 +106,15 @@ class EventServiceBase:
def emit_model_load_started (
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
self.dispatch(
event_name="model_load_started",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -128,9 +124,6 @@ class EventServiceBase:
def emit_model_load_completed(
self,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@ -138,12 +131,9 @@ class EventServiceBase:
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
self.dispatch(
event_name="model_load_completed",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
model_name=model_name,
base_model=base_model,
model_type=model_type,
@ -151,3 +141,92 @@ class EventServiceBase:
model_info=model_info,
),
)
def emit_model_import_started (
self,
import_path: str, # can be a local path, URL or repo_id
)->None:
"""Emitted when a model import commences"""
self.dispatch(
event_name="model_import_started",
payload=dict(
import_path = import_path,
)
)
def emit_model_import_completed (
self,
import_path: str, # can be a local path, URL or repo_id
import_info: AddModelResult,
success: bool= True,
error: str = None,
)->None:
"""Emitted when a model import completes"""
self.dispatch(
event_name="model_import_completed",
payload=dict(
import_path = import_path,
import_info = import_info,
success = success,
error = error,
)
)
def emit_download_started (
self,
url: str,
)->None:
"""Emitted when a download thread starts"""
self.dispatch(
event_name="download_started",
payload=dict(
url = url,
)
)
def emit_download_progress (
self,
url: str,
downloaded_size: int,
total_size: int,
)->None:
"""
Emitted at intervals during a download process
:param url: Requested URL
:param downloaded_size: Bytes downloaded so far
:param total_size: Total bytes to download
"""
self.dispatch(
event_name="download_progress",
payload=dict(
url = url,
downloaded_size = downloaded_size,
total_size = total_size,
)
)
def emit_download_completed (
self,
url: str,
status_code: int,
download_path: Path,
)->None:
"""
Emitted when a download thread completes.
:param url: Requested URL
:param status_code: HTTP status code from request
:param download_path: Path to downloaded file
"""
self.dispatch(
event_name="download_completed",
payload=dict(
url = url,
status_code = status_code,
download_path = download_path,
)
)

View File

@ -26,6 +26,7 @@ import torch
from invokeai.app.models.exceptions import CanceledException
from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig
from .events import EventServiceBase
if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -552,6 +553,7 @@ class ModelManagerService(ModelManagerServiceBase):
def heuristic_import(self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
event_bus: Optional[EventServiceBase]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
@ -569,7 +571,7 @@ class ModelManagerService(ModelManagerServiceBase):
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
return self.mgr.heuristic_import(items_to_import, prediction_type_helper, event_bus=event_bus)
def merge_models(
self,

View File

@ -89,13 +89,16 @@ class ModelInstall(object):
config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None,
access_token:str = None):
access_token:str = None,
event_bus = None, # EventServicesBase - getting circular import errors
):
self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets)
self.event_bus = event_bus
def all_models(self)->Dict[str,ModelLoadInfo]:
'''
@ -197,39 +200,63 @@ class ModelInstall(object):
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
'''
if self.event_bus:
self.event_bus.emit_model_import_started(str(model_path_id_or_url))
if not models_installed:
models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url)
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path):self._install_path(path)})
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
try:
# checkpoint file, or similar
if path.is_file():
models_installed.update({str(path):self._install_path(path)})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# folders style or similar
elif path.is_dir() and any([(path/x).exists() for x in \
{'config.json','model_index.json','learned_embeds.bin','pytorch_lora_weights.bin'}
]
):
models_installed.update(self._install_path(path))
# huggingface repo
elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
# recursive scan
elif path.is_dir():
for child in path.iterdir():
self.heuristic_import(child, models_installed=models_installed)
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
# huggingface repo
elif len(str(model_path_id_or_url).split('/')) == 2:
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
else:
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
# a URL
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else:
errmsg = f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping'
raise KeyError(errmsg)
if self.event_bus:
for path, add_model_result in models_installed.items():
self.event_bus.emit_model_import_completed(
str(path),
import_info = add_model_result,
)
except Exception as e:
if self.event_bus:
self.event_bus.emit_model_import_completed(
str(path),
import_info = None,
success = False,
error = str(e),
)
return models_installed
else:
raise
return models_installed
# install a model from a local path. The optional info parameter is there to prevent
@ -238,10 +265,14 @@ class ModelInstall(object):
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
if not info:
logger.warning(f'Unable to parse format of {path}')
return None
raise ValueError(f'Unable to parse format of {path}')
model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.')
errmsg = f'A model named "{model_name}" is already installed.'
raise ValueError(errmsg)
attributes = self._make_attributes(path,info)
return self.mgr.add_model(model_name = model_name,
base_model = info.base_type,
@ -251,7 +282,7 @@ class ModelInstall(object):
def _install_url(self, url: str)->AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging))
location = download_with_resume(url,Path(staging),event_bus=self.event_bus)
if not location:
logger.error(f'Unable to download {url}. Skipping.')
info = ModelProbe().heuristic_probe(location)
@ -384,7 +415,8 @@ class ModelInstall(object):
p = hf_download_with_resume(repo_id,
model_dir=location,
model_name=filename,
access_token = self.access_token
access_token = self.access_token,
event_bus = self.event_bus,
)
if p:
paths.append(p)
@ -425,12 +457,15 @@ def hf_download_from_pretrained(
return destination
# ---------------------------------------------
# TODO: This function is almost identical to invokeai.backend.util.download_with_resume
# and should be merged
def hf_download_with_resume(
repo_id: str,
model_dir: str,
model_name: str,
model_dest: Path = None,
access_token: str = None,
event_bus = None,
) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True)
@ -447,15 +482,22 @@ def hf_download_with_resume(
open_mode = "ab"
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0))
content_length = int(resp.headers.get("content-length", 0))
if event_bus:
event_bus.emit_download_started(url)
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,model_dest)
return model_dest
elif resp.status_code == 404:
logger.warning("File not found")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,None)
return None
elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}")
@ -464,11 +506,15 @@ def hf_download_with_resume(
else:
logger.info(f"{model_name}: Downloading...")
MB10 = 10 * 1048576
downloaded = exist_size
previous_interval = 0
try:
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
total=content_length + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
@ -476,9 +522,20 @@ def hf_download_with_resume(
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
downloaded += size
if event_bus and downloaded // MB10 > previous_interval:
previous_interval = downloaded // MB10
event_bus.emit_download_progress(url, downloaded, content_length)
except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
if event_bus:
event_bus.emit_download_completed(url,500,None)
return None
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,model_dest)
return model_dest

View File

@ -946,6 +946,7 @@ class ModelManager(object):
def heuristic_import(self,
items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
event_bus = None, # EventServiceBase, with circular dependency issues
)->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
@ -973,10 +974,13 @@ class ModelManager(object):
installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper,
model_manager = self)
model_manager = self,
event_bus = event_bus,
)
for thing in items_to_import:
installed = installer.heuristic_import(thing)
successfully_installed.update(installed)
self.commit()
return successfully_installed

View File

@ -21,7 +21,6 @@ from tqdm import tqdm
import invokeai.backend.util.logging as logger
from .devices import torch_dtype
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
@ -285,7 +284,11 @@ def ask_user(question: str, answers: list):
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
def download_with_resume(url: str,
dest: Path,
access_token: str = None,
event_bus = None, # EventServiceBase (circular import issues)
) -> Path:
"""
Download a model file.
:param url: https, http or ftp URL
@ -323,8 +326,13 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
os.remove(dest)
exist_size = 0
if event_bus:
event_bus.emit_download_started(url)
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,dest)
return dest
elif resp.status_code == 206 or exist_size > 0:
logger.warning(f"{dest}: partial file found. Resuming...")
@ -333,11 +341,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
else:
logger.info(f"{dest}: Downloading...")
try:
if content_length < 2000:
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
# If less than 2K, it's not a model - usually an HTML document of some sort
if content_length < 2000:
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
if event_bus:
event_bus.emit_download_completed(url, 500, None)
return None
# these variables are used in progress reporting events
MB10 = 10 * 1048576
downloaded = exist_size
previous_interval = 0
try:
with open(dest, open_mode) as file, tqdm(
desc=str(dest),
initial=exist_size,
@ -349,10 +366,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
downloaded += size
if event_bus and downloaded // MB10 > previous_interval:
previous_interval = downloaded // MB10
event_bus.emit_download_progress(url, downloaded, content_length)
except Exception as e:
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
if event_bus:
event_bus.emit_download_completed(url,500,None)
return None
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,dest)
return dest