set readline root to ROOTDIR for model import

This commit is contained in:
Lincoln Stein 2022-11-28 18:34:42 +00:00
commit 84cd96decf
2 changed files with 9 additions and 4 deletions

View File

@ -180,7 +180,6 @@ class ModelCache(object):
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
''' '''
omega = self.config omega = self.config
for field in ('description','weights','height','width','config'): for field in ('description','weights','height','width','config'):
assert field in model_attributes, f'required field {field} is missing' assert field in model_attributes, f'required field {field} is missing'
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"' assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'

View File

@ -13,6 +13,7 @@ import re
import atexit import atexit
from ldm.invoke.args import Args from ldm.invoke.args import Args
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import Concepts
from ldm.invoke.globals import Globals
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
try: try:
@ -133,7 +134,12 @@ class Completer(object):
self.matches= self._model_completions(text, state) self.matches= self._model_completions(text, state)
elif re.search(weight_regexp,buffer): elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS) self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp,buffer): elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS) self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
@ -300,7 +306,7 @@ class Completer(object):
readline.redisplay() readline.redisplay()
self.linebuffer = None self.linebuffer = None
def _path_completions(self, text, state, extensions, shortcut_ok=True): def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
# separate the switch from the partial path # separate the switch from the partial path
match = re.search('^(-\w|--\w+=?)(.*)',text) match = re.search('^(-\w|--\w+=?)(.*)',text)
if match is None: if match is None:
@ -319,7 +325,7 @@ class Completer(object):
elif os.path.dirname(path) != '': elif os.path.dirname(path) != '':
dir = os.path.dirname(path) dir = os.path.dirname(path)
else: else:
dir = '' dir = default_dir if os.path.exists(default_dir) else ''
path= os.path.join(dir,path) path= os.path.join(dir,path)
dir_list = os.listdir(dir or '.') dir_list = os.listdir(dir or '.')