mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
adjust t2i to work with new model structure
This commit is contained in:
parent
667171ed90
commit
4649920074
@ -206,7 +206,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
||||
|
||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
|
||||
conditioning_data = ConditioningData(
|
||||
uc,
|
||||
c,
|
||||
@ -346,11 +345,11 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# TODO: this only really needs the vae
|
||||
model_info = choose_model(context.services.model_manager, self.model)
|
||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
||||
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
with model_info.context as model:
|
||||
with torch.inference_mode():
|
||||
np_image = model.decode_latents(latents)
|
||||
image = model.numpy_to_pil(np_image)[0]
|
||||
|
||||
image_type = ImageType.RESULT
|
||||
image_name = context.services.images.create_name(
|
||||
|
@ -54,6 +54,10 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
precision = choose_precision(device)
|
||||
dtype = torch.float32 if precision=='float32' \
|
||||
else torch.float16
|
||||
|
||||
max_cache_size = config.max_cache_size \
|
||||
if hasattr(config,'max_cache_size') \
|
||||
else config.max_loaded_models * 2.5
|
||||
|
||||
model_manager = ModelManager(
|
||||
config.conf,
|
||||
|
@ -557,12 +557,6 @@ class Args(object):
|
||||
default=False,
|
||||
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autoconvert",
|
||||
default=None,
|
||||
@ -765,6 +759,19 @@ class Args(object):
|
||||
action="store_true",
|
||||
help="Start InvokeAI GUI",
|
||||
)
|
||||
deprecated_group.add_argument(
|
||||
"--autoimport",
|
||||
default=None,
|
||||
type=str,
|
||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||
)
|
||||
deprecated_group.add_argument(
|
||||
"--max_loaded_models",
|
||||
dest="max_loaded_models",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of models to keep in RAM cache (deprecated - use max_cache_size)",
|
||||
)
|
||||
return parser
|
||||
|
||||
# This creates the parser that processes commands on the invoke> command line
|
||||
|
@ -213,7 +213,9 @@ class ModelCache(object):
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
self._make_cache_room(key, model_type)
|
||||
|
||||
|
||||
# clean memory to make MemoryUsage() more accurate
|
||||
gc.collect()
|
||||
with MemoryUsage() as usage:
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
|
Loading…
Reference in New Issue
Block a user