check the function signatures and add some easy annotations

Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
devops117 2022-11-22 11:42:46 +05:30 committed by Lincoln Stein
parent c15b839dd4
commit 229f782e3b
2 changed files with 15 additions and 15 deletions

View File

@ -557,8 +557,8 @@ def del_config(model_name:str, gen, opt, completer):
if model_name == current_model: if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **") print("** Can't delete active model. !switch to another model first. **")
return return
if gen.model_cache.del_model(model_name): gen.model_cache.del_model(model_name)
gen.model_cache.commit(opt.conf) gen.model_cache.commit(opt.conf)
print(f'** {model_name} deleted') print(f'** {model_name} deleted')
completer.del_model(model_name) completer.del_model(model_name)

View File

@ -108,7 +108,7 @@ class ModelCache(object):
if self.config[model_name].get('default'): if self.config[model_name].get('default'):
return model_name return model_name
def set_default_model(self,model_name:str): def set_default_model(self,model_name:str) -> None:
''' '''
Set the default model. The change will not take Set the default model. The change will not take
effect until you call model_cache.commit() effect until you call model_cache.commit()
@ -147,7 +147,7 @@ class ModelCache(object):
'description': description, 'description': description,
}} }}
def print_models(self): def print_models(self) -> None:
''' '''
Print a table of models, their descriptions, and load status Print a table of models, their descriptions, and load status
''' '''
@ -158,7 +158,7 @@ class ModelCache(object):
line = f'\033[1m{line}\033[0m') line = f'\033[1m{line}\033[0m')
print(line) print(line)
def del_model(self, model_name:str): def del_model(self, model_name:str) -> None:
''' '''
Delete the named model. Delete the named model.
''' '''
@ -167,7 +167,7 @@ class ModelCache(object):
if model_name in self.stack: if model_name in self.stack:
self.stack.remove(model_name) self.stack.remove(model_name)
def add_model(self, model_name:str, model_attributes:dict, clobber=False): def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
''' '''
Update the named model with a dictionary of attributes. Will fail with an Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite. assertion error if the name already exists. Pass clobber=True to overwrite.
@ -269,7 +269,7 @@ class ModelCache(object):
return model, width, height, model_hash return model, width, height, model_hash
def offload_model(self, model_name:str): def offload_model(self, model_name:str) -> None:
''' '''
Offload the indicated model to CPU. Will call Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed. _make_cache_room() to free space if needed.
@ -306,7 +306,7 @@ class ModelCache(object):
else: else:
print('>> Model Scanned. OK!!') print('>> Model Scanned. OK!!')
def _make_cache_room(self): def _make_cache_room(self) -> None:
num_loaded_models = len(self.models) num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models: if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model() least_recent_model = self._pop_oldest_model()
@ -315,11 +315,11 @@ class ModelCache(object):
del self.models[least_recent_model] del self.models[least_recent_model]
gc.collect() gc.collect()
def print_vram_usage(self): def print_vram_usage(self) -> None:
if self._has_cuda: if self._has_cuda:
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str): def commit(self,config_file_path:str) -> None:
''' '''
Write current configuration out to the indicated file. Write current configuration out to the indicated file.
''' '''
@ -330,7 +330,7 @@ class ModelCache(object):
outfile.write(yaml_str) outfile.write(yaml_str)
os.replace(tmpfile,config_file_path) os.replace(tmpfile,config_file_path)
def preamble(self): def preamble(self) -> str:
''' '''
Returns the preamble for the config file. Returns the preamble for the config file.
''' '''
@ -344,7 +344,7 @@ class ModelCache(object):
# was trained on. # was trained on.
''') ''')
def _invalidate_cached_model(self,model_name:str): def _invalidate_cached_model(self,model_name:str) -> None:
self.offload_model(model_name) self.offload_model(model_name)
if model_name in self.stack: if model_name in self.stack:
self.stack.remove(model_name) self.stack.remove(model_name)
@ -376,7 +376,7 @@ class ModelCache(object):
''' '''
return self.stack.pop(0) return self.stack.pop(0)
def _push_newest_model(self,model_name:str): def _push_newest_model(self,model_name:str) -> None:
''' '''
Maintain a simple FIFO. First element is always the Maintain a simple FIFO. First element is always the
least recent, and last element is always the most recent. least recent, and last element is always the most recent.
@ -385,10 +385,10 @@ class ModelCache(object):
self.stack.remove(model_name) self.stack.remove(model_name)
self.stack.append(model_name) self.stack.append(model_name)
def _has_cuda(self): def _has_cuda(self) -> bool:
return self.device.type == 'cuda' return self.device.type == 'cuda'
def _cached_sha256(self,path,data): def _cached_sha256(self,path,data) -> str | bytes:
dirname = os.path.dirname(path) dirname = os.path.dirname(path)
basename = os.path.basename(path) basename = os.path.basename(path)
base, _ = os.path.splitext(basename) base, _ = os.path.splitext(basename)