almost all type mismatches fixed

This commit is contained in:
Lincoln Stein
2023-09-29 19:23:08 -04:00
parent cbf0310a2c
commit 208d390779
24 changed files with 185 additions and 282 deletions

View File

@ -3,7 +3,7 @@
import pathlib
from enum import Enum
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional, Union, Tuple
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -22,14 +22,17 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from ..dependencies import ApiDependencies
from invokeai.app.api.dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
# NOTE: The generic configuration classes defined in invokeai.backend.model_manager.config
# such as "MainCheckpointConfig" are repackaged by code original written by Stalker
# such as "MainCheckpointConfig" are repackaged by code originally written by Stalker
# into base-specific classes such as `abc.StableDiffusion1ModelCheckpointConfig`
# This is the reason for the calls to dict() followed by pydantic.parse_obj_as()
# There are still numerous mypy errors here because it does not seem to like this
# way of dynamically generating the typing hints below.
UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
@ -38,7 +41,7 @@ ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
models: List[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
class ModelImportStatus(BaseModel):

View File

@ -3,9 +3,9 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import ModelInfo, SubModelType
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import ModelInfo, SubModelType
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -26,6 +26,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .events import EventServiceBase
from .config import InvokeAIAppConfig
@ -33,11 +34,6 @@ if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
# "forward declaration" because of circular import issues
class EventServiceBase:
pass
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory."""
@ -197,8 +193,8 @@ class ModelManagerServiceBase(ABC):
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
"""Import a path, repo_id or URL. Returns an ModelInstallJob.
@ -230,7 +226,7 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def wait_for_installs(self) -> Dict[str, str]:
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
@ -334,7 +330,7 @@ class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory."""
_loader: ModelLoad = Field(description="InvokeAIAppConfig object for the current process")
_event_bus: "EventServiceBase" = Field(description="an event bus to send install events to", default=None)
_event_bus: Optional[EventServiceBase] = Field(description="an event bus to send install events to", default=None)
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
"""
@ -345,8 +341,10 @@ class ModelManagerService(ModelManagerServiceBase):
installation and download events will be sent to the event bus.
"""
self._event_bus = event_bus
handlers = [self._event_bus.emit_model_event] if self._event_bus else None
self._loader = ModelLoad(config, event_handlers=handlers)
kwargs: Dict[str, Any] = {}
if self._event_bus:
kwargs.update(event_handlers=[self._event_bus.emit_model_event])
self._loader = ModelLoad(config, **kwargs)
def get_model(
self,
@ -365,8 +363,8 @@ class ModelManagerService(ModelManagerServiceBase):
if context:
self._emit_load_event(
context=context,
key=key,
submodel_type=submodel_type,
model_key=key,
submodel=submodel_type,
model_info=model_info,
)
@ -416,16 +414,18 @@ class ModelManagerService(ModelManagerServiceBase):
assertion error if the name already exists.
"""
self.logger.debug(f"add/update model {model_path}")
return self._loader.installer.install(
model_path,
probe_override=model_attributes,
return ModelInstallJob.parse_obj(
self._loader.installer.install(
model_path,
probe_override=model_attributes,
)
)
def install_model(
self,
source: Union[str, Path, AnyHttpUrl],
priority: int = 10,
model_attributes: Optional[Dict[str, Any]] = None,
priority: Optional[int] = 10,
) -> ModelInstallJob:
"""
Add a model using a path, repo_id or URL.
@ -438,11 +438,13 @@ class ModelManagerService(ModelManagerServiceBase):
"""
self.logger.debug(f"add model {source}")
variant = "fp16" if self._loader.precision == "float16" else None
return self._loader.installer.install(
source,
probe_override=model_attributes,
variant=variant,
priority=priority,
return ModelInstallJob.parse_obj(
self._loader.installer.install(
source,
probe_override=model_attributes,
variant=variant,
priority=priority,
)
)
def list_install_jobs(self) -> List[ModelInstallJob]:
@ -453,9 +455,9 @@ class ModelManagerService(ModelManagerServiceBase):
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
return self._loader.queue.id_to_job(id)
return ModelInstallJob.parse_obj(self._loader.queue.id_to_job(id))
def wait_for_installs(self) -> Dict[str, str]:
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.

View File

@ -1,6 +1,4 @@
"""
Initialization file for invokeai.backend.model_manager.config
"""
"""Initialization file for invokeai.backend.model_manager.config."""
from .config import ( # noqa F401
BaseModelType,
InvalidModelConfigException,

View File

@ -1,5 +1,6 @@
"""
Manage a RAM cache of diffusion/transformer models for fast switching.
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
grows larger than a preset maximum, then the least recently used
model will be cleared and (re)loaded from disk when next needed.
@ -22,11 +23,11 @@ import sys
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types
from typing import Any, Dict, List, Optional, Type, Union
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.util import InvokeAILogger, Logger
from ..util import GIG
from ..util.devices import choose_torch_device
@ -97,7 +98,7 @@ class ModelCache(object):
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
logger: Logger = InvokeAILogger.get_logger(),
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -122,8 +123,8 @@ class ModelCache(object):
# used for stats collection
self.stats: Optional[CacheStats] = None
self._cached_models = dict()
self._cache_stack = list()
self._cached_models: Dict[str, _CacheRecord] = dict()
self._cache_stack: List[str] = list()
# Note that the combination of model_path and submodel_type
# are sufficient to generate a unique cache key. This key
@ -221,8 +222,12 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
class ModelLocker(object):
"""Context manager that locks models into VRAM."""
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
Initialize a context manager object that locks models into VRAM.
:param cache: The model_cache object
:param key: The key of the model to lock in GPU
:param model: The model to lock

View File

@ -240,6 +240,7 @@ AnyModelConfig = Union[
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
ModelConfigBase,
]
@ -324,6 +325,8 @@ class ModelConfigFactory(object):
# TO DO: Move this somewhere else
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()

View File

@ -19,9 +19,8 @@
import re
from contextlib import nullcontext
from io import BytesIO
from pathlib import Path
from typing import Optional, Union
from typing import Dict, Optional, Union
import requests
import torch
@ -1223,7 +1222,7 @@ def download_from_original_stable_diffusion_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1272,15 +1271,15 @@ def download_from_original_stable_diffusion_ckpt(
# only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = requests.get(config_url).text
original_config = OmegaConf.load(original_config_file)
if original_config["model"]["params"].get("use_ema") is not None:
extract_ema = original_config["model"]["params"]["use_ema"]
if original_config.model["params"].get("use_ema") is not None:
extract_ema = original_config.model["params"]["use_ema"]
if (
model_version in [BaseModelType.StableDiffusion2, BaseModelType.StableDiffusion1]
and original_config["model"]["params"].get("parameterization") == "v"
and original_config.model["params"].get("parameterization") == "v"
):
prediction_type = "v_prediction"
upcast_attention = True
@ -1312,11 +1311,11 @@ def download_from_original_stable_diffusion_ckpt(
num_in_channels = 4
if "unet_config" in original_config.model.params:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
"parameterization" in original_config.model["params"]
and original_config.model["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
@ -1437,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
if model_type == "FrozenOpenCLIPEmbedder":
config_name = "stabilityai/stable-diffusion-2"
config_kwargs = {"subfolder": "text_encoder"}
config_kwargs: Dict[str, Union[str, int]] = {"subfolder": "text_encoder"}
text_model = convert_open_clip_checkpoint(checkpoint, config_name, **config_kwargs)
tokenizer = CLIPTokenizer.from_pretrained(CONVERT_MODEL_ROOT / "stable-diffusion-2-clip", subfolder="tokenizer")
@ -1664,7 +1663,7 @@ def download_controlnet_from_original_ckpt(
# scan model
scan_result = scan_file_path(checkpoint_path)
if scan_result.infected_files != 0:
raise "The model {checkpoint_path} is potentially infected by malware. Aborting import."
raise Exception("The model {checkpoint_path} is potentially infected by malware. Aborting import.")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
@ -1685,7 +1684,7 @@ def download_controlnet_from_original_ckpt(
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
original_config.model["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config")
@ -1725,7 +1724,7 @@ def convert_ckpt_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(checkpoint_path, **kwargs)
pipe = download_from_original_stable_diffusion_ckpt(str(checkpoint_path), **kwargs)
pipe.save_pretrained(
dump_path,
@ -1743,6 +1742,6 @@ def convert_controlnet_to_diffusers(
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path, **kwargs)
pipe = download_controlnet_from_original_ckpt(str(checkpoint_path), **kwargs)
pipe.save_pretrained(dump_path, safe_serialization=True)

View File

@ -1,7 +1,5 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Abstract base class for a multithreaded model download queue.
"""
"""Abstract base class for a multithreaded model download queue."""
from abc import ABC, abstractmethod
from enum import Enum
@ -56,7 +54,7 @@ class DownloadJobBase(BaseModel):
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder
source: str = Field(description="URL or repo_id to download")
source: Union[str, Path] = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
metadata: ModelSourceMetadata = Field(
description="Model metadata (source-specific)", default_factory=ModelSourceMetadata
@ -84,7 +82,8 @@ class DownloadJobBase(BaseModel):
def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list."""
self.event_handlers.append(handler)
if self.event_handlers is not None:
self.event_handlers.append(handler)
def clear_event_handlers(self):
"""Clear all event handlers."""
@ -226,9 +225,7 @@ class DownloadQueueBase(ABC):
@abstractmethod
def prune_jobs(self):
"""
Prune completed and errored queue items from the job list.
"""
"""Prune completed and errored queue items from the job list."""
pass
@abstractmethod

View File

@ -7,19 +7,18 @@ import shutil
import threading
import time
import traceback
from json import JSONDecodeError
from pathlib import Path
from queue import PriorityQueue
from typing import Dict, List, Optional, Set, Tuple, Union
import requests
from huggingface_hub import HfApi, hf_hub_url
from pydantic import Field, ValidationError, validator
from pydantic import Field, ValidationError, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util import InvokeAILogger, Logger
from ..storage import DuplicateModelException
from . import HTTP_RE, REPO_ID_RE
@ -56,19 +55,20 @@ class DownloadJobRepoID(DownloadJobBase):
"""Download repo ids."""
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
source: str = Field(description="URL or repo_id to download")
@validator("source")
@classmethod
def _validate_source(cls, v: str) -> str:
if not re.match(REPO_ID_RE, v):
raise ValidationError(f"{v} invalid repo_id")
raise ValidationError(f"{v} invalid repo_id", cls)
return v
class DownloadJobPath(DownloadJobBase):
"""Handle file paths."""
source: Path = Field(description="Path to a file or directory to install")
source: Union[str, Path] = Field(description="Path to a file or directory to install")
class DownloadQueue(DownloadQueueBase):
@ -77,8 +77,8 @@ class DownloadQueue(DownloadQueueBase):
_jobs: Dict[int, DownloadJobBase]
_worker_pool: Set[threading.Thread]
_queue: PriorityQueue
_lock: threading.Lock
_logger: InvokeAILogger
_lock: threading.RLock
_logger: Logger
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
_next_job_id: int = 0
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
@ -126,6 +126,7 @@ class DownloadQueue(DownloadQueueBase):
"""Create a download job and return its ID."""
kwargs = dict()
cls = DownloadJobBase
if Path(source).exists():
cls = DownloadJobPath
elif re.match(REPO_ID_RE, str(source)):
@ -137,7 +138,7 @@ class DownloadQueue(DownloadQueueBase):
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
job = cls(
source=source,
source=str(source),
destination=Path(destdir) / (filename or "."),
access_token=access_token,
event_handlers=(event_handlers or self._event_handlers),
@ -179,7 +180,7 @@ class DownloadQueue(DownloadQueueBase):
def list_jobs(self) -> List[DownloadJobBase]:
"""List all the jobs."""
return self._jobs.values()
return list(self._jobs.values())
def change_priority(self, job: DownloadJobBase, delta: int):
"""Change the priority of a job. Smaller priorities run first."""
@ -193,9 +194,7 @@ class DownloadQueue(DownloadQueueBase):
self._lock.release()
def prune_jobs(self):
"""
Prune completed and errored queue items from the job list.
"""
"""Prune completed and errored queue items from the job list."""
try:
to_delete = set()
self._lock.acquire()
@ -334,7 +333,7 @@ class DownloadQueue(DownloadQueueBase):
elif isinstance(job, DownloadJobPath):
self._download_path(job)
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}")
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
@ -354,7 +353,7 @@ class DownloadQueue(DownloadQueueBase):
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, metadata_url):
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
@ -364,16 +363,16 @@ class DownloadQueue(DownloadQueueBase):
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, metadata_url):
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return url
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
@ -419,16 +418,19 @@ class DownloadQueue(DownloadQueueBase):
if job.destination.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
file_name = ""
if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]):
file_name = match.group(1)
assert file_name != ""
self._validate_filename(
job.destination, file_name
job.destination.as_posix(), file_name
) # will raise a ValueError exception if file_name is suspicious
except ValueError:
self._logger.warning(
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
)
file_name = os.path.basename(url)
except KeyError:
except (KeyError, AssertionError):
file_name = os.path.basename(url)
job.destination = job.destination / file_name
dest = job.destination
@ -537,12 +539,13 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
job.subqueue = self.__class__(
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
try:
job = DownloadJobRepoID.parse_obj(job)
repo_id = job.source
variant = job.variant
if not job.metadata:
@ -550,12 +553,12 @@ class DownloadQueue(DownloadQueueBase):
urls_to_download = self._get_repo_info(repo_id, variant=variant, metadata=job.metadata)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded = dict()
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
job.subqueue.create_download_job(
subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
@ -569,6 +572,7 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
job.subqueue = subqueue
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
@ -579,7 +583,7 @@ class DownloadQueue(DownloadQueueBase):
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
@ -602,7 +606,7 @@ class DownloadQueue(DownloadQueueBase):
paths = [x for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, "model_index.json")
urls = [
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), Path(x.name), sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
@ -614,7 +618,7 @@ class DownloadQueue(DownloadQueueBase):
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames = dict()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)

View File

@ -57,9 +57,9 @@ class FastModelHash(object):
# only tally tensor files
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = Path(root) / file
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
components.update({str(path): fast_hash})
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order
md5 = hashlib.md5()

View File

@ -54,7 +54,7 @@ import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import move, rmtree
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from pydantic import Field
from pydantic.networks import AnyHttpUrl
@ -166,7 +166,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str:
def install_path(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> str:
"""
Probe, register and install the model in the models directory.
@ -174,7 +174,7 @@ class ModelInstallBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param info: Optional ModelProbeInfo object. If not provided, model will be probed.
:param overrides: Dictionary of model probe info fields that, if present, override probed values.
:returns id: The string ID of the installed model.
"""
pass
@ -184,6 +184,7 @@ class ModelInstallBase(ABC):
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
variant: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
@ -225,7 +226,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def wait_for_installs(self) -> Dict[str, str]:
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""
Wait for all pending installs to complete.
@ -293,7 +294,7 @@ class ModelInstallBase(ABC):
pass
@abstractmethod
def sync_model_path(self, key) -> Path:
def sync_model_path(self, key) -> ModelConfigBase:
"""
Move model into the location indicated by its basetype, type and name.
@ -324,11 +325,11 @@ class ModelInstall(ModelInstallBase):
_logger: Logger
_store: ModelConfigStore
_download_queue: DownloadQueueBase
_async_installs: Dict[str, str]
_installed: Set[Path] = Field(default=set)
_async_installs: Dict[Union[str, Path, AnyHttpUrl], Optional[str]]
_installed: Set[str] = Field(default=set)
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_legacy_configs = {
_legacy_configs: Dict[BaseModelType, Dict[ModelVariantType, Union[str, dict]]] = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
@ -357,13 +358,13 @@ class ModelInstall(ModelInstallBase):
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[Logger] = None,
download: Optional[DownloadQueueBase] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
event_handlers: List[DownloadEventHandler] = [],
): # noqa D107 - use base class docstrings
self._app_config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
self._store = store or get_config_store(self._app_config.model_conf_path)
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers)
self._async_installs = dict()
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None
@ -403,7 +404,8 @@ class ModelInstall(ModelInstallBase):
)
# add 'main' specific fields
if info.model_type == ModelType.Main:
registration_data.update(variant=info.variant_type)
if info.variant_type:
registration_data.update(variant=info.variant_type)
if info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
@ -416,7 +418,7 @@ class ModelInstall(ModelInstallBase):
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._app_config.legacy_conf_dir, config_file).as_posix(),
config=Path(self._app_config.legacy_conf_dir, str(config_file)).as_posix(),
)
except KeyError as exc:
raise InvalidModelException(
@ -453,12 +455,12 @@ class ModelInstall(ModelInstallBase):
return old_path.replace(new_path)
def _probe_model(self, model_path: Union[Path, str], overrides: Optional[Dict[str, Any]] = None) -> ModelProbeInfo:
info: ModelProbeInfo = ModelProbe.probe(model_path)
info: ModelProbeInfo = ModelProbe.probe(Path(model_path))
if overrides: # used to override probe fields
for key, value in overrides.items():
try:
setattr(info, key, value) # skip validation errors
except:
except Exception:
pass
return info
@ -488,11 +490,11 @@ class ModelInstall(ModelInstallBase):
self,
source: Union[str, Path, AnyHttpUrl],
inplace: bool = True,
priority: int = 10,
variant: Optional[str] = None,
probe_override: Optional[Dict[str, Any]] = None,
metadata: Optional[ModelSourceMetadata] = None,
access_token: Optional[str] = None,
priority: Optional[int] = 10,
) -> DownloadJobBase: # noqa D102
queue = self._download_queue
@ -502,7 +504,8 @@ class ModelInstall(ModelInstallBase):
if inplace and Path(source).exists()
else self._complete_installation_handler
)
job.probe_override = probe_override
if isinstance(job, ModelInstallJob):
job.probe_override = probe_override
if metadata:
job.metadata = metadata
job.add_event_handler(handler)
@ -512,6 +515,7 @@ class ModelInstall(ModelInstallBase):
return job
def _complete_installation_handler(self, job: DownloadJobBase):
job = ModelInstallJob.parse_obj(job) # this upcast should succeed
if job.status == "completed":
self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.")
model_id = self.install_path(job.destination, job.probe_override)
@ -537,6 +541,7 @@ class ModelInstall(ModelInstallBase):
self._tmpdir = None
def _complete_registration_handler(self, job: DownloadJobBase):
job = ModelInstallJob.parse_obj(job) # upcast should succeed
if job.status == "completed":
self._logger.info(f"{job.source}: Installing in place.")
model_id = self.register_path(job.destination, job.probe_override)
@ -567,7 +572,7 @@ class ModelInstall(ModelInstallBase):
models_dir = self._app_config.models_path
if not old_path.is_relative_to(models_dir):
return old_path
return model
new_path = models_dir / model.base_model.value / model.model_type.value / model.name
self._logger.info(
@ -592,7 +597,6 @@ class ModelInstall(ModelInstallBase):
# we simply probe and register/install it. The job does not actually do anything, but we
# create one anyway in order to have similar behavior for local files, URLs and repo_ids.
if Path(source).exists(): # a path that is already on disk
source = Path(source)
destdir = source
return ModelInstallPathJob(source=source, destination=Path(destdir))
@ -600,6 +604,7 @@ class ModelInstall(ModelInstallBase):
models_dir = self._app_config.models_path
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
cls = ModelInstallJob
if re.match(REPO_ID_RE, str(source)):
cls = ModelInstallRepoIDJob
kwargs = dict(variant=variant)
@ -609,10 +614,15 @@ class ModelInstall(ModelInstallBase):
else:
raise ValueError(f"'{source}' is not recognized as a local file, directory, repo_id or URL")
return cls(
source=source, destination=Path(self._tmpdir.name), access_token=access_token, priority=priority, **kwargs
source=str(source),
destination=Path(self._tmpdir.name),
access_token=access_token,
priority=priority,
**kwargs,
)
def wait_for_installs(self) -> Dict[str, str]: # noqa D102
def wait_for_installs(self) -> Dict[Union[str, Path, AnyHttpUrl], Optional[str]]:
"""Pause until all installation jobs have completed."""
self._download_queue.join()
id_map = self._async_installs
self._async_installs = dict()
@ -634,9 +644,10 @@ class ModelInstall(ModelInstallBase):
dest_directory: Optional[Path] = None,
) -> ModelConfigBase:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
version and deleting the original checkpoint file if it is in the models
directory.
Convert a checkpoint file into a diffusers folder.
It will delete the cached version ans well as the
original checkpoint file if it is in the models directory.
:param key: Unique key of model.
:dest_directory: Optional place to put converted file. If not specified,
will be stored in the `models_dir`.
@ -668,7 +679,7 @@ class ModelInstall(ModelInstallBase):
update = info.dict()
update.pop("config")
update["model_format"] = "diffusers"
update["path"] = converted_model.location.as_posix()
update["path"] = converted_model.location
if dest_directory:
new_diffusers_path = Path(dest_directory) / info.name

