From d589ad96aac873e273435e1271628c488d15254a Mon Sep 17 00:00:00 2001
From: Lincoln Stein <lincoln.stein@gmail.com>
Date: Fri, 10 Feb 2023 15:06:37 -0500
Subject: [PATCH] fix two bugs in conversion of inpaint models from ckpt to
 diffusers models

- If CLI asked to convert the currently loaded model, the model would crash
  on the first rendering. CLI will now refuse to convert a model loaded
  in memory (probably a good idea in any case).

- CLI will offer the `v1-inpainting-inference.yaml` as the configuration
  file when importing an inpainting a .ckpt or .safetensors file that
  has "inpainting" in the name. Otherwise it offers `v1-inference.yaml`
  as the default.
---
 ldm/invoke/CLI.py | 30 +++++++++++++++++++-----------
 1 file changed, 19 insertions(+), 11 deletions(-)

diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py
index fd61c7c8bf..2c204bb33f 100644
--- a/ldm/invoke/CLI.py
+++ b/ldm/invoke/CLI.py
@@ -58,12 +58,9 @@ def main():
     print(f'>> Internet connectivity is {Globals.internet_available}')
 
     if not args.conf:
-        if not os.path.exists(os.path.join(Globals.root,'configs','models.yaml')):
-            report_model_error(opt, e)
-            # print(f"\n** Error. The file {os.path.join(Globals.root,'configs','models.yaml')} could not be found.")
-            # print('** Please check the location of your invokeai directory and use the --root_dir option to point to the correct path.')
-            # print('** This script will now exit.')
-            # sys.exit(-1)
+        config_file = os.path.join(Globals.root,'configs','models.yaml')
+        if not os.path.exists(config_file):
+            report_model_error(opt, FileNotFoundError(f"The file {config_file} could not be found."))
 
     print(f'>> {ldm.invoke.__app_name__}, version {ldm.invoke.__version__}')
     print(f'>> InvokeAI runtime directory is "{Globals.root}"')
@@ -658,7 +655,9 @@ def import_ckpt_model(path_or_url: Union[Path, str], gen, opt, completer) -> Opt
         model_description=default_description
     )
     config_file = None
-    default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
+    default = Path(Globals.root,'configs/stable-diffusion/v1-inpainting-inference.yaml') \
+        if re.search('inpaint',default_name, flags=re.IGNORECASE) \
+           else Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
 
     completer.complete_extensions(('.yaml','.yml'))
     completer.set_line(str(default))
@@ -709,12 +708,21 @@ def _get_model_name_and_desc(model_manager,completer,model_name:str='',model_des
     model_description = input(f'Description for this model [{model_description}]: ').strip() or model_description
     return model_name, model_description
 
-def optimize_model(model_name_or_path:str, gen, opt, completer):
+def _is_inpainting(model_name_or_path: str)->bool:
+    if re.search('inpaint',model_name_or_path, flags=re.IGNORECASE):
+        return not input('Is this an inpainting model? [y] ').startswith(('n','N'))
+    else:
+        return not input('Is this an inpainting model? [n] ').startswith(('y','Y'))
+
+def optimize_model(model_name_or_path: str, gen, opt, completer):
     manager = gen.model_manager
     ckpt_path = None
     original_config_file = None
 
-    if (model_info := manager.model_info(model_name_or_path)):
+    if model_name_or_path == gen.model_name:
+        print("** Can't convert the active model. !switch to another model first. **")
+        return
+    elif (model_info := manager.model_info(model_name_or_path)):
         if 'weights' in model_info:
             ckpt_path = Path(model_info['weights'])
             original_config_file = Path(model_info['config'])
@@ -731,7 +739,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
             ckpt_path.stem,
             f'Converted model {ckpt_path.stem}'
         )
-        is_inpainting = input('Is this an inpainting model? [n] ').startswith(('y','Y'))
+        is_inpainting = _is_inpainting(model_name_or_path)
         original_config_file = Path(
             'configs',
             'stable-diffusion',
@@ -950,7 +958,7 @@ def prepare_image_metadata(
             print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use {{prefix}}.{{seed}}.png\' instead')
             filename = f'{prefix}.{seed}.png'
         except IndexError:
-            print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
+            print("** The filename format is broken or complete. Will use '{prefix}.{seed}.png' instead")
             filename = f'{prefix}.{seed}.png'
 
     if opt.variation_amount > 0: