Reference model from configs/models.yaml

By supplying --model (defaulting to stable-diffusion-1.4) a user can specify which model to load.
Width/Height/Config Location/Weights Location are referenced from configs/models.yaml
This commit is contained in:
David Wager 2022-09-01 19:04:31 +01:00 committed by GitHub
parent db580ccefd
commit d319b8a762
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -13,11 +13,13 @@ import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf
def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
"""
if opt.laion400m:
# defaults suitable to the older latent diffusion weights
width = 256
@ -33,6 +35,15 @@ def main():
weights = opt.weights
else:
weights = f'models/ldm/stable-diffusion-v1/{opt.weights}.ckpt'
"""
try:
models = OmegaConf.load('configs/models.yaml')
width = models[opt.model].width
height = models[opt.model].height
config = models[opt.model].config
weights = models[opt.model].weights
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
print('* Initializing, be patient...\n')
sys.path.append('.')
@ -426,6 +437,11 @@ def create_argv_parser():
default='model',
help='Indicates the Stable Diffusion model to use.',
)
parser.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which Diffusion model to load.',
)
return parser
@ -540,4 +556,4 @@ def create_cmd_parser():
if __name__ == '__main__':
main()
main()