mostly ported to new manager API; needs testing

This commit is contained in:
Lincoln Stein 2023-05-06 00:44:12 -04:00
parent af8c7c7d29
commit e0214a32bc
12 changed files with 353 additions and 332 deletions

View File

@ -180,31 +180,33 @@ class TextToLatentsInvocation(BaseInvocation):
def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline: def get_model(self, model_manager: ModelManager) -> StableDiffusionGeneratorPipeline:
model_info = choose_model(model_manager, self.model) model_info = choose_model(model_manager, self.model)
model_name = model_info['model_name'] model_name = model_info.name
model_hash = model_info['hash'] model_hash = model_info.hash
model: StableDiffusionGeneratorPipeline = model_info['model'] model_ctx: StableDiffusionGeneratorPipeline = model_info.context
model.scheduler = get_scheduler( with model_ctx as model:
model=model, model.scheduler = get_scheduler(
scheduler_name=self.scheduler model=model,
) scheduler_name=self.scheduler
)
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]: for component in [model.unet, model.vae]:
configure_model_padding(component, configure_model_padding(component,
self.seamless,
self.seamless_axes
)
else:
configure_model_padding(model,
self.seamless, self.seamless,
self.seamless_axes self.seamless_axes
) )
else:
configure_model_padding(model,
self.seamless,
self.seamless_axes
)
return model return model_ctx
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,
c, c,
@ -230,18 +232,17 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) with self.get_model(context.services.model_manager) as model:
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
result_latents, result_attention_map_saver = model.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)), noise=noise,
noise=noise, num_inference_steps=self.steps,
num_inference_steps=self.steps, conditioning_data=conditioning_data,
conditioning_data=conditioning_data, callback=step_callback
callback=step_callback )
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -284,29 +285,29 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) with self.get_model(context.services.model_manager) as model:
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like( initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=model.device, dtype=latent.dtype latent, device=model.device, dtype=latent.dtype
) )
timesteps, _ = model.get_img2img_timesteps( timesteps, _ = model.get_img2img_timesteps(
self.steps, self.steps,
self.strength, self.strength,
device=model.device, device=model.device,
) )
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
timesteps=timesteps, timesteps=timesteps,
noise=noise, noise=noise,
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
callback=step_callback callback=step_callback
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -7,7 +7,7 @@ def choose_model(model_manager: ModelManager, model_name: str):
if model_manager.valid_model(model_name): if model_manager.valid_model(model_name):
model = model_manager.get_model(model_name) model = model_manager.get_model(model_name)
else: else:
model = model_manager.get_model() model = model_manager.get_model(model_manager.default_model())
logger.warning(f"{model_name}' is not a valid model name. Using default model \'{model['model_name']}\' instead.") logger.warning(f"'{model_name}' is not a valid model name. Using default model \'{model.name}\' instead.")
return model return model

View File

@ -47,22 +47,21 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
else: else:
embedding_path = None embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# creating the model manager # creating the model manager
try: try:
device = torch.device(choose_torch_device()) device = torch.device(choose_torch_device())
precision = 'float16' if config.precision=='float16' \ if config.precision=="auto":
else 'float32' if config.precision=='float32' \ precision = choose_precision(device)
else choose_precision(device) dtype = torch.float32 if precision=='float32' \
else torch.float16
model_manager = ModelManager( model_manager = ModelManager(
OmegaConf.load(config.conf), config.conf,
precision=precision, precision=dtype,
device_type=device, device_type=device,
max_loaded_models=config.max_loaded_models, max_loaded_models=config.max_loaded_models,
embedding_path = Path(embedding_path), # temporarily disabled until model manager stabilizes
# embedding_path = Path(embedding_path),
logger = logger, logger = logger,
) )
except (FileNotFoundError, TypeError, AssertionError) as e: except (FileNotFoundError, TypeError, AssertionError) as e:

View File

@ -10,7 +10,7 @@ from .generator import (
Img2Img, Img2Img,
Inpaint Inpaint
) )
from .model_management import ModelManager from .model_management import ModelManager, ModelCache, ModelStatus, SDModelType
from .safety_checker import SafetyChecker from .safety_checker import SafetyChecker
from .args import Args from .args import Args
from .globals import Globals from .globals import Globals

View File

