diff --git a/.gitignore b/.gitignore index 4cf76e78e7..b274423c95 100644 --- a/.gitignore +++ b/.gitignore @@ -235,3 +235,6 @@ update.sh # this may be present if the user created a venv invokeai + +# no longer stored in source directory +models \ No newline at end of file diff --git a/ldm/__init__.py b/ldm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ldm/invoke/__init__.py b/ldm/invoke/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py index 6bdd1814d1..1fdec33685 100644 --- a/ldm/invoke/txt2mask.py +++ b/ldm/invoke/txt2mask.py @@ -29,10 +29,12 @@ work fine. import torch import numpy as np +import os from clipseg.clipseg import CLIPDensePredT from einops import rearrange, repeat from PIL import Image, ImageOps from torchvision import transforms +from ldm.invoke.globals import Globals CLIP_VERSION = 'ViT-B/16' CLIPSEG_WEIGHTS = 'models/clipseg/clipseg_weights/rd64-uni.pth' @@ -80,7 +82,11 @@ class Txt2Mask(object): self.model.eval() # initially we keep everything in cpu to conserve space self.model.to('cpu') - self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS_REFINED if refined else CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) + self.model.load_state_dict(torch.load(os.path.join(Globals.root,CLIPSEG_WEIGHTS_REFINED) + if refined + else os.path.join(Globals.root,CLIPSEG_WEIGHTS), + map_location=torch.device('cpu')), strict=False + ) @torch.no_grad() def segment(self, image, prompt:str) -> SegmentedGrayscale: diff --git a/ldm/models/__init__.py b/ldm/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ldm/modules/__init__.py b/ldm/modules/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/setup.py b/setup.py index bc6c7cdc31..101a1ca47e 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages setup( name='invoke-ai', - version='2.1.3', + version='2.1.4', description='InvokeAI text to image generation toolkit', packages=find_packages(), install_requires=[