Compare commits

...

6 Commits

8 changed files with 262 additions and 74 deletions

View File

@ -9,7 +9,6 @@ from fastapi_events.dispatcher import dispatch
from ..services.events import EventServiceBase from ..services.events import EventServiceBase
class FastAPIEventService(EventServiceBase): class FastAPIEventService(EventServiceBase):
event_handler_id: int event_handler_id: int
__queue: Queue __queue: Queue
@ -28,6 +27,9 @@ class FastAPIEventService(EventServiceBase):
self.__queue.put(None) self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
# TODO: Remove next two debugging lines
from .dependencies import ApiDependencies
ApiDependencies.invoker.services.logger.debug(f'dispatch {event_name} / {payload}')
self.__queue.put(dict(event_name=event_name, payload=payload)) self.__queue.put(dict(event_name=event_name, payload=payload))
async def __dispatch_from_queue(self, stop_event: threading.Event): async def __dispatch_from_queue(self, stop_event: threading.Event):

View File

@ -2,6 +2,7 @@
import pathlib import pathlib
import threading
from typing import Literal, List, Optional, Union from typing import Literal, List, Optional, Union
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
@ -127,54 +128,43 @@ async def update_model(
"/import", "/import",
operation_id="import_model", operation_id="import_model",
responses= { responses= {
201: {"description" : "The model imported successfully"}, 200: {"description" : "The path was queued for import"},
404: {"description" : "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=200
response_model=ImportModelResponse
) )
async def import_model( async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"), location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \ prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"), Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse: ) -> str:
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """ """
Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically.
This call launches a background thread to process the imported model and always succeeds. Results are reported in the background
as the following events:
- model_import_started(import_path:str)
- model_import_completed(import_path:str, import_info:AddModelResults, success:bool, error:str)
- download_started(url:str)
- download_progress(url:str, downloaded_bytes:int, total_bytes:int)
- download_completed(url:str, status_code:int, download_path:str)
"""
items_to_import = {location} items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType } prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
events = ApiDependencies.invoker.services.events
try: try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import( import_thread = threading.Thread(target = ApiDependencies.invoker.services.model_manager.heuristic_import,
items_to_import = items_to_import, kwargs = dict(items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type),
event_bus = events,
) )
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) import_thread.start()
return 'request queued'
except ModelNotFoundException as e: except Exception as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
except InvalidModelException as e:
logger.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post( @models_router.post(
"/add", "/add",

View File

@ -1,9 +1,11 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional from typing import Any, Optional
from pathlib import Path
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo, AddModelResult
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -111,7 +113,7 @@ class EventServiceBase:
submodel: SubModelType, submodel: SubModelType,
) -> None: ) -> None:
"""Emitted when a model is requested""" """Emitted when a model is requested"""
self.__emit_session_event( self.dispatch(
event_name="model_load_started", event_name="model_load_started",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
@ -132,7 +134,7 @@ class EventServiceBase:
model_info: ModelInfo, model_info: ModelInfo,
) -> None: ) -> None:
"""Emitted when a model is correctly loaded (returns model info)""" """Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event( self.dispatch(
event_name="model_load_completed", event_name="model_load_completed",
payload=dict( payload=dict(
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
@ -145,3 +147,92 @@ class EventServiceBase:
precision=str(model_info.precision), precision=str(model_info.precision),
), ),
) )
def emit_model_import_started (
self,
import_path: str, # can be a local path, URL or repo_id
)->None:
"""Emitted when a model import commences"""
self.dispatch(
event_name="model_import_started",
payload=dict(
import_path = import_path,
)
)
def emit_model_import_completed (
self,
import_path: str, # can be a local path, URL or repo_id
import_info: AddModelResult,
success: bool= True,
error: str = None,
)->None:
"""Emitted when a model import completes"""
self.dispatch(
event_name="model_import_completed",
payload=dict(
import_path = import_path,
import_info = import_info,
success = success,
error = error,
)
)
def emit_download_started (
self,
url: str,
)->None:
"""Emitted when a download thread starts"""
self.dispatch(
event_name="download_started",
payload=dict(
url = url,
)
)
def emit_download_progress (
self,
url: str,
downloaded_size: int,
total_size: int,
)->None:
"""
Emitted at intervals during a download process
:param url: Requested URL
:param downloaded_size: Bytes downloaded so far
:param total_size: Total bytes to download
"""
self.dispatch(
event_name="download_progress",
payload=dict(
url = url,
downloaded_size = downloaded_size,
total_size = total_size,
)
)
def emit_download_completed (
self,
url: str,
status_code: int,
download_path: Path,
)->None:
"""
Emitted when a download thread completes.
:param url: Requested URL
:param status_code: HTTP status code from request
:param download_path: Path to downloaded file
"""
self.dispatch(
event_name="download_completed",
payload=dict(
url = url,
status_code = status_code,
download_path = download_path,
)
)

View File

@ -26,6 +26,7 @@ import torch
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from ...backend.util import choose_precision, choose_torch_device from ...backend.util import choose_precision, choose_torch_device
from .config import InvokeAIAppConfig from .config import InvokeAIAppConfig
from .events import EventServiceBase
if TYPE_CHECKING: if TYPE_CHECKING:
from ..invocations.baseinvocation import BaseInvocation, InvocationContext from ..invocations.baseinvocation import BaseInvocation, InvocationContext
@ -542,6 +543,7 @@ class ModelManagerService(ModelManagerServiceBase):
def heuristic_import(self, def heuristic_import(self,
items_to_import: set[str], items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None, prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
event_bus: Optional[EventServiceBase]=None,
)->dict[str, AddModelResult]: )->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
@ -559,7 +561,7 @@ class ModelManagerService(ModelManagerServiceBase):
of the set is a dict corresponding to the newly-created OmegaConf stanza for of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model. that model.
''' '''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper) return self.mgr.heuristic_import(items_to_import, prediction_type_helper, event_bus=event_bus)
def merge_models( def merge_models(
self, self,

View File

@ -89,13 +89,16 @@ class ModelInstall(object):
config:InvokeAIAppConfig, config:InvokeAIAppConfig,
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
model_manager: ModelManager = None, model_manager: ModelManager = None,
access_token:str = None): access_token:str = None,
event_bus = None, # EventServicesBase - getting circular import errors
):
self.config = config self.config = config
self.mgr = model_manager or ModelManager(config.model_conf_path) self.mgr = model_manager or ModelManager(config.model_conf_path)
self.datasets = OmegaConf.load(Dataset_path) self.datasets = OmegaConf.load(Dataset_path)
self.prediction_helper = prediction_type_helper self.prediction_helper = prediction_type_helper
self.access_token = access_token or HfFolder.get_token() self.access_token = access_token or HfFolder.get_token()
self.reverse_paths = self._reverse_paths(self.datasets) self.reverse_paths = self._reverse_paths(self.datasets)
self.event_bus = event_bus
def all_models(self)->Dict[str,ModelLoadInfo]: def all_models(self)->Dict[str,ModelLoadInfo]:
''' '''
@ -197,12 +200,17 @@ class ModelInstall(object):
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
''' '''
if self.event_bus:
self.event_bus.emit_model_import_started(str(model_path_id_or_url))
if not models_installed: if not models_installed:
models_installed = dict() models_installed = dict()
# A little hack to allow nested routines to retrieve info on the requested ID # A little hack to allow nested routines to retrieve info on the requested ID
self.current_id = model_path_id_or_url self.current_id = model_path_id_or_url
path = Path(model_path_id_or_url) path = Path(model_path_id_or_url)
try:
# checkpoint file, or similar # checkpoint file, or similar
if path.is_file(): if path.is_file():
models_installed.update({str(path):self._install_path(path)}) models_installed.update({str(path):self._install_path(path)})
@ -228,7 +236,26 @@ class ModelInstall(object):
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
else: else:
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping') errmsg = f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping'
raise KeyError(errmsg)
if self.event_bus:
for path, add_model_result in models_installed.items():
self.event_bus.emit_model_import_completed(
str(path),
import_info = add_model_result,
)
except Exception as e:
if self.event_bus:
self.event_bus.emit_model_import_completed(
str(path),
import_info = None,
success = False,
error = str(e),
)
return models_installed
else:
raise
return models_installed return models_installed
@ -238,10 +265,14 @@ class ModelInstall(object):
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
if not info: if not info:
logger.warning(f'Unable to parse format of {path}') logger.warning(f'Unable to parse format of {path}')
return None raise ValueError(f'Unable to parse format of {path}')
model_name = path.stem if path.is_file() else path.name model_name = path.stem if path.is_file() else path.name
if self.mgr.model_exists(model_name, info.base_type, info.model_type): if self.mgr.model_exists(model_name, info.base_type, info.model_type):
raise ValueError(f'A model named "{model_name}" is already installed.') errmsg = f'A model named "{model_name}" is already installed.'
raise ValueError(errmsg)
attributes = self._make_attributes(path,info) attributes = self._make_attributes(path,info)
return self.mgr.add_model(model_name = model_name, return self.mgr.add_model(model_name = model_name,
base_model = info.base_type, base_model = info.base_type,
@ -251,7 +282,7 @@ class ModelInstall(object):
def _install_url(self, url: str)->AddModelResult: def _install_url(self, url: str)->AddModelResult:
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
location = download_with_resume(url,Path(staging)) location = download_with_resume(url,Path(staging),event_bus=self.event_bus)
if not location: if not location:
logger.error(f'Unable to download {url}. Skipping.') logger.error(f'Unable to download {url}. Skipping.')
info = ModelProbe().heuristic_probe(location) info = ModelProbe().heuristic_probe(location)
@ -384,7 +415,8 @@ class ModelInstall(object):
p = hf_download_with_resume(repo_id, p = hf_download_with_resume(repo_id,
model_dir=location, model_dir=location,
model_name=filename, model_name=filename,
access_token = self.access_token access_token = self.access_token,
event_bus = self.event_bus,
) )
if p: if p:
paths.append(p) paths.append(p)
@ -425,12 +457,15 @@ def hf_download_from_pretrained(
return destination return destination
# --------------------------------------------- # ---------------------------------------------
# TODO: This function is almost identical to invokeai.backend.util.download_with_resume
# and should be merged
def hf_download_with_resume( def hf_download_with_resume(
repo_id: str, repo_id: str,
model_dir: str, model_dir: str,
model_name: str, model_name: str,
model_dest: Path = None, model_dest: Path = None,
access_token: str = None, access_token: str = None,
event_bus = None,
) -> Path: ) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name)) model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
@ -447,15 +482,22 @@ def hf_download_with_resume(
open_mode = "ab" open_mode = "ab"
resp = requests.get(url, headers=header, stream=True) resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get("content-length", 0)) content_length = int(resp.headers.get("content-length", 0))
if event_bus:
event_bus.emit_download_started(url)
if ( if (
resp.status_code == 416 resp.status_code == 416
): # "range not satisfiable", which means nothing to return ): # "range not satisfiable", which means nothing to return
logger.info(f"{model_name}: complete file found. Skipping.") logger.info(f"{model_name}: complete file found. Skipping.")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,model_dest)
return model_dest return model_dest
elif resp.status_code == 404: elif resp.status_code == 404:
logger.warning("File not found") logger.warning("File not found")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,None)
return None return None
elif resp.status_code != 200: elif resp.status_code != 200:
logger.warning(f"{model_name}: {resp.reason}") logger.warning(f"{model_name}: {resp.reason}")
@ -464,11 +506,15 @@ def hf_download_with_resume(
else: else:
logger.info(f"{model_name}: Downloading...") logger.info(f"{model_name}: Downloading...")
MB10 = 10 * 1048576
downloaded = exist_size
previous_interval = 0
try: try:
with open(model_dest, open_mode) as file, tqdm( with open(model_dest, open_mode) as file, tqdm(
desc=model_name, desc=model_name,
initial=exist_size, initial=exist_size,
total=total + exist_size, total=content_length + exist_size,
unit="iB", unit="iB",
unit_scale=True, unit_scale=True,
unit_divisor=1000, unit_divisor=1000,
@ -476,9 +522,20 @@ def hf_download_with_resume(
for data in resp.iter_content(chunk_size=1024): for data in resp.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)
downloaded += size
if event_bus and downloaded // MB10 > previous_interval:
previous_interval = downloaded // MB10
event_bus.emit_download_progress(url, downloaded, content_length)
except Exception as e: except Exception as e:
logger.error(f"An error occurred while downloading {model_name}: {str(e)}") logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
if event_bus:
event_bus.emit_download_completed(url,500,None)
return None return None
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,model_dest)
return model_dest return model_dest

View File

@ -953,6 +953,7 @@ class ModelManager(object):
def heuristic_import(self, def heuristic_import(self,
items_to_import: Set[str], items_to_import: Set[str],
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
event_bus = None, # EventServiceBase, with circular dependency issues
)->Dict[str, AddModelResult]: )->Dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of '''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items. successfully imported items.
@ -980,10 +981,13 @@ class ModelManager(object):
installer = ModelInstall(config = self.app_config, installer = ModelInstall(config = self.app_config,
prediction_type_helper = prediction_type_helper, prediction_type_helper = prediction_type_helper,
model_manager = self) model_manager = self,
event_bus = event_bus,
)
for thing in items_to_import: for thing in items_to_import:
installed = installer.heuristic_import(thing) installed = installer.heuristic_import(thing)
successfully_installed.update(installed) successfully_installed.update(installed)
self.commit() self.commit()
return successfully_installed return successfully_installed

View File

@ -21,7 +21,6 @@ from tqdm import tqdm
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .devices import torch_dtype from .devices import torch_dtype
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot
@ -285,7 +284,11 @@ def ask_user(question: str, answers: list):
# ------------------------------------- # -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path: def download_with_resume(url: str,
dest: Path,
access_token: str = None,
event_bus = None, # EventServiceBase (circular import issues)
) -> Path:
""" """
Download a model file. Download a model file.
:param url: https, http or ftp URL :param url: https, http or ftp URL
@ -323,8 +326,13 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
os.remove(dest) os.remove(dest)
exist_size = 0 exist_size = 0
if event_bus:
event_bus.emit_download_started(url)
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length): if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
logger.warning(f"{dest}: complete file found. Skipping.") logger.warning(f"{dest}: complete file found. Skipping.")
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,dest)
return dest return dest
elif resp.status_code == 206 or exist_size > 0: elif resp.status_code == 206 or exist_size > 0:
logger.warning(f"{dest}: partial file found. Resuming...") logger.warning(f"{dest}: partial file found. Resuming...")
@ -333,11 +341,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
else: else:
logger.info(f"{dest}: Downloading...") logger.info(f"{dest}: Downloading...")
try: # If less than 2K, it's not a model - usually an HTML document of some sort
if content_length < 2000: if content_length < 2000:
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}") logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
if event_bus:
event_bus.emit_download_completed(url, 500, None)
return None return None
# these variables are used in progress reporting events
MB10 = 10 * 1048576
downloaded = exist_size
previous_interval = 0
try:
with open(dest, open_mode) as file, tqdm( with open(dest, open_mode) as file, tqdm(
desc=str(dest), desc=str(dest),
initial=exist_size, initial=exist_size,
@ -349,10 +366,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
for data in resp.iter_content(chunk_size=1024): for data in resp.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)
downloaded += size
if event_bus and downloaded // MB10 > previous_interval:
previous_interval = downloaded // MB10
event_bus.emit_download_progress(url, downloaded, content_length)
except Exception as e: except Exception as e:
logger.error(f"An error occurred while downloading {dest}: {str(e)}") logger.error(f"An error occurred while downloading {dest}: {str(e)}")
if event_bus:
event_bus.emit_download_completed(url,500,None)
return None return None
if event_bus:
event_bus.emit_download_completed(url,resp.status_code,dest)
return dest return dest

View File

@ -144,4 +144,19 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}) })
); );
}); });
/**
* Model import started
*/
socket.on('model_import_started', (import_path) => {
dispatch(
addToast(
makeToast({
title: 'Importing model ${import_path}',
status: 'info',
duration: 10000,
})
)
);
});
}; };