mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
6 Commits
fix/diffus
...
lstein/thr
Author | SHA1 | Date | |
---|---|---|---|
d894c86db1 | |||
399d505801 | |||
ad312fa1ec | |||
80ce014b1e | |||
1fd053b42d | |||
da187d6a87 |
@ -9,7 +9,6 @@ from fastapi_events.dispatcher import dispatch
|
|||||||
|
|
||||||
from ..services.events import EventServiceBase
|
from ..services.events import EventServiceBase
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
event_handler_id: int
|
event_handler_id: int
|
||||||
__queue: Queue
|
__queue: Queue
|
||||||
@ -28,6 +27,9 @@ class FastAPIEventService(EventServiceBase):
|
|||||||
self.__queue.put(None)
|
self.__queue.put(None)
|
||||||
|
|
||||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||||
|
# TODO: Remove next two debugging lines
|
||||||
|
from .dependencies import ApiDependencies
|
||||||
|
ApiDependencies.invoker.services.logger.debug(f'dispatch {event_name} / {payload}')
|
||||||
self.__queue.put(dict(event_name=event_name, payload=payload))
|
self.__queue.put(dict(event_name=event_name, payload=payload))
|
||||||
|
|
||||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import threading
|
||||||
from typing import Literal, List, Optional, Union
|
from typing import Literal, List, Optional, Union
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
@ -127,54 +128,43 @@ async def update_model(
|
|||||||
"/import",
|
"/import",
|
||||||
operation_id="import_model",
|
operation_id="import_model",
|
||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
200: {"description" : "The path was queued for import"},
|
||||||
404: {"description" : "The model could not be found"},
|
|
||||||
415: {"description" : "Unrecognized file/folder format"},
|
|
||||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=200
|
||||||
response_model=ImportModelResponse
|
|
||||||
)
|
)
|
||||||
async def import_model(
|
async def import_model(
|
||||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||||
) -> ImportModelResponse:
|
) -> str:
|
||||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
"""
|
||||||
|
Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically.
|
||||||
|
This call launches a background thread to process the imported model and always succeeds. Results are reported in the background
|
||||||
|
as the following events:
|
||||||
|
- model_import_started(import_path:str)
|
||||||
|
- model_import_completed(import_path:str, import_info:AddModelResults, success:bool, error:str)
|
||||||
|
- download_started(url:str)
|
||||||
|
- download_progress(url:str, downloaded_bytes:int, total_bytes:int)
|
||||||
|
- download_completed(url:str, status_code:int, download_path:str)
|
||||||
|
"""
|
||||||
|
|
||||||
items_to_import = {location}
|
items_to_import = {location}
|
||||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
events = ApiDependencies.invoker.services.events
|
||||||
|
|
||||||
try:
|
try:
|
||||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
import_thread = threading.Thread(target = ApiDependencies.invoker.services.model_manager.heuristic_import,
|
||||||
items_to_import = items_to_import,
|
kwargs = dict(items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type),
|
||||||
|
event_bus = events,
|
||||||
)
|
)
|
||||||
info = installed_models.get(location)
|
|
||||||
|
|
||||||
if not info:
|
|
||||||
logger.error("Import failed")
|
|
||||||
raise HTTPException(status_code=415)
|
|
||||||
|
|
||||||
logger.info(f'Successfully imported {location}, got {info}')
|
|
||||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
|
||||||
model_name=info.name,
|
|
||||||
base_model=info.base_model,
|
|
||||||
model_type=info.model_type
|
|
||||||
)
|
)
|
||||||
return parse_obj_as(ImportModelResponse, model_raw)
|
import_thread.start()
|
||||||
|
return 'request queued'
|
||||||
except ModelNotFoundException as e:
|
except Exception as e:
|
||||||
logger.error(str(e))
|
logger.error(str(e))
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
except InvalidModelException as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=415)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
raise HTTPException(status_code=409, detail=str(e))
|
|
||||||
|
|
||||||
@models_router.post(
|
@models_router.post(
|
||||||
"/add",
|
"/add",
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.models.image import ProgressImage
|
from invokeai.app.models.image import ProgressImage
|
||||||
from invokeai.app.util.misc import get_timestamp
|
from invokeai.app.util.misc import get_timestamp
|
||||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo, AddModelResult
|
||||||
|
|
||||||
class EventServiceBase:
|
class EventServiceBase:
|
||||||
session_event: str = "session_event"
|
session_event: str = "session_event"
|
||||||
@ -111,7 +113,7 @@ class EventServiceBase:
|
|||||||
submodel: SubModelType,
|
submodel: SubModelType,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is requested"""
|
"""Emitted when a model is requested"""
|
||||||
self.__emit_session_event(
|
self.dispatch(
|
||||||
event_name="model_load_started",
|
event_name="model_load_started",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
@ -132,7 +134,7 @@ class EventServiceBase:
|
|||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||||
self.__emit_session_event(
|
self.dispatch(
|
||||||
event_name="model_load_completed",
|
event_name="model_load_completed",
|
||||||
payload=dict(
|
payload=dict(
|
||||||
graph_execution_state_id=graph_execution_state_id,
|
graph_execution_state_id=graph_execution_state_id,
|
||||||
@ -145,3 +147,92 @@ class EventServiceBase:
|
|||||||
precision=str(model_info.precision),
|
precision=str(model_info.precision),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def emit_model_import_started (
|
||||||
|
self,
|
||||||
|
import_path: str, # can be a local path, URL or repo_id
|
||||||
|
)->None:
|
||||||
|
"""Emitted when a model import commences"""
|
||||||
|
self.dispatch(
|
||||||
|
event_name="model_import_started",
|
||||||
|
payload=dict(
|
||||||
|
import_path = import_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_model_import_completed (
|
||||||
|
self,
|
||||||
|
import_path: str, # can be a local path, URL or repo_id
|
||||||
|
import_info: AddModelResult,
|
||||||
|
success: bool= True,
|
||||||
|
error: str = None,
|
||||||
|
|
||||||
|
)->None:
|
||||||
|
"""Emitted when a model import completes"""
|
||||||
|
self.dispatch(
|
||||||
|
event_name="model_import_completed",
|
||||||
|
payload=dict(
|
||||||
|
import_path = import_path,
|
||||||
|
import_info = import_info,
|
||||||
|
success = success,
|
||||||
|
error = error,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_download_started (
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
|
||||||
|
)->None:
|
||||||
|
"""Emitted when a download thread starts"""
|
||||||
|
self.dispatch(
|
||||||
|
event_name="download_started",
|
||||||
|
payload=dict(
|
||||||
|
url = url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_download_progress (
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
downloaded_size: int,
|
||||||
|
total_size: int,
|
||||||
|
)->None:
|
||||||
|
"""
|
||||||
|
Emitted at intervals during a download process
|
||||||
|
:param url: Requested URL
|
||||||
|
:param downloaded_size: Bytes downloaded so far
|
||||||
|
:param total_size: Total bytes to download
|
||||||
|
"""
|
||||||
|
self.dispatch(
|
||||||
|
event_name="download_progress",
|
||||||
|
payload=dict(
|
||||||
|
url = url,
|
||||||
|
downloaded_size = downloaded_size,
|
||||||
|
total_size = total_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit_download_completed (
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
status_code: int,
|
||||||
|
download_path: Path,
|
||||||
|
|
||||||
|
)->None:
|
||||||
|
"""
|
||||||
|
Emitted when a download thread completes.
|
||||||
|
:param url: Requested URL
|
||||||
|
:param status_code: HTTP status code from request
|
||||||
|
:param download_path: Path to downloaded file
|
||||||
|
"""
|
||||||
|
self.dispatch(
|
||||||
|
event_name="download_completed",
|
||||||
|
payload=dict(
|
||||||
|
url = url,
|
||||||
|
status_code = status_code,
|
||||||
|
download_path = download_path,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ import torch
|
|||||||
from invokeai.app.models.exceptions import CanceledException
|
from invokeai.app.models.exceptions import CanceledException
|
||||||
from ...backend.util import choose_precision, choose_torch_device
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
from .config import InvokeAIAppConfig
|
from .config import InvokeAIAppConfig
|
||||||
|
from .events import EventServiceBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
from ..invocations.baseinvocation import BaseInvocation, InvocationContext
|
||||||
@ -542,6 +543,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: set[str],
|
items_to_import: set[str],
|
||||||
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
|
||||||
|
event_bus: Optional[EventServiceBase]=None,
|
||||||
)->dict[str, AddModelResult]:
|
)->dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
@ -559,7 +561,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
of the set is a dict corresponding to the newly-created OmegaConf stanza for
|
||||||
that model.
|
that model.
|
||||||
'''
|
'''
|
||||||
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
|
return self.mgr.heuristic_import(items_to_import, prediction_type_helper, event_bus=event_bus)
|
||||||
|
|
||||||
def merge_models(
|
def merge_models(
|
||||||
self,
|
self,
|
||||||
|
@ -89,13 +89,16 @@ class ModelInstall(object):
|
|||||||
config:InvokeAIAppConfig,
|
config:InvokeAIAppConfig,
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
model_manager: ModelManager = None,
|
model_manager: ModelManager = None,
|
||||||
access_token:str = None):
|
access_token:str = None,
|
||||||
|
event_bus = None, # EventServicesBase - getting circular import errors
|
||||||
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
||||||
self.datasets = OmegaConf.load(Dataset_path)
|
self.datasets = OmegaConf.load(Dataset_path)
|
||||||
self.prediction_helper = prediction_type_helper
|
self.prediction_helper = prediction_type_helper
|
||||||
self.access_token = access_token or HfFolder.get_token()
|
self.access_token = access_token or HfFolder.get_token()
|
||||||
self.reverse_paths = self._reverse_paths(self.datasets)
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
||||||
|
self.event_bus = event_bus
|
||||||
|
|
||||||
def all_models(self)->Dict[str,ModelLoadInfo]:
|
def all_models(self)->Dict[str,ModelLoadInfo]:
|
||||||
'''
|
'''
|
||||||
@ -197,12 +200,17 @@ class ModelInstall(object):
|
|||||||
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
if self.event_bus:
|
||||||
|
self.event_bus.emit_model_import_started(str(model_path_id_or_url))
|
||||||
|
|
||||||
if not models_installed:
|
if not models_installed:
|
||||||
models_installed = dict()
|
models_installed = dict()
|
||||||
|
|
||||||
# A little hack to allow nested routines to retrieve info on the requested ID
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
||||||
self.current_id = model_path_id_or_url
|
self.current_id = model_path_id_or_url
|
||||||
path = Path(model_path_id_or_url)
|
path = Path(model_path_id_or_url)
|
||||||
|
|
||||||
|
try:
|
||||||
# checkpoint file, or similar
|
# checkpoint file, or similar
|
||||||
if path.is_file():
|
if path.is_file():
|
||||||
models_installed.update({str(path):self._install_path(path)})
|
models_installed.update({str(path):self._install_path(path)})
|
||||||
@ -228,7 +236,26 @@ class ModelInstall(object):
|
|||||||
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise KeyError(f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping')
|
errmsg = f'{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping'
|
||||||
|
raise KeyError(errmsg)
|
||||||
|
|
||||||
|
if self.event_bus:
|
||||||
|
for path, add_model_result in models_installed.items():
|
||||||
|
self.event_bus.emit_model_import_completed(
|
||||||
|
str(path),
|
||||||
|
import_info = add_model_result,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if self.event_bus:
|
||||||
|
self.event_bus.emit_model_import_completed(
|
||||||
|
str(path),
|
||||||
|
import_info = None,
|
||||||
|
success = False,
|
||||||
|
error = str(e),
|
||||||
|
)
|
||||||
|
return models_installed
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
return models_installed
|
return models_installed
|
||||||
|
|
||||||
@ -238,10 +265,14 @@ class ModelInstall(object):
|
|||||||
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
info = info or ModelProbe().heuristic_probe(path,self.prediction_helper)
|
||||||
if not info:
|
if not info:
|
||||||
logger.warning(f'Unable to parse format of {path}')
|
logger.warning(f'Unable to parse format of {path}')
|
||||||
return None
|
raise ValueError(f'Unable to parse format of {path}')
|
||||||
|
|
||||||
model_name = path.stem if path.is_file() else path.name
|
model_name = path.stem if path.is_file() else path.name
|
||||||
|
|
||||||
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
||||||
raise ValueError(f'A model named "{model_name}" is already installed.')
|
errmsg = f'A model named "{model_name}" is already installed.'
|
||||||
|
raise ValueError(errmsg)
|
||||||
|
|
||||||
attributes = self._make_attributes(path,info)
|
attributes = self._make_attributes(path,info)
|
||||||
return self.mgr.add_model(model_name = model_name,
|
return self.mgr.add_model(model_name = model_name,
|
||||||
base_model = info.base_type,
|
base_model = info.base_type,
|
||||||
@ -251,7 +282,7 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
def _install_url(self, url: str)->AddModelResult:
|
def _install_url(self, url: str)->AddModelResult:
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
location = download_with_resume(url,Path(staging))
|
location = download_with_resume(url,Path(staging),event_bus=self.event_bus)
|
||||||
if not location:
|
if not location:
|
||||||
logger.error(f'Unable to download {url}. Skipping.')
|
logger.error(f'Unable to download {url}. Skipping.')
|
||||||
info = ModelProbe().heuristic_probe(location)
|
info = ModelProbe().heuristic_probe(location)
|
||||||
@ -384,7 +415,8 @@ class ModelInstall(object):
|
|||||||
p = hf_download_with_resume(repo_id,
|
p = hf_download_with_resume(repo_id,
|
||||||
model_dir=location,
|
model_dir=location,
|
||||||
model_name=filename,
|
model_name=filename,
|
||||||
access_token = self.access_token
|
access_token = self.access_token,
|
||||||
|
event_bus = self.event_bus,
|
||||||
)
|
)
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
@ -425,12 +457,15 @@ def hf_download_from_pretrained(
|
|||||||
return destination
|
return destination
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
|
# TODO: This function is almost identical to invokeai.backend.util.download_with_resume
|
||||||
|
# and should be merged
|
||||||
def hf_download_with_resume(
|
def hf_download_with_resume(
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
model_dir: str,
|
model_dir: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_dest: Path = None,
|
model_dest: Path = None,
|
||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
|
event_bus = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
@ -447,15 +482,22 @@ def hf_download_with_resume(
|
|||||||
open_mode = "ab"
|
open_mode = "ab"
|
||||||
|
|
||||||
resp = requests.get(url, headers=header, stream=True)
|
resp = requests.get(url, headers=header, stream=True)
|
||||||
total = int(resp.headers.get("content-length", 0))
|
content_length = int(resp.headers.get("content-length", 0))
|
||||||
|
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_started(url)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
resp.status_code == 416
|
resp.status_code == 416
|
||||||
): # "range not satisfiable", which means nothing to return
|
): # "range not satisfiable", which means nothing to return
|
||||||
logger.info(f"{model_name}: complete file found. Skipping.")
|
logger.info(f"{model_name}: complete file found. Skipping.")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,resp.status_code,model_dest)
|
||||||
return model_dest
|
return model_dest
|
||||||
elif resp.status_code == 404:
|
elif resp.status_code == 404:
|
||||||
logger.warning("File not found")
|
logger.warning("File not found")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,resp.status_code,None)
|
||||||
return None
|
return None
|
||||||
elif resp.status_code != 200:
|
elif resp.status_code != 200:
|
||||||
logger.warning(f"{model_name}: {resp.reason}")
|
logger.warning(f"{model_name}: {resp.reason}")
|
||||||
@ -464,11 +506,15 @@ def hf_download_with_resume(
|
|||||||
else:
|
else:
|
||||||
logger.info(f"{model_name}: Downloading...")
|
logger.info(f"{model_name}: Downloading...")
|
||||||
|
|
||||||
|
MB10 = 10 * 1048576
|
||||||
|
downloaded = exist_size
|
||||||
|
previous_interval = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(model_dest, open_mode) as file, tqdm(
|
with open(model_dest, open_mode) as file, tqdm(
|
||||||
desc=model_name,
|
desc=model_name,
|
||||||
initial=exist_size,
|
initial=exist_size,
|
||||||
total=total + exist_size,
|
total=content_length + exist_size,
|
||||||
unit="iB",
|
unit="iB",
|
||||||
unit_scale=True,
|
unit_scale=True,
|
||||||
unit_divisor=1000,
|
unit_divisor=1000,
|
||||||
@ -476,9 +522,20 @@ def hf_download_with_resume(
|
|||||||
for data in resp.iter_content(chunk_size=1024):
|
for data in resp.iter_content(chunk_size=1024):
|
||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
|
downloaded += size
|
||||||
|
if event_bus and downloaded // MB10 > previous_interval:
|
||||||
|
previous_interval = downloaded // MB10
|
||||||
|
event_bus.emit_download_progress(url, downloaded, content_length)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,500,None)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,resp.status_code,model_dest)
|
||||||
|
|
||||||
return model_dest
|
return model_dest
|
||||||
|
|
||||||
|
|
||||||
|
@ -953,6 +953,7 @@ class ModelManager(object):
|
|||||||
def heuristic_import(self,
|
def heuristic_import(self,
|
||||||
items_to_import: Set[str],
|
items_to_import: Set[str],
|
||||||
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None,
|
||||||
|
event_bus = None, # EventServiceBase, with circular dependency issues
|
||||||
)->Dict[str, AddModelResult]:
|
)->Dict[str, AddModelResult]:
|
||||||
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
'''Import a list of paths, repo_ids or URLs. Returns the set of
|
||||||
successfully imported items.
|
successfully imported items.
|
||||||
@ -980,10 +981,13 @@ class ModelManager(object):
|
|||||||
|
|
||||||
installer = ModelInstall(config = self.app_config,
|
installer = ModelInstall(config = self.app_config,
|
||||||
prediction_type_helper = prediction_type_helper,
|
prediction_type_helper = prediction_type_helper,
|
||||||
model_manager = self)
|
model_manager = self,
|
||||||
|
event_bus = event_bus,
|
||||||
|
)
|
||||||
for thing in items_to_import:
|
for thing in items_to_import:
|
||||||
installed = installer.heuristic_import(thing)
|
installed = installer.heuristic_import(thing)
|
||||||
successfully_installed.update(installed)
|
successfully_installed.update(installed)
|
||||||
|
|
||||||
self.commit()
|
self.commit()
|
||||||
return successfully_installed
|
return successfully_installed
|
||||||
|
|
||||||
|
@ -21,7 +21,6 @@ from tqdm import tqdm
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from .devices import torch_dtype
|
from .devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
def log_txt_as_img(wh, xc, size=10):
|
def log_txt_as_img(wh, xc, size=10):
|
||||||
# wh a tuple of (width, height)
|
# wh a tuple of (width, height)
|
||||||
# xc a list of captions to plot
|
# xc a list of captions to plot
|
||||||
@ -285,7 +284,11 @@ def ask_user(question: str, answers: list):
|
|||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
|
def download_with_resume(url: str,
|
||||||
|
dest: Path,
|
||||||
|
access_token: str = None,
|
||||||
|
event_bus = None, # EventServiceBase (circular import issues)
|
||||||
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Download a model file.
|
Download a model file.
|
||||||
:param url: https, http or ftp URL
|
:param url: https, http or ftp URL
|
||||||
@ -323,8 +326,13 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
os.remove(dest)
|
os.remove(dest)
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_started(url)
|
||||||
|
|
||||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||||
logger.warning(f"{dest}: complete file found. Skipping.")
|
logger.warning(f"{dest}: complete file found. Skipping.")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,resp.status_code,dest)
|
||||||
return dest
|
return dest
|
||||||
elif resp.status_code == 206 or exist_size > 0:
|
elif resp.status_code == 206 or exist_size > 0:
|
||||||
logger.warning(f"{dest}: partial file found. Resuming...")
|
logger.warning(f"{dest}: partial file found. Resuming...")
|
||||||
@ -333,11 +341,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
else:
|
else:
|
||||||
logger.info(f"{dest}: Downloading...")
|
logger.info(f"{dest}: Downloading...")
|
||||||
|
|
||||||
try:
|
# If less than 2K, it's not a model - usually an HTML document of some sort
|
||||||
if content_length < 2000:
|
if content_length < 2000:
|
||||||
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
logger.error(f"ERROR DOWNLOADING {url}: {resp.text}")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url, 500, None)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# these variables are used in progress reporting events
|
||||||
|
MB10 = 10 * 1048576
|
||||||
|
downloaded = exist_size
|
||||||
|
previous_interval = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
with open(dest, open_mode) as file, tqdm(
|
with open(dest, open_mode) as file, tqdm(
|
||||||
desc=str(dest),
|
desc=str(dest),
|
||||||
initial=exist_size,
|
initial=exist_size,
|
||||||
@ -349,10 +366,20 @@ def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path
|
|||||||
for data in resp.iter_content(chunk_size=1024):
|
for data in resp.iter_content(chunk_size=1024):
|
||||||
size = file.write(data)
|
size = file.write(data)
|
||||||
bar.update(size)
|
bar.update(size)
|
||||||
|
downloaded += size
|
||||||
|
if event_bus and downloaded // MB10 > previous_interval:
|
||||||
|
previous_interval = downloaded // MB10
|
||||||
|
event_bus.emit_download_progress(url, downloaded, content_length)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
logger.error(f"An error occurred while downloading {dest}: {str(e)}")
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,500,None)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if event_bus:
|
||||||
|
event_bus.emit_download_completed(url,resp.status_code,dest)
|
||||||
|
|
||||||
return dest
|
return dest
|
||||||
|
|
||||||
|
|
||||||
|
@ -144,4 +144,19 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model import started
|
||||||
|
*/
|
||||||
|
socket.on('model_import_started', (import_path) => {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: 'Importing model ${import_path}',
|
||||||
|
status: 'info',
|
||||||
|
duration: 10000,
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
});
|
||||||
};
|
};
|
||||||
|
Reference in New Issue
Block a user