mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
TUI installer functional; minor cosmetic work needed
This commit is contained in:
@ -36,6 +36,7 @@ class UnknownJobIDException(Exception):
|
||||
class ModelSourceMetadata(BaseModel):
|
||||
"""Information collected on a downloadable model from its source site."""
|
||||
|
||||
name: Optional[str] = Field(description="Human-readable name of this model")
|
||||
author: Optional[str] = Field(description="Author/creator of the model")
|
||||
description: Optional[str] = Field(description="Description of the model")
|
||||
license: Optional[str] = Field(description="Model license terms")
|
||||
|
@ -82,6 +82,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_next_job_id: int = 0
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
_requests: requests.sessions.Session
|
||||
_quiet: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -89,6 +90,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
quiet: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
@ -105,6 +107,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._logger = InvokeAILogger.getLogger(config=config)
|
||||
self._event_handlers = event_handlers
|
||||
self._requests = requests_session or requests.Session()
|
||||
self._quiet = quiet
|
||||
|
||||
self._start_workers(max_parallel_dl)
|
||||
|
||||
@ -304,6 +307,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
# There should be a better way to dispatch on the job type
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
if isinstance(job, DownloadJobURL):
|
||||
self._download_with_resume(job)
|
||||
elif isinstance(job, DownloadJobRepoID):
|
||||
@ -336,6 +341,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
print(f"DEBUG: resp={resp}")
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
@ -418,7 +424,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
else:
|
||||
self._logger.info(f"{job.source}: Downloading {job.destination}")
|
||||
self._logger.debug(f"{job.source}: Downloading {job.destination}")
|
||||
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
@ -500,6 +506,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
job.subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
quiet=True,
|
||||
)
|
||||
try:
|
||||
repo_id = job.source
|
||||
@ -564,7 +571,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
metadata.license = metadata.license or model_info.cardData.get("license")
|
||||
if hasattr(model_info, "cardData"):
|
||||
metadata.license = metadata.license or model_info.cardData.get("license")
|
||||
metadata.tags = metadata.tags or model_info.tags
|
||||
metadata.author = metadata.author or model_info.author
|
||||
return urls
|
||||
|
@ -461,7 +461,11 @@ class ModelInstall(ModelInstallBase):
|
||||
|
||||
def delete(self, key: str): # noqa D102
|
||||
model = self._store.get_model(key)
|
||||
rmtree(model.path)
|
||||
path = self._app_config.models_path / model.path
|
||||
if path.is_dir():
|
||||
rmtree(path)
|
||||
else:
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def conditionally_delete(self, key: str): # noqa D102
|
||||
@ -507,6 +511,7 @@ class ModelInstall(ModelInstallBase):
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Imported model {info.name}"
|
||||
info.name = metadata.name or info.name
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
|
@ -10,17 +10,13 @@ This is the npyscreen frontend to the model installation application.
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import textwrap
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Process
|
||||
from multiprocessing.connection import Connection, Pipe
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
@ -28,6 +24,7 @@ import torch
|
||||
from huggingface_hub import HfFolder
|
||||
from npyscreen import widget
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
@ -69,9 +66,9 @@ ACCESS_TOKEN = HfFolder.get_token()
|
||||
|
||||
|
||||
class UnifiedModelInfo(BaseModel):
|
||||
name: str
|
||||
base_model: BaseModelType
|
||||
model_type: ModelType
|
||||
name: Optional[str] = None
|
||||
base_model: Optional[BaseModelType] = None
|
||||
model_type: Optional[ModelType] = None
|
||||
source: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
recommended: bool = False
|
||||
@ -92,6 +89,7 @@ def make_printable(s: str) -> str:
|
||||
|
||||
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
"""Main form for interactive TUI."""
|
||||
|
||||
# for responsive resizing set to False, but this seems to cause a crash!
|
||||
FIX_MINIMUM_SIZE_WHEN_CREATED = True
|
||||
|
||||
@ -172,13 +170,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
|
||||
self.nextrely = bottom_of_table + 1
|
||||
|
||||
self.monitor = self.add_widget_intelligent(
|
||||
BufferBox,
|
||||
name="Log Messages",
|
||||
editable=False,
|
||||
max_height=6,
|
||||
)
|
||||
|
||||
self.nextrely += 1
|
||||
back_label = "BACK"
|
||||
cancel_label = "CANCEL"
|
||||
@ -326,7 +317,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
download_ids=self.add_widget_intelligent(
|
||||
TextBox,
|
||||
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
|
||||
max_height=4,
|
||||
max_height=6,
|
||||
scroll_exit=True,
|
||||
editable=True,
|
||||
)
|
||||
@ -518,7 +509,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
||||
# models located in the 'download_ids" section
|
||||
for section in ui_sections:
|
||||
if downloads := section.get("download_ids"):
|
||||
selections.install_models.extend(downloads.value.split())
|
||||
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
|
||||
selections.install_models.extend(models)
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@ -539,17 +531,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
)
|
||||
|
||||
|
||||
class StderrToMessage:
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
|
||||
def write(self, data: str):
|
||||
self.connection.send_bytes(data.encode("utf-8"))
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
|
||||
def list_models(installer: ModelInstall, model_type: ModelType):
|
||||
"""Print out all models of type model_type."""
|
||||
models = installer.store.search_by_name(model_type=model_type)
|
||||
@ -559,14 +540,34 @@ def list_models(installer: ModelInstall, model_type: ModelType):
|
||||
print(f"{model.name:40}{model.base_model:10}{path}")
|
||||
|
||||
|
||||
def tqdm_progress(job: ModelInstallJob):
|
||||
pass
|
||||
class TqdmProgress(object):
|
||||
_bars: Dict[int, tqdm] # the tqdm object
|
||||
_last: Dict[int, int] # last bytes downloaded
|
||||
|
||||
def __init__(self):
|
||||
self._bars = dict()
|
||||
self._last = dict()
|
||||
|
||||
def job_update(self, job: ModelInstallJob):
|
||||
job_id = job.id
|
||||
if job.status == "running":
|
||||
if job_id not in self._bars:
|
||||
dest = Path(job.destination).name
|
||||
self._bars[job_id] = tqdm(
|
||||
desc=dest,
|
||||
initial=0,
|
||||
total=job.total_bytes,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
)
|
||||
self._last[job_id] = 0
|
||||
self._bars[job_id].update(job.bytes - self._last[job_id])
|
||||
self._last[job_id] = job.bytes
|
||||
|
||||
|
||||
def add_or_delete(installer: ModelInstall, selections: InstallSelections):
|
||||
for model in selections.install_models:
|
||||
print(f"Installing {model.name}")
|
||||
metadata = ModelSourceMetadata(description=model.description)
|
||||
metadata = ModelSourceMetadata(description=model.description, name=model.name)
|
||||
installer.install(
|
||||
model.source,
|
||||
variant="fp16" if config.precision == "float16" else None,
|
||||
@ -594,7 +595,7 @@ def select_and_download_models(opt: Namespace):
|
||||
"""Prompt user for install/delete selections and execute."""
|
||||
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
config.precision = precision
|
||||
installer = ModelInstall(config=config, event_handlers=[tqdm_progress])
|
||||
installer = ModelInstall(config=config, event_handlers=[TqdmProgress().job_update])
|
||||
|
||||
if opt.list_models:
|
||||
list_models(installer, opt.list_models)
|
||||
|
Reference in New Issue
Block a user