diff --git a/environments-and-requirements/requirements-base.txt b/environments-and-requirements/requirements-base.txt index 3f04e2a720..b8937605e3 100644 --- a/environments-and-requirements/requirements-base.txt +++ b/environments-and-requirements/requirements-base.txt @@ -30,6 +30,7 @@ test-tube>=0.7.5 torch-fidelity torchmetrics transformers==4.21.* +picklescan==0.0.5 git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/invoke-ai/clipseg.git@relaxed-python-requirement#egg=clipseg diff --git a/installer/requirements.in b/installer/requirements.in index 71967f1cf2..cb1b7090ff 100644 --- a/installer/requirements.in +++ b/installer/requirements.in @@ -19,6 +19,7 @@ torch-fidelity torchvision==0.13.1 ; platform_system == 'Darwin' torchvision==0.13.1+cu116 ; platform_system == 'Linux' or platform_system == 'Windows' transformers +picklescan==0.0.5 https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip https://github.com/TencentARC/GFPGAN/archive/2eac2033893ca7f427f4035d80fe95b92649ac56.zip https://github.com/invoke-ai/k-diffusion/archive/7f16b2c33411f26b3eae78d10648d625cb0c1095.zip diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index d4007c46de..a0130e740b 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -200,6 +200,9 @@ class ModelCache(object): width = mconfig.width height = mconfig.height + # scan model + self.scan_model(model_name, weights) + print(f'>> Loading {model_name} from {weights}') # for usage statistics @@ -275,6 +278,32 @@ class ModelCache(object): gc.collect() if self._has_cuda(): torch.cuda.empty_cache() + + def scan_model(self, model_name, checkpoint): + # scan model + from picklescan.scanner import scan_file_path + import sys + print(f'>> Scanning Model: {model_name}') + scan_result = scan_file_path(checkpoint) + if scan_result.infected_files != 0: + if scan_result.infected_files == 1: + print(f'\n### Issues Found In Model: {scan_result.issues_count}') + print('### WARNING: The model you are trying to load seems to be infected.') + print('### For your safety, InvokeAI will not load this model.') + print('### Please use checkpoints from trusted sources.') + print("### Exiting InvokeAI") + sys.exit() + else: + print('\n### WARNING: InvokeAI was unable to scan the model you are using.') + from ldm.util import ask_user + model_safe_check_fail = ask_user('Do you want to to continue loading the model?', ['y', 'n']) + if model_safe_check_fail.lower() == 'y': + pass + else: + print("### Exiting InvokeAI") + sys.exit() + else: + print('>> Model Scanned. OK!!') def _make_cache_room(self): num_loaded_models = len(self.models) diff --git a/ldm/util.py b/ldm/util.py index f3ef0b606b..478c66b8b5 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -235,3 +235,12 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t* n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) t = fade(grid[:shape[0], :shape[1]]) return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + +def ask_user(question: str, answers: list): + from itertools import chain, repeat + user_prompt = f'\n>> {question} {answers}: ' + invalid_answer_msg = 'Invalid answer. Please try again.' + pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt]))) + user_answers = map(input, pose_question) + valid_response = next(filter(answers.__contains__, user_answers)) + return valid_response