TUI installer functional; minor cosmetic work needed

This commit is contained in:
Lincoln Stein
2023-09-20 21:41:45 -04:00
parent 3402cf6542
commit 3199409fd3
4 changed files with 51 additions and 36 deletions

View File

@ -36,6 +36,7 @@ class UnknownJobIDException(Exception):
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")

View File

@ -82,6 +82,7 @@ class DownloadQueue(DownloadQueueBase):
_next_job_id: int = 0
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
_requests: requests.sessions.Session
_quiet: bool = False
def __init__(
self,
@ -89,6 +90,7 @@ class DownloadQueue(DownloadQueueBase):
event_handlers: List[DownloadEventHandler] = [],
requests_session: Optional[requests.sessions.Session] = None,
config: Optional[InvokeAIAppConfig] = None,
quiet: bool = False,
):
"""
Initialize DownloadQueue.
@ -105,6 +107,7 @@ class DownloadQueue(DownloadQueueBase):
self._logger = InvokeAILogger.getLogger(config=config)
self._event_handlers = event_handlers
self._requests = requests_session or requests.Session()
self._quiet = quiet
self._start_workers(max_parallel_dl)
@ -304,6 +307,8 @@ class DownloadQueue(DownloadQueueBase):
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
# There should be a better way to dispatch on the job type
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
if isinstance(job, DownloadJobURL):
self._download_with_resume(job)
elif isinstance(job, DownloadJobRepoID):
@ -336,6 +341,7 @@ class DownloadQueue(DownloadQueueBase):
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
print(f"DEBUG: resp={resp}")
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
@ -418,7 +424,7 @@ class DownloadQueue(DownloadQueueBase):
elif resp.status_code != 200:
raise HTTPError(resp.reason)
else:
self._logger.info(f"{job.source}: Downloading {job.destination}")
self._logger.debug(f"{job.source}: Downloading {job.destination}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
@ -500,6 +506,7 @@ class DownloadQueue(DownloadQueueBase):
job.subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
try:
repo_id = job.source
@ -564,7 +571,8 @@ class DownloadQueue(DownloadQueueBase):
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
metadata.license = metadata.license or model_info.cardData.get("license")
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls

View File

@ -461,7 +461,11 @@ class ModelInstall(ModelInstallBase):
def delete(self, key: str): # noqa D102
model = self._store.get_model(key)
rmtree(model.path)
path = self._app_config.models_path / model.path
if path.is_dir():
rmtree(path)
else:
path.unlink()
self.unregister(key)
def conditionally_delete(self, key: str): # noqa D102
@ -507,6 +511,7 @@ class ModelInstall(ModelInstallBase):
info.source = str(job.source)
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Imported model {info.name}"
info.name = metadata.name or info.name
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license

View File

@ -10,17 +10,13 @@ This is the npyscreen frontend to the model installation application.
import argparse
import curses
import logging
import sys
import textwrap
import traceback
from argparse import Namespace
from dataclasses import dataclass, field
from multiprocessing import Process
from multiprocessing.connection import Connection, Pipe
from pathlib import Path
from shutil import get_terminal_size
from typing import List, Optional
from typing import Dict, List, Optional, Tuple
import npyscreen
import omegaconf
@ -28,6 +24,7 @@ import torch
from huggingface_hub import HfFolder
from npyscreen import widget
from pydantic import BaseModel
from tqdm import tqdm
import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig
@ -69,9 +66,9 @@ ACCESS_TOKEN = HfFolder.get_token()
class UnifiedModelInfo(BaseModel):
name: str
base_model: BaseModelType
model_type: ModelType
name: Optional[str] = None
base_model: Optional[BaseModelType] = None
model_type: Optional[ModelType] = None
source: Optional[str] = None
description: Optional[str] = None
recommended: bool = False
@ -92,6 +89,7 @@ def make_printable(s: str) -> str:
class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
"""Main form for interactive TUI."""
# for responsive resizing set to False, but this seems to cause a crash!
FIX_MINIMUM_SIZE_WHEN_CREATED = True
@ -172,13 +170,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.nextrely = bottom_of_table + 1
self.monitor = self.add_widget_intelligent(
BufferBox,
name="Log Messages",
editable=False,
max_height=6,
)
self.nextrely += 1
back_label = "BACK"
cancel_label = "CANCEL"
@ -326,7 +317,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
download_ids=self.add_widget_intelligent(
TextBox,
name="Additional URLs, or HuggingFace repo_ids to install (Space separated. Use shift-control-V to paste):",
max_height=4,
max_height=6,
scroll_exit=True,
editable=True,
)
@ -518,7 +509,8 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
# models located in the 'download_ids" section
for section in ui_sections:
if downloads := section.get("download_ids"):
selections.install_models.extend(downloads.value.split())
models = [UnifiedModelInfo(source=x) for x in downloads.value.split()]
selections.install_models.extend(models)
class AddModelApplication(npyscreen.NPSAppManaged):
@ -539,17 +531,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
)
class StderrToMessage:
def __init__(self, connection: Connection):
self.connection = connection
def write(self, data: str):
self.connection.send_bytes(data.encode("utf-8"))
def flush(self):
pass
def list_models(installer: ModelInstall, model_type: ModelType):
"""Print out all models of type model_type."""
models = installer.store.search_by_name(model_type=model_type)
@ -559,14 +540,34 @@ def list_models(installer: ModelInstall, model_type: ModelType):
print(f"{model.name:40}{model.base_model:10}{path}")
def tqdm_progress(job: ModelInstallJob):
pass
class TqdmProgress(object):
_bars: Dict[int, tqdm] # the tqdm object
_last: Dict[int, int] # last bytes downloaded
def __init__(self):
self._bars = dict()
self._last = dict()
def job_update(self, job: ModelInstallJob):
job_id = job.id
if job.status == "running":
if job_id not in self._bars:
dest = Path(job.destination).name
self._bars[job_id] = tqdm(
desc=dest,
initial=0,
total=job.total_bytes,
unit="iB",
unit_scale=True,
)
self._last[job_id] = 0
self._bars[job_id].update(job.bytes - self._last[job_id])
self._last[job_id] = job.bytes
def add_or_delete(installer: ModelInstall, selections: InstallSelections):
for model in selections.install_models:
print(f"Installing {model.name}")
metadata = ModelSourceMetadata(description=model.description)
metadata = ModelSourceMetadata(description=model.description, name=model.name)
installer.install(
model.source,
variant="fp16" if config.precision == "float16" else None,
@ -594,7 +595,7 @@ def select_and_download_models(opt: Namespace):
"""Prompt user for install/delete selections and execute."""
precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
config.precision = precision
installer = ModelInstall(config=config, event_handlers=[tqdm_progress])
installer = ModelInstall(config=config, event_handlers=[TqdmProgress().job_update])
if opt.list_models:
list_models(installer, opt.list_models)