mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
almost all type mismatches fixed
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Union, Tuple
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -22,14 +22,17 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
|
||||
# such as "MainCheckpointConfig" are repackaged by code original written by Stalker
|
||||
# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker
|
||||
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
|
||||
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
|
||||
|
||||
# There are still numerous mypy errors here because it does not seem to like this
|
||||
# way of dynamically generating the typing hints below.
|
||||
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
@ -38,7 +41,7 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
class ModelImportStatus(BaseModel):
|
||||
|
@ -3,9 +3,9 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import ModelInfo, SubModelType
|
||||
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import ModelInfo, SubModelType
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
@ -26,6 +26,7 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.cache import CacheStats
|
||||
from invokeai.backend.model_manager.download import DownloadJobBase
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from .events import EventServiceBase
|
||||
|
||||
from .config import InvokeAIAppConfig
|
||||
|
||||
@ -33,11 +34,6 @@ if TYPE_CHECKING:
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
# "forward declaration" because of circular import issues
|
||||
class EventServiceBase:
|
||||
pass
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
@ -197,8 +193,8 @@ class ModelManagerServiceBase(ABC):
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
|
||||
|
||||
@ -230,7 +226,7 @@ class ModelManagerServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
@ -334,7 +330,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""Responsible for managing models on disk and in memory."""
|
||||
|
||||
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
|
||||
_event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None)
|
||||
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
|
||||
"""
|
||||
@ -345,8 +341,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
installation and download events will be sent to the event bus.
|
||||
"""
|
||||
self._event_bus = event_bus
|
||||
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
|
||||
self._loader = ModelLoad(config, event_handlers=handlers)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self._event_bus:
|
||||
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
|
||||
self._loader = ModelLoad(config, **kwargs)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@ -365,8 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context=context,
|
||||
key=key,
|
||||
submodel_type=submodel_type,
|
||||
model_key=key,
|
||||
submodel=submodel_type,
|
||||
model_info=model_info,
|
||||
)
|
||||
|
||||
@ -416,16 +414,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
assertion error if the name already exists.
|
||||
"""
|
||||
self.logger.debug(f"add/update model {model_path}")
|
||||
return self._loader.installer.install(
|
||||
model_path,
|
||||
probe_override=model_attributes,
|
||||
return ModelInstallJob.parse_obj(
|
||||
self._loader.installer.install(
|
||||
model_path,
|
||||
probe_override=model_attributes,
|
||||
)
|
||||
)
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
priority: int = 10,
|
||||
model_attributes: Optional[Dict[str, Any]] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Add a model using a path, repo_id or URL.
|
||||
@ -438,11 +438,13 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
self.logger.debug(f"add model {source}")
|
||||
variant = "fp16" if self._loader.precision == "float16" else None
|
||||
return self._loader.installer.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
priority=priority,
|
||||
return ModelInstallJob.parse_obj(
|
||||
self._loader.installer.install(
|
||||
source,
|
||||
probe_override=model_attributes,
|
||||
variant=variant,
|
||||
priority=priority,
|
||||
)
|
||||
)
|
||||
|
||||
def list_install_jobs(self) -> List[ModelInstallJob]:
|
||||
@ -453,9 +455,9 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def id_to_job(self, id: int) -> ModelInstallJob:
|
||||
"""Return the ModelInstallJob instance corresponding to the given job ID."""
|
||||
return self._loader.queue.id_to_job(id)
|
||||
return ModelInstallJob.parse_obj(self._loader.queue.id_to_job(id))
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.config
|
||||
"""
|
||||
"""Initialization file for invokeai.backend.model_manager.config."""
|
||||
from .config import ( # noqa F401
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""
|
||||
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||
|
||||
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||
grows larger than a preset maximum, then the least recently used
|
||||
model will be cleared and (re)loaded from disk when next needed.
|
||||
@ -22,11 +23,11 @@ import sys
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union, types
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
from ..util import GIG
|
||||
from ..util.devices import choose_torch_device
|
||||
@ -97,7 +98,7 @@ class ModelCache(object):
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger,
|
||||
logger: Logger = InvokeAILogger.get_logger(),
|
||||
):
|
||||
"""
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
@ -122,8 +123,8 @@ class ModelCache(object):
|
||||
# used for stats collection
|
||||
self.stats: Optional[CacheStats] = None
|
||||
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
self._cached_models: Dict[str, _CacheRecord] = dict()
|
||||
self._cache_stack: List[str] = list()
|
||||
|
||||
# Note that the combination of model_path and submodel_type
|
||||
# are sufficient to generate a unique cache key. This key
|
||||
@ -221,8 +222,12 @@ class ModelCache(object):
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
class ModelLocker(object):
|
||||
"""Context manager that locks models into VRAM."""
|
||||
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
"""
|
||||
Initialize a context manager object that locks models into VRAM.
|
||||
|
||||
:param cache: The model_cache object
|
||||
:param key: The key of the model to lock in GPU
|
||||
:param model: The model to lock
|
||||
|
@ -240,6 +240,7 @@ AnyModelConfig = Union[
|
||||
TextualInversionConfig,
|
||||
ONNXSD1Config,
|
||||
ONNXSD2Config,
|
||||
ModelConfigBase,
|
||||
]
|
||||
|
||||
|
||||
@ -324,6 +325,8 @@ class ModelConfigFactory(object):
|
||||
|
||||
# TO DO: Move this somewhere else
|
||||
class SilenceWarnings(object):
|
||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
@ -19,9 +19,8 @@
|
||||
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config_file = requests.get(config_url).text
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
if original_config["model"]["params"].get("use_ema") is not None:
|
||||
extract_ema = original_config["model"]["params"]["use_ema"]
|
||||
if original_config.model["params"].get("use_ema") is not None:
|
||||
extract_ema = original_config.model["params"]["use_ema"]
|
||||
|
||||
if (
|
||||
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
|
||||
and original_config["model"]["params"].get("parameterization") == "v"
|
||||
and original_config.model["params"].get("parameterization") == "v"
|
||||
):
|
||||
prediction_type = "v_prediction"
|
||||
upcast_attention = True
|
||||
@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
num_in_channels = 4
|
||||
|
||||
if "unet_config" in original_config.model.params:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
"parameterization" in original_config.model["params"]
|
||||
and original_config.model["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
config_name = "stabilityai/stable-diffusion-2"
|
||||
config_kwargs = {"subfolder": "text_encoder"}
|
||||
config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"}
|
||||
|
||||
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
|
||||
@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt(
|
||||
# scan model
|
||||
scan_result = scan_file_path(checkpoint_path)
|
||||
if scan_result.infected_files != 0:
|
||||
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
|
||||
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt(
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
|
||||
pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs)
|
||||
|
||||
pipe.save_pretrained(
|
||||
dump_path,
|
||||
@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers(
|
||||
and in addition a path-like object indicating the location of the desired diffusers
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
|
||||
pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs)
|
||||
|
||||
pipe.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
@ -1,7 +1,5 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Abstract base class for a multithreaded model download queue.
|
||||
"""
|
||||
"""Abstract base class for a multithreaded model download queue."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
@ -56,7 +54,7 @@ class DownloadJobBase(BaseModel):
|
||||
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
source: Union[str, Path] = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
metadata: ModelSourceMetadata = Field(
|
||||
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
|
||||
@ -84,7 +82,8 @@ class DownloadJobBase(BaseModel):
|
||||
|
||||
def add_event_handler(self, handler: DownloadEventHandler):
|
||||
"""Add an event handler to the end of the handlers list."""
|
||||
self.event_handlers.append(handler)
|
||||
if self.event_handlers is not None:
|
||||
self.event_handlers.append(handler)
|
||||
|
||||
def clear_event_handlers(self):
|
||||
"""Clear all event handlers."""
|
||||
@ -226,9 +225,7 @@ class DownloadQueueBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""
|
||||
Prune completed and errored queue items from the job list.
|
||||
"""
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -7,19 +7,18 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from queue import PriorityQueue
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import requests
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from pydantic import Field, ValidationError, validator
|
||||
from pydantic import Field, ValidationError, parse_obj_as, validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
from ..storage import DuplicateModelException
|
||||
from . import HTTP_RE, REPO_ID_RE
|
||||
@ -56,19 +55,20 @@ class DownloadJobRepoID(DownloadJobBase):
|
||||
"""Download repo ids."""
|
||||
|
||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
|
||||
@validator("source")
|
||||
@classmethod
|
||||
def _validate_source(cls, v: str) -> str:
|
||||
if not re.match(REPO_ID_RE, v):
|
||||
raise ValidationError(f"{v} invalid repo_id")
|
||||
raise ValidationError(f"{v} invalid repo_id", cls)
|
||||
return v
|
||||
|
||||
|
||||
class DownloadJobPath(DownloadJobBase):
|
||||
"""Handle file paths."""
|
||||
|
||||
source: Path = Field(description="Path to a file or directory to install")
|
||||
source: Union[str, Path] = Field(description="Path to a file or directory to install")
|
||||
|
||||
|
||||
class DownloadQueue(DownloadQueueBase):
|
||||
@ -77,8 +77,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
_jobs: Dict[int, DownloadJobBase]
|
||||
_worker_pool: Set[threading.Thread]
|
||||
_queue: PriorityQueue
|
||||
_lock: threading.Lock
|
||||
_logger: InvokeAILogger
|
||||
_lock: threading.RLock
|
||||
_logger: Logger
|
||||
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
||||
_next_job_id: int = 0
|
||||
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
||||
@ -126,6 +126,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Create a download job and return its ID."""
|
||||
kwargs = dict()
|
||||
|
||||
cls = DownloadJobBase
|
||||
if Path(source).exists():
|
||||
cls = DownloadJobPath
|
||||
elif re.match(REPO_ID_RE, str(source)):
|
||||
@ -137,7 +138,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
||||
|
||||
job = cls(
|
||||
source=source,
|
||||
source=str(source),
|
||||
destination=Path(destdir) / (filename or "."),
|
||||
access_token=access_token,
|
||||
event_handlers=(event_handlers or self._event_handlers),
|
||||
@ -179,7 +180,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def list_jobs(self) -> List[DownloadJobBase]:
|
||||
"""List all the jobs."""
|
||||
return self._jobs.values()
|
||||
return list(self._jobs.values())
|
||||
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
"""Change the priority of a job. Smaller priorities run first."""
|
||||
@ -193,9 +194,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._lock.release()
|
||||
|
||||
def prune_jobs(self):
|
||||
"""
|
||||
Prune completed and errored queue items from the job list.
|
||||
"""
|
||||
"""Prune completed and errored queue items from the job list."""
|
||||
try:
|
||||
to_delete = set()
|
||||
self._lock.acquire()
|
||||
@ -334,7 +333,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
elif isinstance(job, DownloadJobPath):
|
||||
self._download_path(job)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}")
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
|
||||
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
@ -354,7 +353,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
model = None
|
||||
|
||||
# a Civitai download URL
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD, metadata_url):
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||
@ -364,16 +363,16 @@ class DownloadQueue(DownloadQueueBase):
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
|
||||
|
||||
# a Civitai model page with the version
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, metadata_url):
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
|
||||
model = match.group(1)
|
||||
version = int(match.group(2))
|
||||
# and without
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
|
||||
model = match.group(1)
|
||||
version = None
|
||||
|
||||
if not model:
|
||||
return url
|
||||
return parse_obj_as(AnyHttpUrl, url)
|
||||
|
||||
if model:
|
||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||
@ -419,16 +418,19 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
|
||||
file_name = ""
|
||||
if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]):
|
||||
file_name = match.group(1)
|
||||
assert file_name != ""
|
||||
self._validate_filename(
|
||||
job.destination, file_name
|
||||
job.destination.as_posix(), file_name
|
||||
) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(
|
||||
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
|
||||
)
|
||||
file_name = os.path.basename(url)
|
||||
except KeyError:
|
||||
except (KeyError, AssertionError):
|
||||
file_name = os.path.basename(url)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
@ -537,12 +539,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
job.subqueue = self.__class__(
|
||||
subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
quiet=True,
|
||||
)
|
||||
try:
|
||||
job = DownloadJobRepoID.parse_obj(job)
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
if not job.metadata:
|
||||
@ -550,12 +553,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata)
|
||||
if job.destination.name != Path(repo_id).name:
|
||||
job.destination = job.destination / Path(repo_id).name
|
||||
bytes_downloaded = dict()
|
||||
bytes_downloaded: Dict[int, int] = dict()
|
||||
job.total_bytes = 0
|
||||
|
||||
for url, subdir, file, size in urls_to_download:
|
||||
job.total_bytes += size
|
||||
job.subqueue.create_download_job(
|
||||
subqueue.create_download_job(
|
||||
source=url,
|
||||
destdir=job.destination / subdir,
|
||||
filename=file,
|
||||
@ -569,6 +572,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
self._logger.error(job.error)
|
||||
finally:
|
||||
job.subqueue = subqueue
|
||||
job.subqueue.join()
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
@ -579,7 +583,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
repo_id: str,
|
||||
metadata: ModelSourceMetadata,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
|
||||
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
||||
"""
|
||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||
The metadata field will be updated with model metadata from HuggingFace.
|
||||
@ -602,7 +606,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
|
||||
paths.insert(0, "model_index.json")
|
||||
urls = [
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), Path(x.name), sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
if hasattr(model_info, "cardData"):
|
||||
@ -614,7 +618,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
result = set()
|
||||
basenames = dict()
|
||||
basenames: Dict[Path, Path] = dict()
|
||||
for p in paths:
|
||||
path = Path(p)
|
||||
|
||||
|
@ -57,9 +57,9 @@ class FastModelHash(object):
|
||||
# only tally tensor files
|
||||
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||
continue
|
||||
path = Path(root) / file
|
||||
path = (Path(root) / file).as_posix()
|
||||
fast_hash = cls._hash_file(path)
|
||||
components.update({str(path): fast_hash})
|
||||
components.update({path: fast_hash})
|
||||
|
||||
# hash all the model hashes together, using alphabetic file order
|
||||
md5 = hashlib.md5()
|
||||
|
@ -54,7 +54,7 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@ -166,7 +166,7 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
|
||||
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Probe, register and install the model in the models directory.
|
||||
|
||||
@ -174,7 +174,7 @@ class ModelInstallBase(ABC):
|
||||
the models directory handled by InvokeAI.
|
||||
|
||||
:param model_path: Filesystem Path to the model.
|
||||
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
|
||||
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
|
||||
:returns id: The string ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
@ -184,6 +184,7 @@ class ModelInstallBase(ABC):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
@ -225,7 +226,7 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self) -> Dict[str, str]:
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""
|
||||
Wait for all pending installs to complete.
|
||||
|
||||
@ -293,7 +294,7 @@ class ModelInstallBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sync_model_path(self, key) -> Path:
|
||||
def sync_model_path(self, key) -> ModelConfigBase:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@ -324,11 +325,11 @@ class ModelInstall(ModelInstallBase):
|
||||
_logger: Logger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
_async_installs: Dict[str, str]
|
||||
_installed: Set[Path] = Field(default=set)
|
||||
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
|
||||
_installed: Set[str] = Field(default=set)
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
|
||||
_legacy_configs = {
|
||||
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
@ -357,13 +358,13 @@ class ModelInstall(ModelInstallBase):
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[Logger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
event_handlers: List[DownloadEventHandler] = [],
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
|
||||
self._store = store or get_config_store(self._app_config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers)
|
||||
self._async_installs = dict()
|
||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||
self._installed = set()
|
||||
self._tmpdir = None
|
||||
|
||||
@ -403,7 +404,8 @@ class ModelInstall(ModelInstallBase):
|
||||
)
|
||||
# add 'main' specific fields
|
||||
if info.model_type == ModelType.Main:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.variant_type:
|
||||
registration_data.update(variant=info.variant_type)
|
||||
if info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
@ -416,7 +418,7 @@ class ModelInstall(ModelInstallBase):
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
registration_data.update(
|
||||
config=Path(self._app_config.legacy_conf_dir, config_file).as_posix(),
|
||||
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
|
||||
)
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException(
|
||||
@ -453,12 +455,12 @@ class ModelInstall(ModelInstallBase):
|
||||
return old_path.replace(new_path)
|
||||
|
||||
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
|
||||
if overrides: # used to override probe fields
|
||||
for key, value in overrides.items():
|
||||
try:
|
||||
setattr(info, key, value) # skip validation errors
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
@ -488,11 +490,11 @@ class ModelInstall(ModelInstallBase):
|
||||
self,
|
||||
source: Union[str, Path, AnyHttpUrl],
|
||||
inplace: bool = True,
|
||||
priority: int = 10,
|
||||
variant: Optional[str] = None,
|
||||
probe_override: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[ModelSourceMetadata] = None,
|
||||
access_token: Optional[str] = None,
|
||||
priority: Optional[int] = 10,
|
||||
) -> DownloadJobBase: # noqa D102
|
||||
queue = self._download_queue
|
||||
|
||||
@ -502,7 +504,8 @@ class ModelInstall(ModelInstallBase):
|
||||
if inplace and Path(source).exists()
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
job.probe_override = probe_override
|
||||
if isinstance(job, ModelInstallJob):
|
||||
job.probe_override = probe_override
|
||||
if metadata:
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
@ -512,6 +515,7 @@ class ModelInstall(ModelInstallBase):
|
||||
return job
|
||||
|
||||
def _complete_installation_handler(self, job: DownloadJobBase):
|
||||
job = ModelInstallJob.parse_obj(job) # this upcast should succeed
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
|
||||
model_id = self.install_path(job.destination, job.probe_override)
|
||||
@ -537,6 +541,7 @@ class ModelInstall(ModelInstallBase):
|
||||
self._tmpdir = None
|
||||
|
||||
def _complete_registration_handler(self, job: DownloadJobBase):
|
||||
job = ModelInstallJob.parse_obj(job) # upcast should succeed
|
||||
if job.status == "completed":
|
||||
self._logger.info(f"{job.source}: Installing in place.")
|
||||
model_id = self.register_path(job.destination, job.probe_override)
|
||||
@ -567,7 +572,7 @@ class ModelInstall(ModelInstallBase):
|
||||
models_dir = self._app_config.models_path
|
||||
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return old_path
|
||||
return model
|
||||
|
||||
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
|
||||
self._logger.info(
|
||||
@ -592,7 +597,6 @@ class ModelInstall(ModelInstallBase):
|
||||
# we simply probe and register/install it. The job does not actually do anything, but we
|
||||
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
|
||||
if Path(source).exists(): # a path that is already on disk
|
||||
source = Path(source)
|
||||
destdir = source
|
||||
return ModelInstallPathJob(source=source, destination=Path(destdir))
|
||||
|
||||
@ -600,6 +604,7 @@ class ModelInstall(ModelInstallBase):
|
||||
models_dir = self._app_config.models_path
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
|
||||
cls = ModelInstallJob
|
||||
if re.match(REPO_ID_RE, str(source)):
|
||||
cls = ModelInstallRepoIDJob
|
||||
kwargs = dict(variant=variant)
|
||||
@ -609,10 +614,15 @@ class ModelInstall(ModelInstallBase):
|
||||
else:
|
||||
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
|
||||
return cls(
|
||||
source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs
|
||||
source=str(source),
|
||||
destination=Path(self._tmpdir.name),
|
||||
access_token=access_token,
|
||||
priority=priority,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
|
||||
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
|
||||
"""Pause until all installation jobs have completed."""
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
@ -634,9 +644,10 @@ class ModelInstall(ModelInstallBase):
|
||||
dest_directory: Optional[Path] = None,
|
||||
) -> ModelConfigBase:
|
||||
"""
|
||||
Convert a checkpoint file into a diffusers folder, deleting the cached
|
||||
version and deleting the original checkpoint file if it is in the models
|
||||
directory.
|
||||
Convert a checkpoint file into a diffusers folder.
|
||||
|
||||
It will delete the cached version ans well as the
|
||||
original checkpoint file if it is in the models directory.
|
||||
:param key: Unique key of model.
|
||||
:dest_directory: Optional place to put converted file. If not specified,
|
||||
will be stored in the `models_dir`.
|
||||
@ -668,7 +679,7 @@ class ModelInstall(ModelInstallBase):
|
||||
update = info.dict()
|
||||
update.pop("config")
|
||||
update["model_format"] = "diffusers"
|
||||
update["path"] = converted_model.location.as_posix()
|
||||
update["path"] = converted_model.location
|
||||
|
||||
if dest_directory:
|
||||
new_diffusers_path = Path(dest_directory) / info.name
|
||||
|
@ -121,11 +121,7 @@ class ModelLoad(ModelLoadBase):
|
||||
_cache_keys: dict
|
||||
_models_file: Path
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InvokeAIAppConfig,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
):
|
||||
def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []):
|
||||
"""
|
||||
Initialize ModelLoad object.
|
||||
|
||||
|
@ -12,7 +12,7 @@ from diffusers.models import UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from .models.lora import LoRAModel
|
||||
from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@ -87,7 +87,7 @@ class ModelPatcher:
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
@ -97,7 +97,7 @@ class ModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
@ -107,7 +107,7 @@ class ModelPatcher:
|
||||
def apply_sdxl_lora_text_encoder2(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
@ -117,7 +117,7 @@ class ModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = dict()
|
||||
@ -337,7 +337,7 @@ class ONNXModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
loras: List[Tuple[LoRAModelRaw, torch.Tensor]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
@ -348,7 +348,7 @@ class ONNXModelPatcher:
|
||||
orig_weights = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
blended_loras: Dict[str, torch.Tensor] = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
|
@ -9,7 +9,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
@ -17,7 +17,8 @@ from diffusers import logging as dlogging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
|
||||
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
|
||||
from .config import MainConfig
|
||||
|
||||
|
||||
class MergeInterpolationMethod(str, Enum):
|
||||
@ -102,11 +103,11 @@ class ModelMerger(object):
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
model_paths = list()
|
||||
model_paths: List[Path] = list()
|
||||
model_names = list()
|
||||
config = self._config
|
||||
store = self._store
|
||||
base_models = set()
|
||||
base_models: Set[BaseModelType] = set()
|
||||
vae = None
|
||||
|
||||
assert (
|
||||
@ -115,6 +116,7 @@ class ModelMerger(object):
|
||||
|
||||
for key in model_keys:
|
||||
info = store.get_model(key)
|
||||
assert isinstance(info, MainConfig)
|
||||
model_names.append(info.name)
|
||||
assert (
|
||||
info.model_format == "diffusers"
|
||||
|
@ -33,7 +33,7 @@ class ModelProbeInfo(BaseModel):
|
||||
base_type: BaseModelType
|
||||
format: ModelFormat
|
||||
hash: str
|
||||
variant_type: Optional[ModelVariantType] = ModelVariantType("normal")
|
||||
variant_type: ModelVariantType = ModelVariantType("normal")
|
||||
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
|
||||
upcast_attention: Optional[bool] = False
|
||||
image_size: Optional[int] = None
|
||||
@ -114,7 +114,7 @@ class ModelProbe(ModelProbeBase):
|
||||
cls,
|
||||
model_path: Path,
|
||||
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
||||
) -> Optional[ModelProbeInfo]:
|
||||
) -> ModelProbeInfo:
|
||||
"""Probe model."""
|
||||
try:
|
||||
model_type = (
|
||||
@ -129,7 +129,7 @@ class ModelProbe(ModelProbeBase):
|
||||
probe_class = cls.PROBES[format_type].get(model_type)
|
||||
|
||||
if not probe_class:
|
||||
return None
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
probe = probe_class(model_path, prediction_type_helper)
|
||||
|
||||
@ -160,7 +160,7 @@ class ModelProbe(ModelProbeBase):
|
||||
else 512,
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
return model_info
|
||||
|
||||
|
@ -27,7 +27,7 @@ from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util import InvokeAILogger, Logger
|
||||
|
||||
default_logger = InvokeAILogger.get_logger()
|
||||
|
||||
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221
|
||||
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
@ -143,7 +143,7 @@ class ModelSearch(ModelSearchBase):
|
||||
self.on_search_completed(self._models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = directory
|
||||
self._directory = Path(directory)
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
@ -155,7 +155,7 @@ class ModelSearch(ModelSearchBase):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
self._pruned_paths.add(Path(root))
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_manager.storage
|
||||
"""
|
||||
"""Initialization file for invokeai.backend.model_manager.storage."""
|
||||
import pathlib
|
||||
|
||||
from .base import ( # noqa F401
|
||||
|
@ -35,13 +35,11 @@ class ModelConfigStore(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def version(self) -> str:
|
||||
"""
|
||||
Return the config file/database schema version.
|
||||
"""
|
||||
"""Return the config file/database schema version."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@ -65,7 +63,7 @@ class ModelConfigStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@ -96,7 +94,7 @@ class ModelConfigStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
|
||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models containing all of the listed tags.
|
||||
|
||||
@ -108,10 +106,8 @@ class ModelConfigStore(ABC):
|
||||
def search_by_path(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
) -> Optional[ModelConfigBase]:
|
||||
"""
|
||||
Return the model having the indicated path.
|
||||
"""
|
||||
) -> Optional[AnyModelConfig]:
|
||||
"""Return the model having the indicated path."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -120,7 +116,7 @@ class ModelConfigStore(ABC):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
@ -133,8 +129,6 @@ class ModelConfigStore(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def all_models(self) -> List[ModelConfigBase]:
|
||||
"""
|
||||
Return all the model configs in the database.
|
||||
"""
|
||||
def all_models(self) -> List[AnyModelConfig]:
|
||||
"""Return all the model configs in the database."""
|
||||
return self.search_by_name()
|
||||
|
@ -52,14 +52,14 @@ def migrate_models_store(config: InvokeAIAppConfig):
|
||||
except Exception as excp:
|
||||
print(str(excp))
|
||||
|
||||
model_info = store.get_model(new_key)
|
||||
if vae := stanza.get("vae") and isinstance(model_info, MainConfig):
|
||||
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||
if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig):
|
||||
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
store.update_model(new_key, model_info)
|
||||
if new_key != "<NOKEY>":
|
||||
model_info = store.get_model(new_key)
|
||||
if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig):
|
||||
model_info.vae = (app_config.models_path / vae).as_posix()
|
||||
if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig):
|
||||
model_info.config = (app_config.root_path / model_config).as_posix()
|
||||
model_info.description = stanza.get("description")
|
||||
store.update_model(new_key, model_info)
|
||||
|
||||
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
|
||||
shutil.move(old_file, str(old_file) + ".orig")
|
||||
|
@ -398,14 +398,14 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
self._lock.release()
|
||||
return count > 0
|
||||
|
||||
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
|
||||
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
|
||||
"""Return models containing all of the listed tags."""
|
||||
# rather than create a hairy SQL cross-product, we intersect
|
||||
# tag results in a stepwise fashion at the python level.
|
||||
results = []
|
||||
try:
|
||||
self._lock.acquire()
|
||||
matches = set()
|
||||
matches: Set[str] = set()
|
||||
for tag in tags:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -438,7 +438,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> List[ModelConfigBase]:
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
|
||||
@ -479,7 +479,5 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
|
||||
"""
|
||||
Return the model with the indicated path, or None..
|
||||
"""
|
||||
"""Return the model with the indicated path, or None."""
|
||||
raise NotImplementedError("search_by_path not implemented in storage.sql")
|
||||
|
@ -48,7 +48,6 @@ from typing import List, Optional, Set, Union
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from omegaconf.listconfig import ListConfig
|
||||
|
||||
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
|
||||
from .base import (
|
||||
@ -64,7 +63,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
"""Implementation of the ModelConfigStore ABC using a YAML file."""
|
||||
|
||||
_filename: Path
|
||||
_config: Union[DictConfig, ListConfig]
|
||||
_config: DictConfig
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, config_file: Path):
|
||||
@ -74,7 +73,9 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
self._lock = threading.RLock()
|
||||
if not self._filename.exists():
|
||||
self._initialize_yaml()
|
||||
self._config = OmegaConf.load(self._filename)
|
||||
config = OmegaConf.load(self._filename)
|
||||
assert isinstance(config, DictConfig)
|
||||
self._config = config
|
||||
if str(self.version) != CONFIG_FILE_VERSION:
|
||||
raise ConfigFileVersionMismatchException
|
||||
|
||||
@ -101,7 +102,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return version of this config file/database."""
|
||||
return self._config["__metadata__"].get("version")
|
||||
return self._config.__metadata__.get("version")
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
|
||||
"""
|
||||
|
@ -149,7 +149,7 @@ def _fast_safetensors_reader(path: str):
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
checkpoint = _fast_safetensors_reader(str(path))
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
|
@ -294,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
return self.FORMATS[levelno]
|
||||
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
@ -333,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
return self.FORMATS[levelno]
|
||||
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
|
@ -104,114 +104,6 @@ def get_obj_from_str(string, reload=False):
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
# DEAD CODE?
|
||||
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||
# create dummy dataset instance
|
||||
|
||||
# run prefetching
|
||||
if idx_to_fn:
|
||||
res = func(data, worker_id=idx)
|
||||
else:
|
||||
res = func(data)
|
||||
Q.put([idx, res])
|
||||
Q.put("Done")
|
||||
|
||||
|
||||
# DEAD CODE?
|
||||
def parallel_data_prefetch(
|
||||
func: callable,
|
||||
data,
|
||||
n_proc,
|
||||
target_data_type="ndarray",
|
||||
cpu_intensive=True,
|
||||
use_worker_id=False,
|
||||
):
|
||||
# if target_data_type not in ["ndarray", "list"]:
|
||||
# raise ValueError(
|
||||
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||
# )
|
||||
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||
raise ValueError("list expected but function got ndarray.")
|
||||
elif isinstance(data, abc.Iterable):
|
||||
if isinstance(data, dict):
|
||||
logger.warning(
|
||||
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||
)
|
||||
data = list(data.values())
|
||||
if target_data_type == "ndarray":
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = list(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||
)
|
||||
|
||||
if cpu_intensive:
|
||||
Q = mp.Queue(1000)
|
||||
proc = mp.Process
|
||||
else:
|
||||
Q = Queue(1000)
|
||||
proc = Thread
|
||||
# spawn processes
|
||||
if target_data_type == "ndarray":
|
||||
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
|
||||
else:
|
||||
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
|
||||
arguments = [
|
||||
[func, Q, part, i, use_worker_id]
|
||||
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
|
||||
]
|
||||
processes = []
|
||||
for i in range(n_proc):
|
||||
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||
processes += [p]
|
||||
|
||||
# start processes
|
||||
logger.info("Start prefetching...")
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
gather_res = [[] for _ in range(n_proc)]
|
||||
try:
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
k = 0
|
||||
while k < n_proc:
|
||||
# get result
|
||||
res = Q.get()
|
||||
if res == "Done":
|
||||
k += 1
|
||||
else:
|
||||
gather_res[res[0]] = res[1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Exception: ", e)
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
|
||||
raise e
|
||||
finally:
|
||||
for p in processes:
|
||||
p.join()
|
||||
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||
|
||||
if target_data_type == "ndarray":
|
||||
if not isinstance(gather_res[0], np.ndarray):
|
||||
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||
|
||||
# order outputs
|
||||
return np.concatenate(gather_res, axis=0)
|
||||
elif target_data_type == "list":
|
||||
out = []
|
||||
for r in gather_res:
|
||||
out.extend(r)
|
||||
return out
|
||||
else:
|
||||
return gather_res
|
||||
|
||||
|
||||
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
||||
delta = (res[0] / shape[0], res[1] / shape[1])
|
||||
d = (shape[0] // res[0], shape[1] // res[1])
|
||||
|
Reference in New Issue
Block a user