mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
check the function signatures and add some easy annotations
Signed-off-by: devops117 <55235206+devops117@users.noreply.github.com>
This commit is contained in:
parent
c15b839dd4
commit
229f782e3b
@ -557,8 +557,8 @@ def del_config(model_name:str, gen, opt, completer):
|
|||||||
if model_name == current_model:
|
if model_name == current_model:
|
||||||
print("** Can't delete active model. !switch to another model first. **")
|
print("** Can't delete active model. !switch to another model first. **")
|
||||||
return
|
return
|
||||||
if gen.model_cache.del_model(model_name):
|
gen.model_cache.del_model(model_name)
|
||||||
gen.model_cache.commit(opt.conf)
|
gen.model_cache.commit(opt.conf)
|
||||||
print(f'** {model_name} deleted')
|
print(f'** {model_name} deleted')
|
||||||
completer.del_model(model_name)
|
completer.del_model(model_name)
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ class ModelCache(object):
|
|||||||
if self.config[model_name].get('default'):
|
if self.config[model_name].get('default'):
|
||||||
return model_name
|
return model_name
|
||||||
|
|
||||||
def set_default_model(self,model_name:str):
|
def set_default_model(self,model_name:str) -> None:
|
||||||
'''
|
'''
|
||||||
Set the default model. The change will not take
|
Set the default model. The change will not take
|
||||||
effect until you call model_cache.commit()
|
effect until you call model_cache.commit()
|
||||||
@ -147,7 +147,7 @@ class ModelCache(object):
|
|||||||
'description': description,
|
'description': description,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
def print_models(self):
|
def print_models(self) -> None:
|
||||||
'''
|
'''
|
||||||
Print a table of models, their descriptions, and load status
|
Print a table of models, their descriptions, and load status
|
||||||
'''
|
'''
|
||||||
@ -158,7 +158,7 @@ class ModelCache(object):
|
|||||||
line = f'\033[1m{line}\033[0m')
|
line = f'\033[1m{line}\033[0m')
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
def del_model(self, model_name:str):
|
def del_model(self, model_name:str) -> None:
|
||||||
'''
|
'''
|
||||||
Delete the named model.
|
Delete the named model.
|
||||||
'''
|
'''
|
||||||
@ -167,7 +167,7 @@ class ModelCache(object):
|
|||||||
if model_name in self.stack:
|
if model_name in self.stack:
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
|
|
||||||
def add_model(self, model_name:str, model_attributes:dict, clobber=False):
|
def add_model(self, model_name:str, model_attributes:dict, clobber=False) -> None:
|
||||||
'''
|
'''
|
||||||
Update the named model with a dictionary of attributes. Will fail with an
|
Update the named model with a dictionary of attributes. Will fail with an
|
||||||
assertion error if the name already exists. Pass clobber=True to overwrite.
|
assertion error if the name already exists. Pass clobber=True to overwrite.
|
||||||
@ -269,7 +269,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def offload_model(self, model_name:str):
|
def offload_model(self, model_name:str) -> None:
|
||||||
'''
|
'''
|
||||||
Offload the indicated model to CPU. Will call
|
Offload the indicated model to CPU. Will call
|
||||||
_make_cache_room() to free space if needed.
|
_make_cache_room() to free space if needed.
|
||||||
@ -306,7 +306,7 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
print('>> Model Scanned. OK!!')
|
print('>> Model Scanned. OK!!')
|
||||||
|
|
||||||
def _make_cache_room(self):
|
def _make_cache_room(self) -> None:
|
||||||
num_loaded_models = len(self.models)
|
num_loaded_models = len(self.models)
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
@ -315,11 +315,11 @@ class ModelCache(object):
|
|||||||
del self.models[least_recent_model]
|
del self.models[least_recent_model]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def print_vram_usage(self):
|
def print_vram_usage(self) -> None:
|
||||||
if self._has_cuda:
|
if self._has_cuda:
|
||||||
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
print('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
def commit(self,config_file_path:str):
|
def commit(self,config_file_path:str) -> None:
|
||||||
'''
|
'''
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
'''
|
'''
|
||||||
@ -330,7 +330,7 @@ class ModelCache(object):
|
|||||||
outfile.write(yaml_str)
|
outfile.write(yaml_str)
|
||||||
os.replace(tmpfile,config_file_path)
|
os.replace(tmpfile,config_file_path)
|
||||||
|
|
||||||
def preamble(self):
|
def preamble(self) -> str:
|
||||||
'''
|
'''
|
||||||
Returns the preamble for the config file.
|
Returns the preamble for the config file.
|
||||||
'''
|
'''
|
||||||
@ -344,7 +344,7 @@ class ModelCache(object):
|
|||||||
# was trained on.
|
# was trained on.
|
||||||
''')
|
''')
|
||||||
|
|
||||||
def _invalidate_cached_model(self,model_name:str):
|
def _invalidate_cached_model(self,model_name:str) -> None:
|
||||||
self.offload_model(model_name)
|
self.offload_model(model_name)
|
||||||
if model_name in self.stack:
|
if model_name in self.stack:
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
@ -376,7 +376,7 @@ class ModelCache(object):
|
|||||||
'''
|
'''
|
||||||
return self.stack.pop(0)
|
return self.stack.pop(0)
|
||||||
|
|
||||||
def _push_newest_model(self,model_name:str):
|
def _push_newest_model(self,model_name:str) -> None:
|
||||||
'''
|
'''
|
||||||
Maintain a simple FIFO. First element is always the
|
Maintain a simple FIFO. First element is always the
|
||||||
least recent, and last element is always the most recent.
|
least recent, and last element is always the most recent.
|
||||||
@ -385,10 +385,10 @@ class ModelCache(object):
|
|||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
self.stack.append(model_name)
|
self.stack.append(model_name)
|
||||||
|
|
||||||
def _has_cuda(self):
|
def _has_cuda(self) -> bool:
|
||||||
return self.device.type == 'cuda'
|
return self.device.type == 'cuda'
|
||||||
|
|
||||||
def _cached_sha256(self,path,data):
|
def _cached_sha256(self,path,data) -> str | bytes:
|
||||||
dirname = os.path.dirname(path)
|
dirname = os.path.dirname(path)
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
base, _ = os.path.splitext(basename)
|
base, _ = os.path.splitext(basename)
|
||||||
|
Loading…
Reference in New Issue
Block a user