@ -37,7 +37,7 @@ from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary from .stable_diffusion import HuggingFaceConceptsLibrary
from .util import choose_precision, choose_torch_device from .util import choose_precision, choose_torch_device, torch_dtype
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@ -50,7 +50,6 @@ def fix_func(orig):
return new_func return new_func
return orig return orig
torch.rand = fix_func(torch.rand) torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like) torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn) torch.randn = fix_func(torch.randn)
@ -156,7 +155,6 @@ class Generate:
weights=None, weights=None,
config=None, config=None,
): ):
mconfig = OmegaConf.load(conf)
self.height = None self.height = None
self.width = None self.width = None
self.model_manager = None self.model_manager = None
@ -171,7 +169,7 @@ class Generate:
self.seamless_axes = {"x", "y"} self.seamless_axes = {"x", "y"}
self.hires_fix = False self.hires_fix = False
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.model = None # empty for now self.model_context = None # empty for now
self.model_hash = None self.model_hash = None
self.sampler = None self.sampler = None
self.device = None self.device = None
@ -219,12 +217,12 @@ class Generate:
# model caching system for fast switching # model caching system for fast switching
self.model_manager = ModelManager( self.model_manager = ModelManager(
mconfig, conf,
self.device, self.device,
self.precision, torch_dtype(self.device),
max_loaded_models=max_loaded_models, max_loaded_models=max_loaded_models,
sequential_offload=self.free_gpu_mem, sequential_offload=self.free_gpu_mem,
embedding_path=Path(self.embedding_path), # embedding_path=Path(self.embedding_path),
) )
# don't accept invalid models # don't accept invalid models
fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME fallback = self.model_manager.default_model() or FALLBACK_MODEL_NAME
@ -418,170 +416,171 @@ class Generate:
with_variations = [] if with_variations is None else with_variations with_variations = [] if with_variations is None else with_variations
# will instantiate the model or return it from cache # will instantiate the model or return it from cache
model = self.set_model(self.model_name) model_context = self.set_model(self.model_name)
# self.width and self.height are set by set_model() # self.width and self.height are set by set_model()
# to the width and height of the image training set # to the width and height of the image training set
width = width or self.width width = width or self.width
height = height or self.height height = height or self.height
if isinstance(model, DiffusionPipeline): with model_context as model:
configure_model_padding(model.unet, seamless, seamless_axes) if isinstance(model, DiffusionPipeline):
configure_model_padding(model.vae, seamless, seamless_axes) configure_model_padding(model.unet, seamless, seamless_axes)
else: configure_model_padding(model.vae, seamless, seamless_axes)
configure_model_padding(model, seamless, seamless_axes)
assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0"
assert threshold >= 0.0, "--threshold must be >=0.0"
assert (
0.0 < strength <= 1.0
), "img2img and inpaint strength can only work with 0.0 < strength < 1.0"
assert (
0.0 <= variation_amount <= 1.0
), "-v --variation_amount must be in [0.0, 1.0]"
assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]"
assert (embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None
), "Embiggen requires an init/input image to be specified"
if len(with_variations) > 0 or variation_amount > 1.0:
assert seed is not None, "seed must be specified when using with_variations"
if variation_amount == 0.0:
assert (
iterations == 1
), "when using --with_variations, multiple iterations are only possible when using --variation_amount"
assert all(
0 <= weight <= 1 for _, weight in with_variations
), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}"
width, height, _ = self._resolution_check(width, height, log=True)
assert (
inpaint_replace >= 0.0 and inpaint_replace <= 1.0
), "inpaint_replace must be between 0.0 and 1.0"
if sampler_name and (sampler_name != self.sampler_name):
self.sampler_name = sampler_name
self._set_scheduler()
# apply the concepts library to the prompt
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
prompt,
lambda concepts: self.load_huggingface_concepts(concepts),
self.model.textual_inversion_manager.get_all_trigger_strings(),
)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
results = list()
try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt,
model=self.model,
skip_normalize_legacy_blend=skip_normalize,
log_tokens=self.log_tokenization,
)
init_image, mask_image = self._make_images(
init_img,
init_mask,
width,
height,
fit=fit,
text_mask=text_mask,
invert_mask=invert_mask,
force_outpaint=force_outpaint,
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
generator = self.select_generator(
init_image, mask_image, embiggen, hires_fix, force_outpaint
)
generator.set_variation(self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise
results = generator.generate(
prompt,
iterations=iterations,
seed=self.seed,
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
threshold=threshold,
perlin=perlin,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
embiggen_strength=embiggen_strength,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=self.safety_checker,
seam_size=seam_size,
seam_blur=seam_blur,
seam_strength=seam_strength,
seam_steps=seam_steps,
tile_size=tile_size,
infill_method=infill_method,
force_outpaint=force_outpaint,
inpaint_height=inpaint_height,
inpaint_width=inpaint_width,
enable_image_debugging=enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
clear_cuda_cache=self.clear_cuda_cache,
)
if init_color:
self.correct_colors(
image_list=results,
reference_image_path=init_color,
image_callback=image_callback,
)
if upscale is not None or facetool_strength > 0:
self.upscale_and_reconstruct(
results,
upscale=upscale,
upscale_denoise_str=upscale_denoise_str,
facetool=facetool,
strength=facetool_strength,
codeformer_fidelity=codeformer_fidelity,
save_original=save_original,
image_callback=image_callback,
)
except KeyboardInterrupt:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
if catch_interrupts:
logger.warning("Interrupted** Partial results will be returned.")
else: else:
raise KeyboardInterrupt configure_model_padding(model, seamless, seamless_axes)
except RuntimeError:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr) assert cfg_scale > 1.0, "CFG_Scale (-C) must be >1.0"
logger.info("Could not generate image.") assert threshold >= 0.0, "--threshold must be >=0.0"
assert (
0.0 < strength <= 1.0
), "img2img and inpaint strength can only work with 0.0 < strength < 1.0"
assert (
0.0 <= variation_amount <= 1.0
), "-v --variation_amount must be in [0.0, 1.0]"
assert 0.0 <= perlin <= 1.0, "--perlin must be in [0.0, 1.0]"
assert (embiggen == None and embiggen_tiles == None) or (
(embiggen != None or embiggen_tiles != None) and init_img != None
), "Embiggen requires an init/input image to be specified"
toc = time.time() if len(with_variations) > 0 or variation_amount > 1.0:
logger.info("Usage stats:") assert seed is not None, "seed must be specified when using with_variations"
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic)) if variation_amount == 0.0:
self.print_cuda_stats() assert (
iterations == 1
), "when using --with_variations, multiple iterations are only possible when using --variation_amount"
assert all(
0 <= weight <= 1 for _, weight in with_variations
), f"variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}"
width, height, _ = self._resolution_check(width, height, log=True)
assert (
inpaint_replace >= 0.0 and inpaint_replace <= 1.0
), "inpaint_replace must be between 0.0 and 1.0"
if sampler_name and (sampler_name != self.sampler_name):
self.sampler_name = sampler_name
self._set_scheduler(model)
# apply the concepts library to the prompt
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
prompt,
lambda concepts: self.load_huggingface_concepts(concepts),
model.textual_inversion_manager.get_all_trigger_strings(),
)
tic = time.time()
if self._has_cuda():
torch.cuda.reset_peak_memory_stats()
results = list()
try:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt,
model=model,
skip_normalize_legacy_blend=skip_normalize,
log_tokens=self.log_tokenization,
)
init_image, mask_image = self._make_images(
init_img,
init_mask,
width,
height,
fit=fit,
text_mask=text_mask,
invert_mask=invert_mask,
force_outpaint=force_outpaint,
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
generator = self.select_generator(
init_image, mask_image, embiggen, hires_fix, force_outpaint
)
generator.set_variation(self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise
results = generator.generate(
prompt,
iterations=iterations,
seed=self.seed,
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
threshold=threshold,
perlin=perlin,
h_symmetry_time_pct=h_symmetry_time_pct,
v_symmetry_time_pct=v_symmetry_time_pct,
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
embiggen_strength=embiggen_strength,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=self.safety_checker,
seam_size=seam_size,
seam_blur=seam_blur,
seam_strength=seam_strength,
seam_steps=seam_steps,
tile_size=tile_size,
infill_method=infill_method,
force_outpaint=force_outpaint,
inpaint_height=inpaint_height,
inpaint_width=inpaint_width,
enable_image_debugging=enable_image_debugging,
free_gpu_mem=self.free_gpu_mem,
clear_cuda_cache=self.clear_cuda_cache,
)
if init_color:
self.correct_colors(
image_list=results,
reference_image_path=init_color,
image_callback=image_callback,
)
if upscale is not None or facetool_strength > 0:
self.upscale_and_reconstruct(
results,
upscale=upscale,
upscale_denoise_str=upscale_denoise_str,
facetool=facetool,
strength=facetool_strength,
codeformer_fidelity=codeformer_fidelity,
save_original=save_original,
image_callback=image_callback,
)
except KeyboardInterrupt:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
if catch_interrupts:
logger.warning("Interrupted** Partial results will be returned.")
else:
raise KeyboardInterrupt
except RuntimeError:
# Clear the CUDA cache on an exception
self.clear_cuda_cache()
print(traceback.format_exc(), file=sys.stderr)
logger.info("Could not generate image.")
toc = time.time()
logger.info("Usage stats:")
logger.info(f"{len(results)} image(s) generated in "+"%4.2fs" % (toc - tic))
self.print_cuda_stats()
return results return results
def gather_cuda_stats(self): def gather_cuda_stats(self):
@ -662,12 +661,13 @@ class Generate:
# used by multiple postfixers # used by multiple postfixers
# todo: cross-attention control # todo: cross-attention control
uc, c, extra_conditioning_info = get_uc_and_c_and_ec( with self.model_context as model:
prompt, uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
model=self.model, prompt,
skip_normalize_legacy_blend=opt.skip_normalize, model=model,
log_tokens=log_tokenization, skip_normalize_legacy_blend=opt.skip_normalize,
) log_tokens=log_tokenization,
)
if tool in ("gfpgan", "codeformer", "upscale"): if tool in ("gfpgan", "codeformer", "upscale"):
if tool == "gfpgan": if tool == "gfpgan":
@ -852,7 +852,8 @@ class Generate:
cn = class_name cn = class_name
module = importlib.import_module(mn) module = importlib.import_module(mn)
constructor = getattr(module, cn) constructor = getattr(module, cn)
return constructor(self.model, self.precision) with self.model_context as model:
return constructor(model, self.precision)
def load_model(self): def load_model(self):
""" """
@ -869,8 +870,8 @@ class Generate:
If the model fails to load for some reason, will attempt to load the previously- If the model fails to load for some reason, will attempt to load the previously-
loaded model (if any). If that fallback fails, will raise an AssertionError loaded model (if any). If that fallback fails, will raise an AssertionError
""" """
if self.model_name == model_name and self.model is not None: if self.model_name == model_name and self.model_context is not None:
return self.model return self.model_context
previous_model_name = self.model_name previous_model_name = self.model_name
@ -881,11 +882,9 @@ class Generate:
f'** "{model_name}" is not a known model name. Cannot change.' f'** "{model_name}" is not a known model name. Cannot change.'
) )
cache.print_vram_usage()
# have to get rid of all references to model in order # have to get rid of all references to model in order
# to free it from GPU memory # to free it from GPU memory
self.model = None self.model_context = None
self.sampler = None self.sampler = None
self.generators = {} self.generators = {}
gc.collect() gc.collect()
@ -902,29 +901,33 @@ class Generate:
raise e raise e
model_name = previous_model_name model_name = previous_model_name
self.model = model_data["model"] self.model_context = model_data.context
self.width = model_data["width"] self.width = 512
self.height = model_data["height"] self.height = 512
self.model_hash = model_data["hash"] self.model_hash = model_data.hash
# uncache generators so they pick up new models # uncache generators so they pick up new models
self.generators = {} self.generators = {}
set_seed(random.randrange(0, np.iinfo(np.uint32).max)) set_seed(random.randrange(0, np.iinfo(np.uint32).max))
self.model_name = model_name self.model_name = model_name
self._set_scheduler() # requires self.model_name to be set first with self.model_context as model:
return self.model self._set_scheduler(model) # requires self.model_name to be set first
return self.model_context
def load_huggingface_concepts(self, concepts: list[str]): def load_huggingface_concepts(self, concepts: list[str]):
self.model.textual_inversion_manager.load_huggingface_concepts(concepts) with self.model_context as model:
model.textual_inversion_manager.load_huggingface_concepts(concepts)
@property @property
def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary: def huggingface_concepts_library(self) -> HuggingFaceConceptsLibrary:
return self.model.textual_inversion_manager.hf_concepts_library with self.model_context as model:
return model.textual_inversion_manager.hf_concepts_library
@property @property
def embedding_trigger_strings(self) -> List[str]: def embedding_trigger_strings(self) -> List[str]:
return self.model.textual_inversion_manager.get_all_trigger_strings() with self.model_context as model:
return model.textual_inversion_manager.get_all_trigger_strings()
def correct_colors(self, image_list, reference_image_path, image_callback=None): def correct_colors(self, image_list, reference_image_path, image_callback=None):
reference_image = Image.open(reference_image_path) reference_image = Image.open(reference_image_path)
@ -1044,8 +1047,8 @@ class Generate:
def is_legacy_model(self, model_name) -> bool: def is_legacy_model(self, model_name) -> bool:
return self.model_manager.is_legacy(model_name) return self.model_manager.is_legacy(model_name)
def _set_scheduler(self): def _set_scheduler(self,model):
default = self.model.scheduler default = model.scheduler
# See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672 # See https://github.com/huggingface/diffusers/issues/277#issuecomment-1371428672
scheduler_map = dict( scheduler_map = dict(
@ -1069,7 +1072,7 @@ class Generate:
msg = ( msg = (
f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})" f"Setting Sampler to {self.sampler_name} ({sampler_class.__name__})"
) )
self.sampler = sampler_class.from_config(self.model.scheduler.config) self.sampler = sampler_class.from_config(model.scheduler.config)
else: else:
msg = ( msg = (
f" Unsupported Sampler: {self.sampler_name} "+ f" Unsupported Sampler: {self.sampler_name} "+

View File

@ -123,51 +123,51 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.update(keyword_args) generator_args.update(keyword_args)
model_info = self.model_info model_info = self.model_info
model_name = model_info['model_name'] model_name = model_info.name
model:StableDiffusionGeneratorPipeline = model_info['model'] model_hash = model_info.hash
model_hash = model_info['hash'] with model_info.context as model:
scheduler: Scheduler = self.get_scheduler( scheduler: Scheduler = self.get_scheduler(
model=model, model=model,
scheduler_name=generator_args.get('scheduler') scheduler_name=generator_args.get('scheduler')
) )
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class() gen_class = self._generator_class()
generator = gen_class(model, self.params.precision) generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'), generator_args.get('variation_amount'),
generator_args.get('with_variations') generator_args.get('with_variations')
) )
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]: for component in [model.unet, model.vae]:
configure_model_padding(component, configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False), generator_args.get('seamless',False),
generator_args.get('seamless_axes') generator_args.get('seamless_axes')
) )
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count: for i in iteration_count:
results = generator.generate(prompt, results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info), conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback, step_callback=step_callback,
sampler=scheduler, sampler=scheduler,
**generator_args, **generator_args,
) )
output = InvokeAIGeneratorOutput( output = InvokeAIGeneratorOutput(
image=results[0][0], image=results[0][0],
seed=results[0][1], seed=results[0][1],
attention_maps_images=results[0][2], attention_maps_images=results[0][2],
model_hash = model_hash, model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args), params=Namespace(model_name=model_name,**generator_args),
) )
if callback: if callback:
callback(output) callback(output)
yield output yield output
@classmethod @classmethod
@ -275,7 +275,6 @@ class Embiggen(Txt2Img):
from .embiggen import Embiggen from .embiggen import Embiggen
return Embiggen return Embiggen
class Generator: class Generator:
downsampling_factor: int downsampling_factor: int
latent_channels: int latent_channels: int

