Bug Fix: Model import fixes (#1566)

These bug fixes address issues #1546 and #1547 .
This commit is contained in:
Damian Stewart 2022-11-28 21:46:27 +01:00 committed by GitHub
commit bc44ab786c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 9 deletions

View File

@ -840,6 +840,7 @@ class Generate:
model_data = cache.get_model(model_name) model_data = cache.get_model(model_name)
if model_data is None: # restore previous if model_data is None: # restore previous
model_data = cache.get_model(self.model_name) model_data = cache.get_model(self.model_name)
model_name = self.model_name # addresses Issue #1547
self.model = model_data['model'] self.model = model_data['model']
self.width = model_data['width'] self.width = model_data['width']

View File

@ -78,11 +78,12 @@ class ModelCache(object):
else: # we're about to load a new model, so potentially offload the least recently used one else: # we're about to load a new model, so potentially offload the least recently used one
try: try:
requested_model, width, height, hash = self._load_model(model_name) requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {} self.models[model_name] = {
self.models[model_name]['model'] = requested_model 'model': requested_model,
self.models[model_name]['width'] = width 'width': width,
self.models[model_name]['height'] = height 'height': height,
self.models[model_name]['hash'] = hash 'hash': hash,
}
except Exception as e: except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}') print(f'** model {model_name} could not be loaded: {str(e)}')
@ -183,11 +184,11 @@ class ModelCache(object):
method will return True. Will fail with an assertion error if provided method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing. attributes are incorrect or the model name is missing.
''' '''
omega = self.config
for field in ('description','weights','height','width','config'): for field in ('description','weights','height','width','config'):
assert field in model_attributes, f'required field {field} is missing' assert field in model_attributes, f'required field {field} is missing'
assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"' assert (clobber or model_name not in omega), f'attempt to overwrite existing model definition "{model_name}"'
omega = self.config
config = omega[model_name] if model_name in omega else {} config = omega[model_name] if model_name in omega else {}
for field in model_attributes: for field in model_attributes:
config[field] = model_attributes[field] config[field] = model_attributes[field]

View File

@ -13,6 +13,7 @@ import re
import atexit import atexit
from ldm.invoke.args import Args from ldm.invoke.args import Args
from ldm.invoke.concepts_lib import Concepts from ldm.invoke.concepts_lib import Concepts
from ldm.invoke.globals import Globals
# ---------------readline utilities--------------------- # ---------------readline utilities---------------------
try: try:
@ -133,7 +134,12 @@ class Completer(object):
self.matches= self._model_completions(text, state) self.matches= self._model_completions(text, state)
elif re.search(weight_regexp,buffer): elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS) self.matches = self._path_completions(
text,
state,
WEIGHT_EXTENSIONS,
default_dir=Globals.root,
)
elif re.search(text_regexp,buffer): elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS) self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
@ -300,7 +306,7 @@ class Completer(object):
readline.redisplay() readline.redisplay()
self.linebuffer = None self.linebuffer = None
def _path_completions(self, text, state, extensions, shortcut_ok=True): def _path_completions(self, text, state, extensions, shortcut_ok=True, default_dir:str=''):
# separate the switch from the partial path # separate the switch from the partial path
match = re.search('^(-\w|--\w+=?)(.*)',text) match = re.search('^(-\w|--\w+=?)(.*)',text)
if match is None: if match is None:
@ -319,7 +325,7 @@ class Completer(object):
elif os.path.dirname(path) != '': elif os.path.dirname(path) != '':
dir = os.path.dirname(path) dir = os.path.dirname(path)
else: else:
dir = '' dir = default_dir if os.path.exists(default_dir) else ''
path= os.path.join(dir,path) path= os.path.join(dir,path)
dir_list = os.listdir(dir or '.') dir_list = os.listdir(dir or '.')