mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add Model Scanning
This commit is contained in:
parent
c212b74990
commit
2d6e0baa87
@ -30,6 +30,7 @@ test-tube>=0.7.5
|
||||
torch-fidelity
|
||||
torchmetrics
|
||||
transformers==4.21.*
|
||||
picklescan==0.0.5
|
||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||
git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg
|
||||
|
@ -19,6 +19,7 @@ torch-fidelity
|
||||
torchvision==0.13.1 ; platform_system == 'Darwin'
|
||||
torchvision==0.13.1+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
|
||||
transformers
|
||||
picklescan==0.0.5
|
||||
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
|
||||
https://github.com/TencentARC/GFPGAN/archive/2eac2033893ca7f427f4035d80fe95b92649ac56.zip
|
||||
https://github.com/invoke-ai/k-diffusion/archive/7f16b2c33411f26b3eae78d10648d625cb0c1095.zip
|
||||
|
@ -200,6 +200,9 @@ class ModelCache(object):
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
# scan model
|
||||
self.scan_model(model_name, weights)
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
@ -275,6 +278,32 @@ class ModelCache(object):
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
# scan model
|
||||
from picklescan.scanner import scan_file_path
|
||||
import sys
|
||||
print(f'>> Scanning Model: {model_name}')
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
|
||||
print('### WARNING: The model you are trying to load seems to be infected.')
|
||||
print('### For your safety, InvokeAI will not load this model.')
|
||||
print('### Please use checkpoints from trusted sources.')
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
|
||||
from ldm.util import ask_user
|
||||
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
|
||||
if model_safe_check_fail.lower() == 'y':
|
||||
pass
|
||||
else:
|
||||
print("### Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
print('>> Model Scanned. OK!!')
|
||||
|
||||
def _make_cache_room(self):
|
||||
num_loaded_models = len(self.models)
|
||||
|
@ -235,3 +235,12 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
|
||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||
t = fade(grid[:shape[0], :shape[1]])
|
||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||
|
||||
def ask_user(question: str, answers: list):
|
||||
from itertools import chain, repeat
|
||||
user_prompt = f'\n>> {question} {answers}: '
|
||||
invalid_answer_msg = 'Invalid answer. Please try again.'
|
||||
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
|
||||
user_answers = map(input, pose_question)
|
||||
valid_response = next(filter(answers.__contains__, user_answers))
|
||||
return valid_response
|
||||
|
Loading…
Reference in New Issue
Block a user