View File

@ -121,11 +121,7 @@ class ModelLoad(ModelLoadBase):
_cache_keys: dict
_models_file: Path
def __init__(
self,
config: InvokeAIAppConfig,
event_handlers: Optional[List[DownloadEventHandler]] = None,
):
def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []):
"""
Initialize ModelLoad object.

View File

@ -12,7 +12,7 @@ from diffusers.models import UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from .models.lora import LoRAModel
from .models.lora import LoRALayerBase, LoRAModel, LoRAModelRaw
"""
loras = [
@ -87,7 +87,7 @@ class ModelPatcher:
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -97,7 +97,7 @@ class ModelPatcher:
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -107,7 +107,7 @@ class ModelPatcher:
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -117,7 +117,7 @@ class ModelPatcher:
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
):
original_weights = dict()
@ -337,7 +337,7 @@ class ONNXModelPatcher:
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModelRaw, torch.Tensor]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
@ -348,7 +348,7 @@ class ONNXModelPatcher:
orig_weights = dict()
try:
blended_loras = dict()
blended_loras: Dict[str, torch.Tensor] = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():

View File

@ -9,7 +9,7 @@ Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
import warnings
from enum import Enum
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Set
from diffusers import DiffusionPipeline
from diffusers import logging as dlogging
@ -17,7 +17,8 @@ from diffusers import logging as dlogging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from . import ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
from . import BaseModelType, ModelConfigBase, ModelConfigStore, ModelInstall, ModelType
from .config import MainConfig
class MergeInterpolationMethod(str, Enum):
@ -102,11 +103,11 @@ class ModelMerger(object):
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
"""
model_paths = list()
model_paths: List[Path] = list()
model_names = list()
config = self._config
store = self._store
base_models = set()
base_models: Set[BaseModelType] = set()
vae = None
assert (
@ -115,6 +116,7 @@ class ModelMerger(object):
for key in model_keys:
info = store.get_model(key)
assert isinstance(info, MainConfig)
model_names.append(info.name)
assert (
info.model_format == "diffusers"

View File

@ -33,7 +33,7 @@ class ModelProbeInfo(BaseModel):
base_type: BaseModelType
format: ModelFormat
hash: str
variant_type: Optional[ModelVariantType] = ModelVariantType("normal")
variant_type: ModelVariantType = ModelVariantType("normal")
prediction_type: Optional[SchedulerPredictionType] = SchedulerPredictionType("v_prediction")
upcast_attention: Optional[bool] = False
image_size: Optional[int] = None
@ -114,7 +114,7 @@ class ModelProbe(ModelProbeBase):
cls,
model_path: Path,
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> Optional[ModelProbeInfo]:
) -> ModelProbeInfo:
"""Probe model."""
try:
model_type = (
@ -129,7 +129,7 @@ class ModelProbe(ModelProbeBase):
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
raise InvalidModelException(f"Unable to determine model type for {model_path}")
probe = probe_class(model_path, prediction_type_helper)
@ -160,7 +160,7 @@ class ModelProbe(ModelProbeBase):
else 512,
)
except Exception:
raise
raise InvalidModelException(f"Unable to determine model type for {model_path}")
return model_info

View File

@ -27,7 +27,7 @@ from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util import InvokeAILogger, Logger
default_logger = InvokeAILogger.get_logger()
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
@ -143,7 +143,7 @@ class ModelSearch(ModelSearchBase):
self.on_search_completed(self._models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = directory
self._directory = Path(directory)
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
@ -155,7 +155,7 @@ class ModelSearch(ModelSearchBase):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
self._pruned_paths.add(Path(root))
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue

View File

@ -1,6 +1,4 @@
"""
Initialization file for invokeai.backend.model_manager.storage
"""
"""Initialization file for invokeai.backend.model_manager.storage."""
import pathlib
from .base import ( # noqa F401

View File

@ -35,13 +35,11 @@ class ModelConfigStore(ABC):
@property
@abstractmethod
def version(self) -> str:
"""
Return the config file/database schema version.
"""
"""Return the config file/database schema version."""
pass
@abstractmethod
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> None:
"""
Add a model to the database.
@ -65,7 +63,7 @@ class ModelConfigStore(ABC):
pass
@abstractmethod
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
"""
Update the model, returning the updated version.
@ -96,7 +94,7 @@ class ModelConfigStore(ABC):
pass
@abstractmethod
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""
Return models containing all of the listed tags.
@ -108,10 +106,8 @@ class ModelConfigStore(ABC):
def search_by_path(
self,
path: Union[str, Path],
) -> Optional[ModelConfigBase]:
"""
Return the model having the indicated path.
"""
) -> Optional[AnyModelConfig]:
"""Return the model having the indicated path."""
pass
@abstractmethod
@ -120,7 +116,7 @@ class ModelConfigStore(ABC):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -133,8 +129,6 @@ class ModelConfigStore(ABC):
"""
pass
def all_models(self) -> List[ModelConfigBase]:
"""
Return all the model configs in the database.
"""
def all_models(self) -> List[AnyModelConfig]:
"""Return all the model configs in the database."""
return self.search_by_name()

