Merge branch 'main' into fix/sdxl_controlnet

This commit is contained in:
Lincoln Stein
2023-08-17 15:40:51 -04:00
committed by GitHub
238 changed files with 11814 additions and 5370 deletions

View File

@ -21,12 +21,12 @@ import os
import sys
import hashlib
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any
import torch
import logging
import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, SubModelType, ModelBase
@ -41,6 +41,18 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
GIG = 1073741824
@dataclass
class CacheStats(object):
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelLocker(object):
"Forward declaration"
pass
@ -115,6 +127,9 @@ class ModelCache(object):
self.sha_chunksize = sha_chunksize
self.logger = logger
# used for stats collection
self.stats = None
self._cached_models = dict()
self._cache_stack = list()
@ -181,13 +196,14 @@ class ModelCache(object):
model_type=model_type,
submodel_type=submodel,
)
# TODO: lock for no copies on simultaneous calls?
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
# this will remove older cached models until
# there is sufficient room to load the requested model
@ -201,6 +217,17 @@ class ModelCache(object):
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
else:
if self.stats:
self.stats.hits += 1
if self.stats:
self.stats.cache_size = self.max_cache_size * GIG
self.stats.high_watermark = max(self.stats.high_watermark, self._cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[key] = max(
self.stats.loaded_model_sizes.get(key, 0), model_info.get_size(submodel)
)
with suppress(Exception):
self._cache_stack.remove(key)
@ -280,14 +307,14 @@ class ModelCache(object):
"""
Given the HF repo id or path to a model on disk, returns a unique
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
:param model_path: Path to model file/directory on disk.
"""
return self._local_model_hash(model_path)
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
"""Return the current size of the cache, in GB."""
return self._cache_size() / GIG
def _has_cuda(self) -> bool:
return self.execution_device.type == "cuda"
@ -310,12 +337,15 @@ class ModelCache(object):
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
return sum([m.size for m in self._cached_models.values()])
def _make_cache_room(self, model_size):
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_models.values()])
current_size = self._cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
@ -364,6 +394,8 @@ class ModelCache(object):
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry

View File

@ -481,9 +481,19 @@ class ControlNetFolderProbe(FolderProbeBase):
with open(config_file, "r") as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
return (
BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):