check the function signatures and add some easy annotations

Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
devops117 2022-11-22 11:42:46 +05:30 committed by Lincoln Stein
parent c15b839dd4
commit 229f782e3b
2 changed files with 15 additions and 15 deletions

View File

@ -557,8 +557,8 @@ def del_config(model_name:str, gen, opt, completer):
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
if gen.model_cache.del_model(model_name):
gen.model_cache.commit(opt.conf)
gen.model_cache.del_model(model_name)
gen.model_cache.commit(opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)

View File

@ -108,7 +108,7 @@ class ModelCache(object):
if self.config[model_name].get('default'):
return model_name
def set_default_model(self,model_name:str):
def set_default_model(self,model_name:str) -> None:
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
@ -147,7 +147,7 @@ class ModelCache(object):
'description': description,
}}
def print_models(self):
def print_models(self) -> None:
'''
Print a table of models, their descriptions, and load status
'''
@ -158,7 +158,7 @@ class ModelCache(object):
line = f'\033[1m{line}\033[0m')
print(line)
def del_model(self, model_name:str):
def del_model(self, model_name:str) -> None:
'''
Delete the named model.
'''
@ -167,7 +167,7 @@ class ModelCache(object):
if model_name in self.stack:
self.stack.remove(model_name)
def add_model(self, model_name:str, model_attributes:dict, clobber=False):
def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@ -269,7 +269,7 @@ class ModelCache(object):
return model, width, height, model_hash
def offload_model(self, model_name:str):
def offload_model(self, model_name:str) -> None:
'''
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
@ -306,7 +306,7 @@ class ModelCache(object):
else:
print('>> Model Scanned. OK!!')
def _make_cache_room(self):
def _make_cache_room(self) -> None:
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
@ -315,11 +315,11 @@ class ModelCache(object):
del self.models[least_recent_model]
gc.collect()
def print_vram_usage(self):
def print_vram_usage(self) -> None:
if self._has_cuda:
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str):
def commit(self,config_file_path:str) -> None:
'''
Write current configuration out to the indicated file.
'''
@ -330,7 +330,7 @@ class ModelCache(object):
outfile.write(yaml_str)
os.replace(tmpfile,config_file_path)
def preamble(self):
def preamble(self) -> str:
'''
Returns the preamble for the config file.
'''
@ -344,7 +344,7 @@ class ModelCache(object):
# was trained on.
''')
def _invalidate_cached_model(self,model_name:str):
def _invalidate_cached_model(self,model_name:str) -> None:
self.offload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
@ -376,7 +376,7 @@ class ModelCache(object):
'''
return self.stack.pop(0)
def _push_newest_model(self,model_name:str):
def _push_newest_model(self,model_name:str) -> None:
'''
Maintain a simple FIFO. First element is always the
least recent, and last element is always the most recent.
@ -385,10 +385,10 @@ class ModelCache(object):
self.stack.remove(model_name)
self.stack.append(model_name)
def _has_cuda(self):
def _has_cuda(self) -> bool:
return self.device.type == 'cuda'
def _cached_sha256(self,path,data):
def _cached_sha256(self,path,data) -> str | bytes:
dirname = os.path.dirname(path)
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)