improve behavior of preload_models.py

- NEVER overwrite user's existing models.yaml
- Instead, merge its contents into new config file,
  and rename original to models.yaml.orig (with
  message)
- models.yaml has been removed from repository and renamed
  models.yaml.example
This commit is contained in:
Lincoln Stein 2022-10-31 11:08:19 -04:00
parent 5a95ce5625
commit 90cd791e76
2 changed files with 24 additions and 2 deletions

View File

@ -13,3 +13,16 @@ stable-diffusion-1.5:
height: 512
vae: ./models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
default: true
stable-diffusion-1.4:
description: Stable Diffusion inference model version 1.4
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512
height: 512
inpainting-1.5:
weights: models/ldm/stable-diffusion-v1/sd-v1-5-inpainting.ckpt
config: configs/stable-diffusion/v1-inpainting-inference.yaml
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: RunwayML SD 1.5 model optimized for inpainting

View File

@ -321,22 +321,31 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
#---------------------------------------------
def update_config_file(successfully_downloaded:dict):
try:
yaml = new_config_file_contents(successfully_downloaded)
try:
if os.path.exists(Config_file):
print(f'** {Config_file} exists. Renaming to {Config_file}.orig')
os.rename(Config_file,f'{Config_file}.orig')
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(Config_preamble)
outfile.write(yaml)
os.rename(tmpfile,Config_file)
except Exception as e:
print(f'**Error creating config file {Config_file}: {str(e)} **')
return
print(f'Successfully created new configuration file {Config_file}')
#---------------------------------------------
def new_config_file_contents(successfully_downloaded:dict)->str:
if os.path.exists(Config_file):
conf = OmegaConf.load(Config_file)
else:
conf = OmegaConf.create()
# find the VAE file, if there is one
vae = None