use heuristic to select RAM cache size during headless install; blackified

This commit is contained in:
Lincoln Stein 2023-09-25 19:18:58 -04:00 committed by Kent Keirsey
parent 0c97a1e7e7
commit d59e534cad
10 changed files with 39 additions and 40 deletions

View File

@ -344,12 +344,12 @@ class InvokeAiInstance:
auto_install = True auto_install = True
sys.argv = new_argv sys.argv = new_argv
import requests # to catch download exceptions
import messages import messages
import requests # to catch download exceptions
auto_install = auto_install or messages.user_wants_auto_configuration() auto_install = auto_install or messages.user_wants_auto_configuration()
if auto_install: if auto_install:
sys.argv.append('--yes') sys.argv.append("--yes")
else: else:
messages.introduction() messages.introduction()

View File

@ -7,7 +7,7 @@ import os
import platform import platform
from pathlib import Path from pathlib import Path
from prompt_toolkit import prompt, HTML from prompt_toolkit import HTML, prompt
from prompt_toolkit.completion import PathCompleter from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from rich import box, print from rich import box, print
@ -97,13 +97,17 @@ def user_wants_auto_configuration() -> bool:
padding=(1, 1), padding=(1, 1),
) )
) )
choice = prompt(HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "), choice = (
prompt(
HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "),
validator=Validator.from_callable( validator=Validator.from_callable(
lambda n: n=='' or n.startswith(('a', 'A', 'm', 'M')), lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
error_message="Please select 'a' or 'm'"
), ),
) or 'a' )
return choice.lower().startswith('a') or "a"
)
return choice.lower().startswith("a")
def dest_path(dest=None) -> Path: def dest_path(dest=None) -> Path:
""" """

View File

@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Default_config_file = config.model_conf_path Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
) )
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.", name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
begin_entry_at=0, begin_entry_at=0,
editable=False, editable=False,
color="CONTROL", color="CONTROL",
@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
return editApp.new_opts() return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
# So we adjust everthing down a bit.
return (
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
) # 2.1 is just large enough for sd 1.5 ;-)
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config() opts = InvokeAIAppConfig.get_config()
opts.ram = default_ramcache()
return opts return opts

View File

@ -175,10 +175,7 @@ class InvokeAIDiffuserComponent:
dim=0, dim=0,
), ),
} }
( (encoder_hidden_states, encoder_attention_mask,) = self._concat_conditionings_for_batch(
encoder_hidden_states,
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.text_embeddings.embeds,
) )
@ -240,10 +237,7 @@ class InvokeAIDiffuserComponent:
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_cross_attention_controlled_conditioning(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -251,10 +245,7 @@ class InvokeAIDiffuserComponent:
**kwargs, **kwargs,
) )
elif self.sequential_guidance: elif self.sequential_guidance:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning_sequentially(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning_sequentially(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,
@ -262,10 +253,7 @@ class InvokeAIDiffuserComponent:
) )
else: else:
( (unconditioned_next_x, conditioned_next_x,) = self._apply_standard_conditioning(
unconditioned_next_x,
conditioned_next_x,
) = self._apply_standard_conditioning(
sample, sample,
timestep, timestep,
conditioning_data, conditioning_data,

View File

@ -470,10 +470,7 @@ class TextualInversionDataset(Dataset):
if self.center_crop: if self.center_crop:
crop = min(img.shape[0], img.shape[1]) crop = min(img.shape[0], img.shape[1])
( (h, w,) = (
h,
w,
) = (
img.shape[0], img.shape[0],
img.shape[1], img.shape[1],
) )