mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
mostly ported to new manager API; needs testing
This commit is contained in:
parent
af8c7c7d29
commit
e0214a32bc
@ -180,9 +180,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
|
||||||
model_info = choose_model(model_manager, self.model)
|
model_info = choose_model(model_manager, self.model)
|
||||||
model_name = model_info['model_name']
|
model_name = model_info.name
|
||||||
model_hash = model_info['hash']
|
model_hash = model_info.hash
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
model_ctx: StableDiffusionGeneratorPipeline = model_info.context
|
||||||
|
with model_ctx as model:
|
||||||
model.scheduler = get_scheduler(
|
model.scheduler = get_scheduler(
|
||||||
model=model,
|
model=model,
|
||||||
scheduler_name=self.scheduler
|
scheduler_name=self.scheduler
|
||||||
@ -200,11 +201,12 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
self.seamless_axes
|
self.seamless_axes
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model_ctx
|
||||||
|
|
||||||
|
|
||||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
|
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
@ -230,11 +232,10 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
with self.get_model(context.services.model_manager) as model:
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
|
||||||
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
result_latents, result_attention_map_saver = model.latents_from_embeddings(
|
||||||
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
|
||||||
noise=noise,
|
noise=noise,
|
||||||
@ -284,7 +285,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
|
|||||||
def step_callback(state: PipelineIntermediateState):
|
def step_callback(state: PipelineIntermediateState):
|
||||||
self.dispatch_progress(context, source_node_id, state)
|
self.dispatch_progress(context, source_node_id, state)
|
||||||
|
|
||||||
model = self.get_model(context.services.model_manager)
|
with self.get_model(context.services.model_manager) as model:
|
||||||
conditioning_data = self.get_conditioning_data(model)
|
conditioning_data = self.get_conditioning_data(model)
|
||||||
|
|
||||||
# TODO: Verify the noise is the right size
|
# TODO: Verify the noise is the right size
|
||||||
|
@ -7,7 +7,7 @@ def choose_model(model_manager: ModelManager, model_name: str):
|
|||||||
if model_manager.valid_model(model_name):
|
if model_manager.valid_model(model_name):
|
||||||
model = model_manager.get_model(model_name)
|
model = model_manager.get_model(model_name)
|
||||||
else:
|
else:
|
||||||
model = model_manager.get_model()
|
model = model_manager.get_model(model_manager.default_model())
|
||||||
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.")
|
logger.warning(f"'{model_name}' is not a valid model name. Using default model \'{model.name}\' instead.")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -47,22 +47,21 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
else:
|
else:
|
||||||
embedding_path = None
|
embedding_path = None
|
||||||
|
|
||||||
# migrate legacy models
|
|
||||||
ModelManager.migrate_models()
|
|
||||||
|
|
||||||
# creating the model manager
|
# creating the model manager
|
||||||
try:
|
try:
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
precision = 'float16' if config.precision=='float16' \
|
if config.precision=="auto":
|
||||||
else 'float32' if config.precision=='float32' \
|
precision = choose_precision(device)
|
||||||
else choose_precision(device)
|
dtype = torch.float32 if precision=='float32' \
|
||||||
|
else torch.float16
|
||||||
|
|
||||||
model_manager = ModelManager(
|
model_manager = ModelManager(
|
||||||
OmegaConf.load(config.conf),
|
config.conf,
|
||||||
precision=precision,
|
precision=dtype,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=config.max_loaded_models,
|
max_loaded_models=config.max_loaded_models,
|
||||||
embedding_path = Path(embedding_path),
|
# temporarily disabled until model manager stabilizes
|
||||||
|
# embedding_path = Path(embedding_path),
|
||||||
logger = logger,
|
logger = logger,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
|
@ -10,7 +10,7 @@ from .generator import (
|
|||||||
Img2Img,
|
Img2Img,
|
||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager, ModelCache, ModelStatus, SDModelType
|
||||||
from .safety_checker import SafetyChecker
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -37,7 +37,7 @@ from .safety_checker import SafetyChecker
|
|||||||
from .prompting import get_uc_and_c_and_ec
|
from .prompting import get_uc_and_c_and_ec
|
||||||
from .prompting.conditioning import log_tokenization
|
from .prompting.conditioning import log_tokenization
|
||||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||||
from .util import choose_precision, choose_torch_device
|
from .util import choose_precision, choose_torch_device, torch_dtype
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
@ -50,7 +50,6 @@ def fix_func(orig):
|
|||||||
return new_func
|
return new_func
|
||||||
return orig
|
return orig
|
||||||
|
|
||||||
|
|
||||||
torch.rand = fix_func(torch.rand)
|
torch.rand = fix_func(torch.rand)
|
||||||
torch.rand_like = fix_func(torch.rand_like)
|
torch.rand_like = fix_func(torch.rand_like)
|
||||||
torch.randn = fix_func(torch.randn)
|
torch.randn = fix_func(torch.randn)
|
||||||
@ -156,7 +155,6 @@ class Generate:
|
|||||||
weights=None,
|
weights=None,
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
|
||||||
self.height = None
|
self.height = None
|
||||||
self.width = None
|
self.width = None
|
||||||
self.model_manager = None
|
self.model_manager = None
|
||||||
@ -171,7 +169,7 @@ class Generate:
|
|||||||
self.seamless_axes = {"x", "y"}
|
self.seamless_axes = {"x", "y"}
|
||||||
self.hires_fix = False
|
self.hires_fix = False
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
self.model = None # empty for now
|
self.model_context = None # empty for now
|
||||||
self.model_hash = None
|
self.model_hash = None
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.device = None
|
self.device = None
|
||||||
@ -219,12 +217,12 @@ class Generate:
|
|||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(
|
self.model_manager = ModelManager(
|
||||||
mconfig,
|
conf,
|
||||||
self.device,
|
self.device,
|
||||||
self.precision,
|
torch_dtype(self.device),
|
||||||
max_loaded_models=max_loaded_models,
|
max_loaded_models=max_loaded_models,
|
||||||
sequential_offload=self.free_gpu_mem,
|
sequential_offload=self.free_gpu_mem,
|
||||||
embedding_path=Path(self.embedding_path),
|
# embedding_path=Path(self.embedding_path),
|
||||||
)
|
)
|
||||||
# don't accept invalid models
|
# don't accept invalid models
|
||||||
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
|
||||||
@ -418,13 +416,14 @@ class Generate:
|
|||||||
with_variations = [] if with_variations is None else with_variations
|
with_variations = [] if with_variations is None else with_variations
|
||||||
|
|
||||||
# will instantiate the model or return it from cache
|
# will instantiate the model or return it from cache
|
||||||
model = self.set_model(self.model_name)
|
model_context = self.set_model(self.model_name)
|
||||||
|
|
||||||
# self.width and self.height are set by set_model()
|
# self.width and self.height are set by set_model()
|
||||||
# to the width and height of the image training set
|
# to the width and height of the image training set
|
||||||
width = width or self.width
|
width = width or self.width
|
||||||
height = height or self.height
|
height = height or self.height
|
||||||
|
|
||||||
|
with model_context as model:
|
||||||
if isinstance(model, DiffusionPipeline):
|
if isinstance(model, DiffusionPipeline):
|
||||||
configure_model_padding(model.unet, seamless, seamless_axes)
|
configure_model_padding(model.unet, seamless, seamless_axes)
|
||||||
configure_model_padding(model.vae, seamless, seamless_axes)
|
configure_model_padding(model.vae, seamless, seamless_axes)
|
||||||
@ -461,13 +460,13 @@ class Generate:
|
|||||||
|
|
||||||
if sampler_name and (sampler_name != self.sampler_name):
|
if sampler_name and (sampler_name != self.sampler_name):
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self._set_scheduler()
|
self._set_scheduler(model)
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||||
prompt,
|
prompt,
|
||||||
lambda concepts: self.load_huggingface_concepts(concepts),
|
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||||
self.model.textual_inversion_manager.get_all_trigger_strings(),
|
model.textual_inversion_manager.get_all_trigger_strings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@ -479,7 +478,7 @@ class Generate:
|
|||||||
try:
|
try:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt,
|
prompt,
|
||||||
model=self.model,
|
model=model,
|
||||||
skip_normalize_legacy_blend=skip_normalize,
|
skip_normalize_legacy_blend=skip_normalize,
|
||||||
log_tokens=self.log_tokenization,
|
log_tokens=self.log_tokenization,
|
||||||
)
|
)
|
||||||
@ -662,9 +661,10 @@ class Generate:
|
|||||||
|
|
||||||
# used by multiple postfixers
|
# used by multiple postfixers
|
||||||
# todo: cross-attention control
|
# todo: cross-attention control
|
||||||
|
with self.model_context as model:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||||
prompt,
|
prompt,
|
||||||
model=self.model,
|
model=model,
|
||||||
skip_normalize_legacy_blend=opt.skip_normalize,
|
skip_normalize_legacy_blend=opt.skip_normalize,
|
||||||
log_tokens=log_tokenization,
|
log_tokens=log_tokenization,
|
||||||
)
|
)
|
||||||
@ -852,7 +852,8 @@ class Generate:
|
|||||||
cn = class_name
|
cn = class_name
|
||||||
module = importlib.import_module(mn)
|
module = importlib.import_module(mn)
|
||||||
constructor = getattr(module, cn)
|
constructor = getattr(module, cn)
|
||||||
return constructor(self.model, self.precision)
|
with self.model_context as model:
|
||||||
|
return constructor(model, self.precision)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""
|
"""
|
||||||
@ -869,8 +870,8 @@ class Generate:
|
|||||||
If the model fails to load for some reason, will attempt to load the previously-
|
If the model fails to load for some reason, will attempt to load the previously-
|
||||||
loaded model (if any). If that fallback fails, will raise an AssertionError
|
loaded model (if any). If that fallback fails, will raise an AssertionError
|
||||||
"""
|
"""
|
||||||
if self.model_name == model_name and self.model is not None:
|
if self.model_name == model_name and self.model_context is not None:
|
||||||
return self.model
|
return self.model_context
|
||||||
|
|
||||||
previous_model_name = self.model_name
|
previous_model_name = self.model_name
|
||||||
|
|
||||||
@ -881,11 +882,9 @@ class Generate:
|
|||||||
f'** "{model_name}" is not a known model name. Cannot change.'
|
f'** "{model_name}" is not a known model name. Cannot change.'
|
||||||
)
|
)
|
||||||
|
|
||||||
cache.print_vram_usage()
|
|
||||||
|
|
||||||
# have to get rid of all references to model in order
|
# have to get rid of all references to model in order
|
||||||
# to free it from GPU memory
|
# to free it from GPU memory
|
||||||
self.model = None
|
self.model_context = None
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -902,29 +901,33 @@ class Generate:
|
|||||||
raise e
|
raise e
|
||||||
model_name = previous_model_name
|
model_name = previous_model_name
|
||||||
|
|
||||||
self.model = model_data["model"]
|
self.model_context = model_data.context
|
||||||
self.width = model_data["width"]
|
self.width = 512
|
||||||
self.height = model_data["height"]
|
self.height = 512
|
||||||
self.model_hash = model_data["hash"]
|
self.model_hash = model_data.hash
|
||||||
|
|
||||||
# uncache generators so they pick up new models
|
# uncache generators so they pick up new models
|
||||||
self.generators = {}
|
self.generators = {}
|
||||||
|
|
||||||
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
|
set_seed(random.randrange(0, np.iinfo(np.uint32).max))
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self._set_scheduler() # requires self.model_name to be set first
|
with self.model_context as model:
|
||||||
return self.model
|
self._set_scheduler(model) # requires self.model_name to be set first
|
||||||
|
return self.model_context
|
||||||
|
|
||||||
def load_huggingface_concepts(self, concepts: list[str]):
|
def load_huggingface_concepts(self, concepts: list[str]):
|
||||||
self.model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
with self.model_context as model:
|
||||||
|
model.textual_inversion_manager.load_huggingface_concepts(concepts)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
|
||||||
return self.model.textual_inversion_manager.hf_concepts_library
|
with self.model_context as model:
|
||||||
|
return model.textual_inversion_manager.hf_concepts_library
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def embedding_trigger_strings(self) -> List[str]:
|
def embedding_trigger_strings(self) -> List[str]:
|
||||||
return self.model.textual_inversion_manager.get_all_trigger_strings()
|
with self.model_context as model:
|
||||||
|
return model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
|
||||||
def correct_colors(self, image_list, reference_image_path, image_callback=None):
|
def correct_colors(self, image_list, reference_image_path, image_callback=None):
|
||||||
reference_image = Image.open(reference_image_path)
|
reference_image = Image.open(reference_image_path)
|
||||||
@ -1044,8 +1047,8 @@ class Generate:
|
|||||||
def is_legacy_model(self, model_name) -> bool:
|
def is_legacy_model(self, model_name) -> bool:
|
||||||
return self.model_manager.is_legacy(model_name)
|
return self.model_manager.is_legacy(model_name)
|
||||||
|
|
||||||
def _set_scheduler(self):
|
def _set_scheduler(self,model):
|
||||||
default = self.model.scheduler
|
default = model.scheduler
|
||||||
|
|
||||||
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
|
||||||
scheduler_map = dict(
|
scheduler_map = dict(
|
||||||
@ -1069,7 +1072,7 @@ class Generate:
|
|||||||
msg = (
|
msg = (
|
||||||
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
|
||||||
)
|
)
|
||||||
self.sampler = sampler_class.from_config(self.model.scheduler.config)
|
self.sampler = sampler_class.from_config(model.scheduler.config)
|
||||||
else:
|
else:
|
||||||
msg = (
|
msg = (
|
||||||
f" Unsupported Sampler: {self.sampler_name} "+
|
f" Unsupported Sampler: {self.sampler_name} "+
|
||||||
|
@ -123,9 +123,9 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
model_info = self.model_info
|
model_info = self.model_info
|
||||||
model_name = model_info['model_name']
|
model_name = model_info.name
|
||||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
model_hash = model_info.hash
|
||||||
model_hash = model_info['hash']
|
with model_info.context as model:
|
||||||
scheduler: Scheduler = self.get_scheduler(
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
model=model,
|
model=model,
|
||||||
scheduler_name=generator_args.get('scheduler')
|
scheduler_name=generator_args.get('scheduler')
|
||||||
@ -275,7 +275,6 @@ class Embiggen(Txt2Img):
|
|||||||
from .embiggen import Embiggen
|
from .embiggen import Embiggen
|
||||||
return Embiggen
|
return Embiggen
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
|
@ -2,4 +2,4 @@
|
|||||||
Initialization file for invokeai.backend.model_management
|
Initialization file for invokeai.backend.model_management
|
||||||
"""
|
"""
|
||||||
from .model_manager import ModelManager
|
from .model_manager import ModelManager
|
||||||
from .model_cache import ModelCache, ModelStatus
|
from .model_cache import ModelCache, ModelStatus, SDModelType
|
||||||
|
@ -78,6 +78,10 @@ class UnscannableModelException(Exception):
|
|||||||
"Raised when picklescan is unable to scan a legacy model file"
|
"Raised when picklescan is unable to scan a legacy model file"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class ModelLocker(object):
|
||||||
|
"Forward declaration"
|
||||||
|
pass
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -112,8 +116,6 @@ class ModelCache(object):
|
|||||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
@ -124,7 +126,7 @@ class ModelCache(object):
|
|||||||
legacy_info: LegacyInfo=None,
|
legacy_info: LegacyInfo=None,
|
||||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
||||||
gpu_load: bool=True,
|
gpu_load: bool=True,
|
||||||
)->Generator[ModelClass, None, None]:
|
)->ModelLocker: # ?? what does it return
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
|
||||||
Use like this:
|
Use like this:
|
||||||
@ -188,29 +190,45 @@ class ModelCache(object):
|
|||||||
if submodel:
|
if submodel:
|
||||||
model = getattr(model, submodel.name)
|
model = getattr(model, submodel.name)
|
||||||
|
|
||||||
if gpu_load and hasattr(model,'to'):
|
return self.ModelLocker(self, key, model, gpu_load)
|
||||||
try:
|
|
||||||
self.loaded_models.add(key)
|
class ModelLocker(object):
|
||||||
self.locked_models[key] += 1
|
def __init__(self, cache, key, model, gpu_load):
|
||||||
if self.lazy_offloading:
|
self.gpu_load = gpu_load
|
||||||
self._offload_unlocked_models()
|
self.cache = cache
|
||||||
self.logger.debug(f'Loading {key} into {self.execution_device}')
|
self.key = key
|
||||||
model.to(self.execution_device) # move into GPU
|
# This will keep a copy of the model in RAM until the locker
|
||||||
self._print_cuda_stats()
|
# is garbage collected. Needs testing!
|
||||||
yield model
|
self.model = model
|
||||||
finally:
|
|
||||||
self.locked_models[key] -= 1
|
def __enter__(self)->ModelClass:
|
||||||
if not self.lazy_offloading:
|
cache = self.cache
|
||||||
self._offload_unlocked_models()
|
key = self.key
|
||||||
self._print_cuda_stats()
|
model = self.model
|
||||||
|
if self.gpu_load and hasattr(model,'to'):
|
||||||
|
cache.loaded_models.add(key)
|
||||||
|
cache.locked_models[key] += 1
|
||||||
|
if cache.lazy_offloading:
|
||||||
|
cache._offload_unlocked_models()
|
||||||
|
cache.logger.debug(f'Loading {key} into {cache.execution_device}')
|
||||||
|
model.to(cache.execution_device) # move into GPU
|
||||||
|
cache._print_cuda_stats()
|
||||||
else:
|
else:
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
if hasattr(model,'to') and (key in self.loaded_models
|
if hasattr(model,'to') and (key in cache.loaded_models
|
||||||
and self.locked_models[key] == 0):
|
and cache.locked_models[key] == 0):
|
||||||
model.to(self.storage_device)
|
model.to(cache.storage_device)
|
||||||
self.loaded_models.remove(key)
|
cache.loaded_models.remove(key)
|
||||||
yield model
|
return model
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
key = self.key
|
||||||
|
cache = self.cache
|
||||||
|
cache.locked_models[key] -= 1
|
||||||
|
if not cache.lazy_offloading:
|
||||||
|
cache._offload_unlocked_models()
|
||||||
|
cache._print_cuda_stats()
|
||||||
|
|
||||||
def attach_part(self,
|
def attach_part(self,
|
||||||
diffusers_model: StableDiffusionPipeline,
|
diffusers_model: StableDiffusionPipeline,
|
||||||
@ -381,8 +399,9 @@ class ModelCache(object):
|
|||||||
revisions = [revision] if revision \
|
revisions = [revision] if revision \
|
||||||
else ['fp16','main'] if self.precision==torch.float16 \
|
else ['fp16','main'] if self.precision==torch.float16 \
|
||||||
else ['main']
|
else ['main']
|
||||||
extra_args = {'precision': self.precision} \
|
extra_args = {'torch_dtype': self.precision,
|
||||||
if model_class in DiffusionClasses \
|
'safety_checker': None}\
|
||||||
|
if model_class in DiffusionClasses\
|
||||||
else {}
|
else {}
|
||||||
|
|
||||||
# silence transformer and diffuser warnings
|
# silence transformer and diffuser warnings
|
||||||
|
@ -69,7 +69,7 @@ class SDModelInfo():
|
|||||||
revision: str = None
|
revision: str = None
|
||||||
_cache: ModelCache = None
|
_cache: ModelCache = None
|
||||||
|
|
||||||
|
@property
|
||||||
def status(self)->ModelStatus:
|
def status(self)->ModelStatus:
|
||||||
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
||||||
if not self._cache:
|
if not self._cache:
|
||||||
@ -106,7 +106,7 @@ class ModelManager(object):
|
|||||||
config_path: Path,
|
config_path: Path,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
max_models=DEFAULT_MAX_MODELS,
|
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
logger: types.ModuleType = logger,
|
logger: types.ModuleType = logger,
|
||||||
):
|
):
|
||||||
@ -119,7 +119,7 @@ class ModelManager(object):
|
|||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
self.config = OmegaConf.load(self.config_path)
|
self.config = OmegaConf.load(self.config_path)
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_models=max_models,
|
max_models=max_loaded_models,
|
||||||
execution_device = device_type,
|
execution_device = device_type,
|
||||||
precision = precision,
|
precision = precision,
|
||||||
sequential_offload = sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
@ -164,7 +164,7 @@ class ModelManager(object):
|
|||||||
if mconfig.get('vae'):
|
if mconfig.get('vae'):
|
||||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||||
elif format=='diffusers':
|
elif format=='diffusers':
|
||||||
location = mconfig.repo_id
|
location = mconfig.get('repo_id') or mconfig.get('path')
|
||||||
revision = mconfig.get('revision')
|
revision = mconfig.get('revision')
|
||||||
else:
|
else:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
|
@ -7,6 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
import torch
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from compel import Compel
|
from compel import Compel
|
||||||
|
@ -78,7 +78,6 @@ class InvokeAIWebServer:
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
# Socket IO
|
# Socket IO
|
||||||
logger = True if args.web_verbose else False
|
|
||||||
engineio_logger = True if args.web_verbose else False
|
engineio_logger = True if args.web_verbose else False
|
||||||
max_http_buffer_size = 10000000
|
max_http_buffer_size = 10000000
|
||||||
|
|
||||||
@ -1278,11 +1277,12 @@ class InvokeAIWebServer:
|
|||||||
eventlet.sleep(0)
|
eventlet.sleep(0)
|
||||||
|
|
||||||
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
|
||||||
|
with self.generate.model_context as model:
|
||||||
tokens = (
|
tokens = (
|
||||||
None
|
None
|
||||||
if type(parsed_prompt) is Blend
|
if type(parsed_prompt) is Blend
|
||||||
else get_tokens_for_prompt_object(
|
else get_tokens_for_prompt_object(
|
||||||
self.generate.model.tokenizer, parsed_prompt
|
model.tokenizer, parsed_prompt
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
attention_maps_image_base64_url = (
|
attention_maps_image_base64_url = (
|
||||||
|
@ -109,9 +109,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
embedding_path = None
|
embedding_path = None
|
||||||
|
|
||||||
# migrate legacy models
|
|
||||||
ModelManager.migrate_models()
|
|
||||||
|
|
||||||
# load the infile as a list of lines
|
# load the infile as a list of lines
|
||||||
if opt.infile:
|
if opt.infile:
|
||||||
try:
|
try:
|
||||||
@ -197,7 +194,7 @@ def main_loop(gen, opt):
|
|||||||
# changing the history file midstream when the output directory is changed.
|
# changing the history file midstream when the output directory is changed.
|
||||||
completer = get_completer(opt, models=gen.model_manager.list_models())
|
completer = get_completer(opt, models=gen.model_manager.list_models())
|
||||||
set_default_output_dir(opt, completer)
|
set_default_output_dir(opt, completer)
|
||||||
if gen.model:
|
if gen.model_context:
|
||||||
add_embedding_terms(gen, completer)
|
add_embedding_terms(gen, completer)
|
||||||
output_cntr = completer.get_current_history_length() + 1
|
output_cntr = completer.get_current_history_length() + 1
|
||||||
|
|
||||||
@ -1080,7 +1077,8 @@ def add_embedding_terms(gen, completer):
|
|||||||
Called after setting the model, updates the autocompleter with
|
Called after setting the model, updates the autocompleter with
|
||||||
any terms loaded by the embedding manager.
|
any terms loaded by the embedding manager.
|
||||||
"""
|
"""
|
||||||
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings()
|
with gen.model_context as model:
|
||||||
|
trigger_strings = model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
completer.add_embedding_terms(trigger_strings)
|
completer.add_embedding_terms(trigger_strings)
|
||||||
|
|
||||||
|
|
||||||
@ -1222,6 +1220,7 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
|
||||||
)
|
)
|
||||||
|
traceback.print_exc()
|
||||||
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
|
||||||
if yes_to_all:
|
if yes_to_all:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
Loading…
Reference in New Issue
Block a user