View File

@ -2,4 +2,4 @@
Initialization file for invokeai.backend.model_management Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager from .model_manager import ModelManager
from .model_cache import ModelCache, ModelStatus from .model_cache import ModelCache, ModelStatus, SDModelType

View File

@ -78,6 +78,10 @@ class UnscannableModelException(Exception):
"Raised when picklescan is unable to scan a legacy model file" "Raised when picklescan is unable to scan a legacy model file"
pass pass
class ModelLocker(object):
"Forward declaration"
pass
class ModelCache(object): class ModelCache(object):
def __init__( def __init__(
self, self,
@ -112,8 +116,6 @@ class ModelCache(object):
self.loaded_models: set = set() # set of model keys loaded in GPU self.loaded_models: set = set() # set of model keys loaded in GPU
self.locked_models: Counter = Counter() # set of model keys locked in GPU self.locked_models: Counter = Counter() # set of model keys locked in GPU
@contextlib.contextmanager
def get_model( def get_model(
self, self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str,Path],
@ -124,7 +126,7 @@ class ModelCache(object):
legacy_info: LegacyInfo=None, legacy_info: LegacyInfo=None,
attach_model_part: Tuple[SDModelType, str] = (None,None), attach_model_part: Tuple[SDModelType, str] = (None,None),
gpu_load: bool=True, gpu_load: bool=True,
)->Generator[ModelClass, None, None]: )->ModelLocker: # ?? what does it return
''' '''
Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching. Load and return a HuggingFace model wrapped in a context manager generator, with RAM caching.
Use like this: Use like this:
@ -188,29 +190,45 @@ class ModelCache(object):
if submodel: if submodel:
model = getattr(model, submodel.name) model = getattr(model, submodel.name)
if gpu_load and hasattr(model,'to'): return self.ModelLocker(self, key, model, gpu_load)
try:
self.loaded_models.add(key) class ModelLocker(object):
self.locked_models[key] += 1 def __init__(self, cache, key, model, gpu_load):
if self.lazy_offloading: self.gpu_load = gpu_load
self._offload_unlocked_models() self.cache = cache
self.logger.debug(f'Loading {key} into {self.execution_device}') self.key = key
model.to(self.execution_device) # move into GPU # This will keep a copy of the model in RAM until the locker
self._print_cuda_stats() # is garbage collected. Needs testing!
yield model self.model = model
finally:
self.locked_models[key] -= 1 def __enter__(self)->ModelClass:
if not self.lazy_offloading: cache = self.cache
self._offload_unlocked_models() key = self.key
self._print_cuda_stats() model = self.model
else: if self.gpu_load and hasattr(model,'to'):
# in the event that the caller wants the model in RAM, we cache.loaded_models.add(key)
# move it into CPU if it is in GPU and not locked cache.locked_models[key] += 1
if hasattr(model,'to') and (key in self.loaded_models if cache.lazy_offloading:
and self.locked_models[key] == 0): cache._offload_unlocked_models()
model.to(self.storage_device) cache.logger.debug(f'Loading {key} into {cache.execution_device}')
self.loaded_models.remove(key) model.to(cache.execution_device) # move into GPU
yield model cache._print_cuda_stats()
else:
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
if hasattr(model,'to') and (key in cache.loaded_models
and cache.locked_models[key] == 0):
model.to(cache.storage_device)
cache.loaded_models.remove(key)
return model
def __exit__(self, type, value, traceback):
key = self.key
cache = self.cache
cache.locked_models[key] -= 1
if not cache.lazy_offloading:
cache._offload_unlocked_models()
cache._print_cuda_stats()
def attach_part(self, def attach_part(self,
diffusers_model: StableDiffusionPipeline, diffusers_model: StableDiffusionPipeline,
@ -381,10 +399,11 @@ class ModelCache(object):
revisions = [revision] if revision \ revisions = [revision] if revision \
else ['fp16','main'] if self.precision==torch.float16 \ else ['fp16','main'] if self.precision==torch.float16 \
else ['main'] else ['main']
extra_args = {'precision': self.precision} \ extra_args = {'torch_dtype': self.precision,
if model_class in DiffusionClasses \ 'safety_checker': None}\
else {} if model_class in DiffusionClasses\
else {}
# silence transformer and diffuser warnings # silence transformer and diffuser warnings
with SilenceWarnings(): with SilenceWarnings():
for rev in revisions: for rev in revisions:

View File

@ -69,7 +69,7 @@ class SDModelInfo():
revision: str = None revision: str = None
_cache: ModelCache = None _cache: ModelCache = None
@property
def status(self)->ModelStatus: def status(self)->ModelStatus:
'''Return load status of this model as a model_cache.ModelStatus enum''' '''Return load status of this model as a model_cache.ModelStatus enum'''
if not self._cache: if not self._cache:
@ -106,7 +106,7 @@ class ModelManager(object):
config_path: Path, config_path: Path,
device_type: torch.device = CUDA_DEVICE, device_type: torch.device = CUDA_DEVICE,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
max_models=DEFAULT_MAX_MODELS, max_loaded_models=DEFAULT_MAX_MODELS,
sequential_offload=False, sequential_offload=False,
logger: types.ModuleType = logger, logger: types.ModuleType = logger,
): ):
@ -119,7 +119,7 @@ class ModelManager(object):
self.config_path = config_path self.config_path = config_path
self.config = OmegaConf.load(self.config_path) self.config = OmegaConf.load(self.config_path)
self.cache = ModelCache( self.cache = ModelCache(
max_models=max_models, max_models=max_loaded_models,
execution_device = device_type, execution_device = device_type,
precision = precision, precision = precision,
sequential_offload = sequential_offload, sequential_offload = sequential_offload,
@ -164,7 +164,7 @@ class ModelManager(object):
if mconfig.get('vae'): if mconfig.get('vae'):
legacy.vae_file = global_resolve_path(mconfig.vae) legacy.vae_file = global_resolve_path(mconfig.vae)
elif format=='diffusers': elif format=='diffusers':
location = mconfig.repo_id location = mconfig.get('repo_id') or mconfig.get('path')
revision = mconfig.get('revision') revision = mconfig.get('revision')
else: else:
raise InvalidModelError( raise InvalidModelError(

View File

@ -7,6 +7,7 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
""" """
import re import re
import torch
from typing import Optional, Union from typing import Optional, Union
from compel import Compel from compel import Compel

View File

@ -78,7 +78,6 @@ class InvokeAIWebServer:
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
# Socket IO # Socket IO
logger = True if args.web_verbose else False
engineio_logger = True if args.web_verbose else False engineio_logger = True if args.web_verbose else False
max_http_buffer_size = 10000000 max_http_buffer_size = 10000000
@ -1278,13 +1277,14 @@ class InvokeAIWebServer:
eventlet.sleep(0) eventlet.sleep(0)
parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"]) parsed_prompt, _ = get_prompt_structure(generation_parameters["prompt"])
tokens = ( with self.generate.model_context as model:
None tokens = (
if type(parsed_prompt) is Blend None
else get_tokens_for_prompt_object( if type(parsed_prompt) is Blend
self.generate.model.tokenizer, parsed_prompt else get_tokens_for_prompt_object(
model.tokenizer, parsed_prompt
)
) )
)
attention_maps_image_base64_url = ( attention_maps_image_base64_url = (
None None
if attention_maps_image is None if attention_maps_image is None

View File

@ -109,9 +109,6 @@ def main():
else: else:
embedding_path = None embedding_path = None
# migrate legacy models
ModelManager.migrate_models()
# load the infile as a list of lines # load the infile as a list of lines
if opt.infile: if opt.infile:
try: try:
@ -197,7 +194,7 @@ def main_loop(gen, opt):
# changing the history file midstream when the output directory is changed. # changing the history file midstream when the output directory is changed.
completer = get_completer(opt, models=gen.model_manager.list_models()) completer = get_completer(opt, models=gen.model_manager.list_models())
set_default_output_dir(opt, completer) set_default_output_dir(opt, completer)
if gen.model: if gen.model_context:
add_embedding_terms(gen, completer) add_embedding_terms(gen, completer)
output_cntr = completer.get_current_history_length() + 1 output_cntr = completer.get_current_history_length() + 1
@ -1080,7 +1077,8 @@ def add_embedding_terms(gen, completer):
Called after setting the model, updates the autocompleter with Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager. any terms loaded by the embedding manager.
""" """
trigger_strings = gen.model.textual_inversion_manager.get_all_trigger_strings() with gen.model_context as model:
trigger_strings = model.textual_inversion_manager.get_all_trigger_strings()
completer.add_embedding_terms(trigger_strings) completer.add_embedding_terms(trigger_strings)
@ -1222,6 +1220,7 @@ def report_model_error(opt: Namespace, e: Exception):
logger.warning( logger.warning(
"This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models." "This can be caused by a missing or corrupted models file, and can sometimes be fixed by (re)installing the models."
) )
traceback.print_exc()
yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE") yes_to_all = os.environ.get("INVOKE_MODEL_RECONFIGURE")
if yes_to_all: if yes_to_all:
logger.warning( logger.warning(