add/delete from command line working; training words downloaded

This commit is contained in:
Lincoln Stein
2023-09-21 18:18:35 -04:00
parent 30aea54f1a
commit c9cd418ed8
3 changed files with 39 additions and 4 deletions

View File

@ -9,9 +9,12 @@ from functools import total_ordering
from pathlib import Path
from typing import Callable, List, Optional, Union
import requests
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
# Used to distinguish between repo_id sources and URL sources
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
HTTP_RE = r"^https?://"
@ -107,6 +110,26 @@ class DownloadJobBase(BaseModel):
class DownloadQueueBase(ABC):
"""Abstract base class for managing model downloads."""
@abstractmethod
def __init__(
self,
max_parallel_dl: int = 5,
event_handlers: List[DownloadEventHandler] = [],
requests_session: Optional[requests.sessions.Session] = None,
config: Optional[InvokeAIAppConfig] = None,
quiet: bool = False,
):
"""
Initialize DownloadQueue.
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param event_handler: Optional callable that will be called each time a job status changes.
:param requests_session: Optional requests.sessions.Session object, for unit tests.
:param config: InvokeAIAppConfig object, used to configure the logger and other options.
:param quiet: If true, don't log the start of download jobs. Useful for subrequests.
"""
pass
@abstractmethod
def create_download_job(
self,

View File

@ -359,11 +359,17 @@ class DownloadQueue(DownloadQueueBase):
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.thumbnail_url = metadata.thumbnail_url or resp["modelVersions"][0]["images"][0]["url"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
first_version = resp["modelVersions"][0]
metadata.thumbnail_url = metadata.thumbnail_url or first_version.get("url")
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(first_version.get('trainedWords'))}"
if first_version.get("trainedWords")
else first_version.get("description")
)
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
self._logger.warn(excp)

View File

@ -79,7 +79,7 @@ class UnifiedModelInfo(BaseModel):
@dataclass
class InstallSelections:
install_models: List[UnifiedModelInfo] = field(default_factory=list)
remove_models: List[UnifiedModelInfo] = field(default_factory=list)
remove_models: List[str] = field(default_factory=list)
def make_printable(s: str) -> str:
@ -576,7 +576,11 @@ def add_or_delete(installer: ModelInstall, selections: InstallSelections):
)
for model in selections.remove_models:
base_model, model_type, model_name = model.split("/")
parts = model.split("/")
if len(parts) == 1:
base_model, model_type, model_name = (None, None, model)
else:
base_model, model_type, model_name = parts
matches = installer.store.search_by_name(base_model=base_model, model_type=model_type, model_name=model_name)
if len(matches) > 1:
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
@ -601,7 +605,9 @@ def select_and_download_models(opt: Namespace):
list_models(installer, opt.list_models)
elif opt.add or opt.delete:
selections = InstallSelections(install_models=opt.add, remove_models=opt.delete)
selections = InstallSelections(
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
)
add_or_delete(installer, selections)
elif opt.default_only: