mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add/delete from command line working; training words downloaded
This commit is contained in:
@ -9,9 +9,12 @@ from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
# Used to distinguish between repo_id sources and URL sources
|
||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
||||
HTTP_RE = r"^https?://"
|
||||
@ -107,6 +110,26 @@ class DownloadJobBase(BaseModel):
|
||||
class DownloadQueueBase(ABC):
|
||||
"""Abstract base class for managing model downloads."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
quiet: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param event_handler: Optional callable that will be called each time a job status changes.
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
:param config: InvokeAIAppConfig object, used to configure the logger and other options.
|
||||
:param quiet: If true, don't log the start of download jobs. Useful for subrequests.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_download_job(
|
||||
self,
|
||||
|
@ -359,11 +359,17 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
metadata.author = metadata.author or resp["creator"]["username"]
|
||||
metadata.tags = metadata.tags or resp["tags"]
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = (
|
||||
metadata.license
|
||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
)
|
||||
first_version = resp["modelVersions"][0]
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or first_version.get("url")
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(first_version.get('trainedWords'))}"
|
||||
if first_version.get("trainedWords")
|
||||
else first_version.get("description")
|
||||
)
|
||||
|
||||
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
|
||||
self._logger.warn(excp)
|
||||
|
@ -79,7 +79,7 @@ class UnifiedModelInfo(BaseModel):
|
||||
@dataclass
|
||||
class InstallSelections:
|
||||
install_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||
remove_models: List[UnifiedModelInfo] = field(default_factory=list)
|
||||
remove_models: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def make_printable(s: str) -> str:
|
||||
@ -576,7 +576,11 @@ def add_or_delete(installer: ModelInstall, selections: InstallSelections):
|
||||
)
|
||||
|
||||
for model in selections.remove_models:
|
||||
base_model, model_type, model_name = model.split("/")
|
||||
parts = model.split("/")
|
||||
if len(parts) == 1:
|
||||
base_model, model_type, model_name = (None, None, model)
|
||||
else:
|
||||
base_model, model_type, model_name = parts
|
||||
matches = installer.store.search_by_name(base_model=base_model, model_type=model_type, model_name=model_name)
|
||||
if len(matches) > 1:
|
||||
print(f"{model} is ambiguous. Please use model_type:model_name (e.g. main:my_model) to disambiguate.")
|
||||
@ -601,7 +605,9 @@ def select_and_download_models(opt: Namespace):
|
||||
list_models(installer, opt.list_models)
|
||||
|
||||
elif opt.add or opt.delete:
|
||||
selections = InstallSelections(install_models=opt.add, remove_models=opt.delete)
|
||||
selections = InstallSelections(
|
||||
install_models=[UnifiedModelInfo(source=x) for x in (opt.add or [])], remove_models=opt.delete or []
|
||||
)
|
||||
add_or_delete(installer, selections)
|
||||
|
||||
elif opt.default_only:
|
||||
|
Reference in New Issue
Block a user