Add Model Scanning

This commit is contained in:
blessedcoolant 2022-11-16 22:34:37 +13:00 committed by Lincoln Stein
parent c212b74990
commit 2d6e0baa87
4 changed files with 40 additions and 0 deletions

View File

@ -30,6 +30,7 @@ test-tube>=0.7.5
torch-fidelity
torchmetrics
transformers==4.21.*
picklescan==0.0.5
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg

View File

@ -19,6 +19,7 @@ torch-fidelity
torchvision==0.13.1 ; platform_system == 'Darwin'
torchvision==0.13.1+cu116 ; platform_system == 'Linux' or platform_system == 'Windows'
transformers
picklescan==0.0.5
https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip
https://github.com/TencentARC/GFPGAN/archive/2eac2033893ca7f427f4035d80fe95b92649ac56.zip
https://github.com/invoke-ai/k-diffusion/archive/7f16b2c33411f26b3eae78d10648d625cb0c1095.zip

View File

@ -200,6 +200,9 @@ class ModelCache(object):
width = mconfig.width
height = mconfig.height
# scan model
self.scan_model(model_name, weights)
print(f'>> Loading {model_name} from {weights}')
# for usage statistics
@ -275,6 +278,32 @@ class ModelCache(object):
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def scan_model(self, model_name, checkpoint):
# scan model
from picklescan.scanner import scan_file_path
import sys
print(f'>> Scanning Model: {model_name}')
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
print(f'\n### Issues Found In Model: {scan_result.issues_count}')
print('### WARNING: The model you are trying to load seems to be infected.')
print('### For your safety, InvokeAI will not load this model.')
print('### Please use checkpoints from trusted sources.')
print("### Exiting InvokeAI")
sys.exit()
else:
print('\n### WARNING: InvokeAI was unable to scan the model you are using.')
from ldm.util import ask_user
model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n'])
if model_safe_check_fail.lower() == 'y':
pass
else:
print("### Exiting InvokeAI")
sys.exit()
else:
print('>> Model Scanned. OK!!')
def _make_cache_room(self):
num_loaded_models = len(self.models)

View File

@ -235,3 +235,12 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
def ask_user(question: str, answers: list):
from itertools import chain, repeat
user_prompt = f'\n>> {question} {answers}: '
invalid_answer_msg = 'Invalid answer. Please try again.'
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response