(config) fix config file creation in edge cases

if the config directory is missing, initialize it using the standard
process of copying it over, instead of failing to create the config file

this can happen if the user is re-running the config script in a directory which
already has the init file, but no configs dir
This commit is contained in:
Eugene Brodsky 2023-01-15 04:09:13 -05:00
parent 02c530e200
commit d047e070b8
2 changed files with 43 additions and 34 deletions

View File

@ -91,7 +91,7 @@ class Installer:
venv_dir = self.mktemp_venv()
pip = get_venv_pip(Path(venv_dir.name))
cmd = [pip, "install", "--require-virtualenv"]
cmd = [pip, "install", "--require-virtualenv", "--use-pep517"]
cmd.extend(self.reqs)
try:

View File

@ -31,8 +31,9 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPSegForImageSegmentation
from tempfile import TemporaryFile
from ldm.invoke.globals import Globals, global_cache_dir
from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
from ldm.invoke.readline import generic_completer
warnings.filterwarnings('ignore')
@ -51,10 +52,9 @@ Weights_dir = 'ldm/stable-diffusion-v1/'
# the initial "configs" dir is now bundled with the `config` package
Dataset_path = Path(__file__).parent / "configs" / 'INITIAL_MODELS.yaml'
Default_config_file = Path(__file__).parent / "configs" / 'models.yaml'
SD_Configs = Path(__file__).parent / "configs" / 'stable-diffusion'
assert os.path.exists(Dataset_path),"The configs directory cannot be found. Please run this script from within the invokeai runtime directory."
Default_config_file = Path (global_config_dir()) / 'models.yaml'
SD_Configs = Path (global_config_dir()) / 'stable-diffusion'
Datasets = OmegaConf.load(Dataset_path)
completer = generic_completer(['yes','no'])
@ -434,20 +434,6 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
return None
return model_dest
# -----------------------------------------------------------------------------------
#---------------------------------------------
def is_huggingface_authenticated():
# huggingface_hub 0.10 API isn't great for this, it could be OSError, ValueError,
# maybe other things, not all end-user-friendly.
# noinspection PyBroadException
try:
response = hf_whoami()
if response.get('id') is not None:
return True
except Exception:
pass
return False
#---------------------------------------------
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
try:
@ -465,33 +451,56 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
print(traceback.format_exc())
#---------------------------------------------
def update_config_file(successfully_downloaded:dict,opt:dict):
config_file = opt.config_file or Default_config_file
config_file = os.path.normpath(os.path.join(Globals.root,config_file))
def update_config_file(successfully_downloaded:dict, opt: dict):
yaml = new_config_file_contents(successfully_downloaded,config_file)
config_file = Path(opt.config_file) if opt.config_file is not None else Default_config_file
# In some cases (incomplete setup, etc), the default configs directory might be missing.
# Create it if it doesn't exist.
# this check is ignored if opt.config_file is specified - user is assumed to know what they
# are doing if they are passing a custom config file from elsewhere.
if config_file is Default_config_file and not config_file.parent.exists():
configs_src = Path(__file__).parent / "configs"
configs_dest = Path(Globals.root) / "configs"
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file)
try:
backup = None
if os.path.exists(config_file):
print(f'** {config_file} exists. Renaming to {config_file}.orig')
os.replace(config_file,f'{config_file}.orig')
tmpfile = os.path.join(os.path.dirname(config_file),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(Config_preamble)
outfile.write(yaml)
os.replace(tmpfile,config_file)
print(f'** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig')
backup=config_file.with_suffix(".yaml.orig")
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
if sys.platform == "win32" and backup.is_file():
backup.unlink()
config_file.rename(backup)
with TemporaryFile() as tmp:
tmp.write(Config_preamble.encode())
tmp.write(yaml.encode())
with open(str(config_file.expanduser().resolve()), 'wb') as new_config:
tmp.seek(0)
new_config.write(tmp.read())
except Exception as e:
print(f'**Error creating config file {config_file}: {str(e)} **')
if backup is not None:
print(f'restoring previous config file')
## workaround, for WinError 183, see above
if sys.platform == "win32" and config_file.is_file():
config_file.unlink()
backup.rename(config_file)
return
print(f'Successfully created new configuration file {config_file}')
#---------------------------------------------
def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str:
if os.path.exists(config_file):
conf = OmegaConf.load(config_file)
def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -> str:
if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else:
conf = OmegaConf.create()
@ -796,7 +805,7 @@ def main():
'-c',
dest='config_file',
type=str,
default='./configs/models.yaml',
default=None,
help='path to configuration file to create')
parser.add_argument('--root_dir',
dest='root',