mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
API model downloads are now threaded and generate progress events
This commit is contained in:
parent
1d3fda80aa
commit
da187d6a87
@ -9,7 +9,6 @@ from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
@ -28,6 +27,9 @@ class FastAPIEventService(EventServiceBase):
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
# TODO: Remove next two debugging lines
|
||||
from .dependencies import ApiDependencies
|
||||
ApiDependencies.invoker.services.logger.debug(f'dispatch {event_name} / {payload}')
|
||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
|
||||
import pathlib
|
||||
import threading
|
||||
from typing import Literal, List, Optional, Union
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
@ -106,50 +107,43 @@ async def update_model(
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
200: {"description" : "The path was queued for import"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
status_code=200
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
) -> str:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically.
|
||||
This call launches a background thread to process the imported model and always succeeds. Results are reported in the background
|
||||
as the following events:
|
||||
- model_import_started(import_path)
|
||||
- model_import_completed(import_path, AddModelResults)
|
||||
- download_started(url)
|
||||
- download_progress(url,downloaded_bytes,total_bytes)
|
||||
- download_completed(url,status_code,download_path)
|
||||
"""
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
events = ApiDependencies.invoker.services.events
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
import_thread = threading.Thread(target = ApiDependencies.invoker.services.model_manager.heuristic_import,
|
||||
kwargs = dict(items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type),
|
||||
event_bus = events,
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
except ModelNotFoundException as e:
|
||||
import_thread.start()
|
||||
return 'request queued'
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
|
@ -1,9 +1,11 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo, AddModelResult
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -104,21 +106,15 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_load_started (
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
self.dispatch(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -128,9 +124,6 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
@ -138,12 +131,9 @@ class EventServiceBase:
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
self.dispatch(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
@ -151,3 +141,92 @@ class EventServiceBase:
|
||||
model_info=model_info,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_import_started (
|
||||
self,
|
||||
import_path: str, # can be a local path, URL or repo_id
|
||||
)->None:
|
||||
"""Emitted when a model import commences"""
|
||||
self.dispatch(
|
||||
event_name="model_import_started",
|
||||
payload=dict(
|
||||
import_path = import_path,
|
||||
)
|
||||
)
|
||||
|
||||
def emit_model_import_completed (
|
||||
self,
|
||||
import_path: str, # can be a local path, URL or repo_id
|
||||
import_info: AddModelResult,
|
||||
success: bool= True,
|
||||
error: str = None,
|
||||
|
||||
)->None:
|
||||
"""Emitted when a model import completes"""
|
||||
self.dispatch(
|
||||
event_name="model_import_completed",
|
||||
payload=dict(
|
||||
import_path = import_path,
|
||||
import_info = import_info,
|
||||
success = success,
|
||||
error = error,
|
||||
)
|
||||
)
|
||||
|
||||
def emit_download_started (
|
||||
self,
|
||||
url: str,
|
||||
|
||||
)->None:
|
||||
"""Emitted when a download thread starts"""
|
||||
self.dispatch(
|
||||
event_name="download_started",
|
||||
payload=dict(
|
||||
url = url,
|
||||
)
|
||||
)
|
||||
|
||||
def emit_download_progress (
|
||||
self,
|
||||
url: str,
|
||||
downloaded_size: int,
|
||||
total_size: int,
|
||||
)->None:
|
||||
"""
|
||||
Emitted at intervals during a download process
|
||||
:param url: Requested URL
|
||||
:param downloaded_size: Bytes downloaded so far
|
||||
:param total_size: Total bytes to download
|
||||
"""
|
||||
self.dispatch(
|
||||
event_name="download_progress",
|
||||
payload=dict(
|
||||
url = url,
|
||||
downloaded_size = downloaded_size,
|
||||
total_size = total_size,
|
||||
)
|
||||
)
|
||||
|
||||
def emit_download_completed (
|
||||
self,
|
||||
url: str,
|
||||
status_code: int,
|
||||
download_path: Path,
|
||||
|
||||
)->None:
|
||||
"""
|
||||
Emitted when a download thread completes.
|
||||
:param url: Requested URL
|
||||
:param status_code: HTTP status code from request
|
||||
:param download_path: Path to downloaded file
|
||||
"""
|
||||
self.dispatch(
|
||||
event_name="download_completed",
|
||||
payload=dict(
|
||||
url = url,
|
||||
status_code = status_code,
|
||||
download_path = download_path,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ import torch
|
||||
from invokeai.app.models.exceptions import CanceledException
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from .config import InvokeAIAppConfig
|
||||
from .events import EventServiceBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||
@ -552,6 +553,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def heuristic_import(self,
|
||||
items_to_import: set[str],
|
||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||
event_bus: Optional[EventServiceBase]=None,
|
||||
)->dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
@ -569,7 +571,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||
that model.
|
||||
'''
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper, event_bus=event_bus)
|
||||
|
||||
def merge_models(
|
||||
self,
|
||||
|
@ -89,13 +89,16 @@ class ModelInstall(object):
|
||||
config:InvokeAIAppConfig,
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
model_manager: ModelManager = None,
|
||||
access_token:str = None):
|
||||
access_token:str = None,
|
||||
event_bus = None, # EventServicesBase - getting circular import errors
|
||||
):
|
||||
self.config = config
|
||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||
self.datasets = OmegaConf.load(Dataset_path)
|
||||
self.prediction_helper = prediction_type_helper
|
||||
self.access_token = access_token or HfFolder.get_token()
|
||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||
self.event_bus = event_bus
|
||||
|
||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||
'''
|
||||
@ -197,12 +200,17 @@ class ModelInstall(object):
|
||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||
'''
|
||||
|
||||
if self.event_bus:
|
||||
self.event_bus.emit_model_import_started(str(model_path_id_or_url))
|
||||
|
||||
if not models_installed:
|
||||
models_installed = dict()
|
||||
|
||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||
self.current_id = model_path_id_or_url
|
||||
path = Path(model_path_id_or_url)
|
||||
|
||||
try:
|
||||
# checkpoint file, or similar
|
||||
if path.is_file():
|
||||
models_installed.update({str(path):self._install_path(path)})
|
||||
@ -228,7 +236,26 @@ class ModelInstall(object):
|
||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||
|
||||
else:
|
||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
||||
errmsg = f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping'
|
||||
raise KeyError(errmsg)
|
||||
|
||||
if self.event_bus:
|
||||
for path, add_model_result in models_installed.items():
|
||||
self.event_bus.emit_model_import_completed(
|
||||
str(path),
|
||||
import_info = add_model_result,
|
||||
)
|
||||
except Exception as e:
|
||||
if self.event_bus:
|
||||
self.event_bus.emit_model_import_completed(
|
||||
str(path),
|
||||
import_info = None,
|
||||
success = False,
|
||||
error = str(e),
|
||||
)
|
||||
return models_installed
|
||||
else:
|
||||
raise
|
||||
|
||||
return models_installed
|
||||
|
||||
@ -238,10 +265,14 @@ class ModelInstall(object):
|
||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||
if not info:
|
||||
logger.warning(f'Unable to parse format of {path}')
|
||||
return None
|
||||
raise ValueError(f'Unable to parse format of {path}')
|
||||
|
||||
model_name = path.stem if path.is_file() else path.name
|
||||
|
||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
||||
errmsg = f'A model named "{model_name}" is already installed.'
|
||||
raise ValueError(errmsg)
|
||||
|
||||
attributes = self._make_attributes(path,info)
|
||||
return self.mgr.add_model(model_name = model_name,
|
||||
base_model = info.base_type,
|
||||
@ -251,7 +282,7 @@ class ModelInstall(object):
|
||||
|
||||
def _install_url(self, url: str)->AddModelResult:
|
||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||
location = download_with_resume(url,Path(staging))
|
||||
location = download_with_resume(url,Path(staging),event_bus=self.event_bus)
|
||||
if not location:
|
||||
logger.error(f'Unable to download {url}. Skipping.')
|
||||
info = ModelProbe().heuristic_probe(location)
|
||||
@ -384,7 +415,8 @@ class ModelInstall(object):
|
||||
p = hf_download_with_resume(repo_id,
|
||||
model_dir=location,
|
||||
model_name=filename,
|
||||
access_token = self.access_token
|
||||
access_token = self.access_token,
|
||||
event_bus = self.event_bus,
|
||||
)
|
||||
if p:
|
||||
paths.append(p)
|
||||
@ -425,12 +457,15 @@ def hf_download_from_pretrained(
|
||||
return destination
|
||||
|
||||
# ---------------------------------------------
|
||||
# TODO: This function is almost identical to invokeai.backend.util.download_with_resume
|
||||
# and should be merged
|
||||
def hf_download_with_resume(
|
||||
repo_id: str,
|
||||
model_dir: str,
|
||||
model_name: str,
|
||||
model_dest: Path = None,
|
||||
access_token: str = None,
|
||||
event_bus = None,
|
||||
) -> Path:
|
||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
@ -447,15 +482,22 @@ def hf_download_with_resume(
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if event_bus:
|
||||
event_bus.emit_download_started(url)
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,resp.status_code,model_dest)
|
||||
return model_dest
|
||||
elif resp.status_code == 404:
|
||||
logger.warning("File not found")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,resp.status_code,None)
|
||||
return None
|
||||
elif resp.status_code != 200:
|
||||
logger.warning(f"{model_name}: {resp.reason}")
|
||||
@ -464,11 +506,15 @@ def hf_download_with_resume(
|
||||
else:
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
MB10 = 10 * 1048576
|
||||
downloaded = exist_size
|
||||
previous_interval = 0
|
||||
|
||||
try:
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
total=content_length + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
@ -476,9 +522,20 @@ def hf_download_with_resume(
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
downloaded += size
|
||||
if event_bus and downloaded // MB10 > previous_interval:
|
||||
previous_interval = downloaded // MB10
|
||||
event_bus.emit_download_progress(url, downloaded, content_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,500,None)
|
||||
return None
|
||||
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,resp.status_code,model_dest)
|
||||
|
||||
return model_dest
|
||||
|
||||
|
||||
|
@ -946,6 +946,7 @@ class ModelManager(object):
|
||||
def heuristic_import(self,
|
||||
items_to_import: Set[str],
|
||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||
event_bus = None, # EventServiceBase, with circular dependency issues
|
||||
)->Dict[str, AddModelResult]:
|
||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||
successfully imported items.
|
||||
@ -973,10 +974,13 @@ class ModelManager(object):
|
||||
|
||||
installer = ModelInstall(config = self.app_config,
|
||||
prediction_type_helper = prediction_type_helper,
|
||||
model_manager = self)
|
||||
model_manager = self,
|
||||
event_bus = event_bus,
|
||||
)
|
||||
for thing in items_to_import:
|
||||
installed = installer.heuristic_import(thing)
|
||||
successfully_installed.update(installed)
|
||||
|
||||
self.commit()
|
||||
return successfully_installed
|
||||
|
||||
|
@ -21,7 +21,6 @@ from tqdm import tqdm
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .devices import torch_dtype
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
@ -285,7 +284,11 @@ def ask_user(question: str, answers: list):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
||||
def download_with_resume(url: str,
|
||||
dest: Path,
|
||||
access_token: str = None,
|
||||
event_bus = None, # EventServiceBase (circular import issues)
|
||||
) -> Path:
|
||||
"""
|
||||
Download a model file.
|
||||
:param url: https, http or ftp URL
|
||||
@ -323,8 +326,13 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if event_bus:
|
||||
event_bus.emit_download_started(url)
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,resp.status_code,dest)
|
||||
return dest
|
||||
elif resp.status_code == 206 or exist_size > 0:
|
||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
@ -333,11 +341,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
else:
|
||||
logger.info(f"{dest}: Downloading...")
|
||||
|
||||
try:
|
||||
# If less than 2K, it's not a model - usually an HTML document of some sort
|
||||
if content_length < 2000:
|
||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url, 500, None)
|
||||
return None
|
||||
|
||||
# these variables are used in progress reporting events
|
||||
MB10 = 10 * 1048576
|
||||
downloaded = exist_size
|
||||
previous_interval = 0
|
||||
|
||||
try:
|
||||
|
||||
with open(dest, open_mode) as file, tqdm(
|
||||
desc=str(dest),
|
||||
initial=exist_size,
|
||||
@ -349,10 +366,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
downloaded += size
|
||||
if event_bus and downloaded // MB10 > previous_interval:
|
||||
previous_interval = downloaded // MB10
|
||||
event_bus.emit_download_progress(url, downloaded, content_length)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,500,None)
|
||||
return None
|
||||
|
||||
if event_bus:
|
||||
event_bus.emit_download_completed(url,resp.status_code,dest)
|
||||
|
||||
return dest
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user