mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
adjust t2i to work with new model structure
This commit is contained in:
parent
667171ed90
commit
4649920074
@ -206,7 +206,6 @@ class TextToLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model)
|
||||||
print(f'DEBUG: uc.dtype={uc.dtype}, c.dtype={c.dtype}')
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
uc,
|
uc,
|
||||||
c,
|
c,
|
||||||
@ -346,8 +345,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# TODO: this only really needs the vae
|
# TODO: this only really needs the vae
|
||||||
model_info = choose_model(context.services.model_manager, self.model)
|
model_info = choose_model(context.services.model_manager, self.model)
|
||||||
model: StableDiffusionGeneratorPipeline = model_info['model']
|
|
||||||
|
|
||||||
|
with model_info.context as model:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
np_image = model.decode_latents(latents)
|
np_image = model.decode_latents(latents)
|
||||||
image = model.numpy_to_pil(np_image)[0]
|
image = model.numpy_to_pil(np_image)[0]
|
||||||
|
@ -55,6 +55,10 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
dtype = torch.float32 if precision=='float32' \
|
dtype = torch.float32 if precision=='float32' \
|
||||||
else torch.float16
|
else torch.float16
|
||||||
|
|
||||||
|
max_cache_size = config.max_cache_size \
|
||||||
|
if hasattr(config,'max_cache_size') \
|
||||||
|
else config.max_loaded_models * 2.5
|
||||||
|
|
||||||
model_manager = ModelManager(
|
model_manager = ModelManager(
|
||||||
config.conf,
|
config.conf,
|
||||||
precision=dtype,
|
precision=dtype,
|
||||||
|
@ -557,12 +557,6 @@ class Args(object):
|
|||||||
default=False,
|
default=False,
|
||||||
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
help="Check for and blur potentially NSFW images. Use --no-nsfw_checker to disable.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
|
||||||
"--autoimport",
|
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
|
||||||
)
|
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--autoconvert",
|
"--autoconvert",
|
||||||
default=None,
|
default=None,
|
||||||
@ -765,6 +759,19 @@ class Args(object):
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Start InvokeAI GUI",
|
help="Start InvokeAI GUI",
|
||||||
)
|
)
|
||||||
|
deprecated_group.add_argument(
|
||||||
|
"--autoimport",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
help="(DEPRECATED - NONFUNCTIONAL). Check the indicated directory for .ckpt/.safetensors weights files at startup and import directly",
|
||||||
|
)
|
||||||
|
deprecated_group.add_argument(
|
||||||
|
"--max_loaded_models",
|
||||||
|
dest="max_loaded_models",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Maximum number of models to keep in RAM cache (deprecated - use max_cache_size)",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
# This creates the parser that processes commands on the invoke> command line
|
# This creates the parser that processes commands on the invoke> command line
|
||||||
|
@ -214,6 +214,8 @@ class ModelCache(object):
|
|||||||
# there is sufficient room to load the requested model
|
# there is sufficient room to load the requested model
|
||||||
self._make_cache_room(key, model_type)
|
self._make_cache_room(key, model_type)
|
||||||
|
|
||||||
|
# clean memory to make MemoryUsage() more accurate
|
||||||
|
gc.collect()
|
||||||
with MemoryUsage() as usage:
|
with MemoryUsage() as usage:
|
||||||
model = self._load_model_from_storage(
|
model = self._load_model_from_storage(
|
||||||
repo_id_or_path=repo_id_or_path,
|
repo_id_or_path=repo_id_or_path,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user