prevent crash when switching to an invalid model

This commit is contained in:
Lincoln Stein 2022-11-09 02:07:13 +00:00
parent b17ca0a5e7
commit 71ee44a827
2 changed files with 13 additions and 2 deletions

View File

@ -802,6 +802,10 @@ class Generate:
# the model cache does the loading and offloading # the model cache does the loading and offloading
cache = self.model_cache cache = self.model_cache
if not cache.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return self.model
cache.print_vram_usage() cache.print_vram_usage()
# have to get rid of all references to model in order # have to get rid of all references to model in order

View File

@ -41,15 +41,22 @@ class ModelCache(object):
self.stack = [] # this is an LRU FIFO self.stack = [] # this is an LRU FIFO
self.current_model = None self.current_model = None
def valid_model(self, model_name:str)->bool:
'''
Given a model name, returns True if it is a valid
identifier.
'''
return model_name in self.config
def get_model(self, model_name:str): def get_model(self, model_name:str):
''' '''
Given a model named identified in models.yaml, return Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM. the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there. If on disk, will load from there.
''' '''
if model_name not in self.config: if not self.valid_model(model_name):
print(f'** "{model_name}" is not a known model name. Please check your models.yaml file') print(f'** "{model_name}" is not a known model name. Please check your models.yaml file')
return None return self.current_model
if self.current_model != model_name: if self.current_model != model_name:
if model_name not in self.models: # make room for a new one if model_name not in self.models: # make room for a new one