adjust t2i to work with new model structure

This commit is contained in:
Lincoln Stein 2023-05-07 19:06:49 -04:00
parent 667171ed90
commit 4649920074
4 changed files with 25 additions and 13 deletions

View File

@ -206,7 +206,6 @@ class TextToLatentsInvocation(BaseInvocation):
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
conditioning_data = ConditioningData(
uc,
c,
@ -346,11 +345,11 @@ class LatentsToImageInvocation(BaseInvocation):
# TODO: this only really needs the vae
model_info = choose_model(context.services.model_manager, self.model)
model: StableDiffusionGeneratorPipeline = model_info['model']
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
with model_info.context as model:
with torch.inference_mode():
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]
image_type = ImageType.RESULT
image_name = context.services.images.create_name(

View File

@ -54,6 +54,10 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
precision = choose_precision(device)
dtype = torch.float32 if precision=='float32' \
else torch.float16
max_cache_size = config.max_cache_size \
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
model_manager = ModelManager(
config.conf,

View File

@ -557,12 +557,6 @@ class Args(object):
default=False,
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
)
model_group.add_argument(
"--autoimport",
default=None,
type=str,
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
)
model_group.add_argument(
"--autoconvert",
default=None,
@ -765,6 +759,19 @@ class Args(object):
action="store_true",
help="Start InvokeAI GUI",
)
deprecated_group.add_argument(
"--autoimport",
default=None,
type=str,
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
)
deprecated_group.add_argument(
"--max_loaded_models",
dest="max_loaded_models",
type=int,
default=3,
help="Maximum number of models to keep in RAM cache (deprecated - use max_cache_size)",
)
return parser
# This creates the parser that processes commands on the invoke> command line

View File

@ -213,7 +213,9 @@ class ModelCache(object):
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(key, model_type)
# clean memory to make MemoryUsage() more accurate
gc.collect()
with MemoryUsage() as usage:
model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path,