View File

@ -52,14 +52,14 @@ def migrate_models_store(config: InvokeAIAppConfig):
except Exception as excp:
print(str(excp))
model_info = store.get_model(new_key)
if vae := stanza.get("vae") and isinstance(model_info, MainConfig):
model_info.vae = (app_config.models_path / vae).as_posix()
if model_config := stanza.get("config") and isinstance(model_info, MainCheckpointConfig):
model_info.config = (app_config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)
store.update_model(new_key, model_info)
if new_key != "<NOKEY>":
model_info = store.get_model(new_key)
if (vae := stanza.get("vae")) and isinstance(model_info, MainConfig):
model_info.vae = (app_config.models_path / vae).as_posix()
if (model_config := stanza.get("config")) and isinstance(model_info, MainCheckpointConfig):
model_info.config = (app_config.root_path / model_config).as_posix()
model_info.description = stanza.get("description")
store.update_model(new_key, model_info)
logger.info(f"Original version of models config file saved as {str(old_file) + '.orig'}")
shutil.move(old_file, str(old_file) + ".orig")

View File

@ -398,14 +398,14 @@ class ModelConfigStoreSQL(ModelConfigStore):
self._lock.release()
return count > 0
def search_by_tag(self, tags: Set[str]) -> List[ModelConfigBase]:
def search_by_tag(self, tags: Set[str]) -> List[AnyModelConfig]:
"""Return models containing all of the listed tags."""
# rather than create a hairy SQL cross-product, we intersect
# tag results in a stepwise fashion at the python level.
results = []
try:
self._lock.acquire()
matches = set()
matches: Set[str] = set()
for tag in tags:
self._cursor.execute(
"""--sql
@ -438,7 +438,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
) -> List[ModelConfigBase]:
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -479,7 +479,5 @@ class ModelConfigStoreSQL(ModelConfigStore):
return results
def search_by_path(self, path: Union[str, Path]) -> Optional[ModelConfigBase]:
"""
Return the model with the indicated path, or None..
"""
"""Return the model with the indicated path, or None."""
raise NotImplementedError("search_by_path not implemented in storage.sql")

View File

@ -48,7 +48,6 @@ from typing import List, Optional, Set, Union
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from ..config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelConfigFactory, ModelType
from .base import (
@ -64,7 +63,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
"""Implementation of the ModelConfigStore ABC using a YAML file."""
_filename: Path
_config: Union[DictConfig, ListConfig]
_config: DictConfig
_lock: threading.RLock
def __init__(self, config_file: Path):
@ -74,7 +73,9 @@ class ModelConfigStoreYAML(ModelConfigStore):
self._lock = threading.RLock()
if not self._filename.exists():
self._initialize_yaml()
self._config = OmegaConf.load(self._filename)
config = OmegaConf.load(self._filename)
assert isinstance(config, DictConfig)
self._config = config
if str(self.version) != CONFIG_FILE_VERSION:
raise ConfigFileVersionMismatchException
@ -101,7 +102,7 @@ class ModelConfigStoreYAML(ModelConfigStore):
@property
def version(self) -> str:
"""Return version of this config file/database."""
return self._config["__metadata__"].get("version")
return self._config.__metadata__.get("version")
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> None:
"""

View File

@ -149,7 +149,7 @@ def _fast_safetensors_reader(path: str):
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
checkpoint = _fast_safetensors_reader(str(path))
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")

View File

@ -294,7 +294,7 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
}
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
return self.FORMATS[levelno]
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
@ -333,7 +333,7 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
}
def log_fmt(self, levelno: int) -> str:
return self.FORMATS.get(levelno)
return self.FORMATS[levelno]
LOG_FORMATTERS = {

View File

@ -104,114 +104,6 @@ def get_obj_from_str(string, reload=False):
return getattr(importlib.import_module(module, package=None), cls)
# DEAD CODE?
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
# create dummy dataset instance
# run prefetching
if idx_to_fn:
res = func(data, worker_id=idx)
else:
res = func(data)
Q.put([idx, res])
Q.put("Done")
# DEAD CODE?
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
# if target_data_type not in ["ndarray", "list"]:
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
logger.warning(
'"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
Q = mp.Queue(1000)
proc = mp.Process
else:
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == "ndarray":
arguments = [[func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc))]
else:
step = int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc)
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate([data[i : i + step] for i in range(0, len(data), step)])
]
processes = []
for i in range(n_proc):
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
processes += [p]
# start processes
logger.info("Start prefetching...")
import time
start = time.time()
gather_res = [[] for _ in range(n_proc)]
try:
for p in processes:
p.start()
k = 0
while k < n_proc:
# get result
res = Q.get()
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
logger.error("Exception: ", e)
for p in processes:
p.terminate()
raise e
finally:
for p in processes:
p.join()
logger.info(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
return out
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])