From 018ccebd6f1ac7cdb1017bb4866e776e3c02adba Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 4 Dec 2023 19:07:25 -0500 Subject: [PATCH] make ModelLocalSource comparisons work across platforms --- .../model_install/model_install_base.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index baed791e61..7f359772b3 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -31,19 +31,35 @@ class UnknownInstallJobException(Exception): class StringLikeSource(BaseModel): - """Base class for model sources, implements functions that lets the source be sorted and indexed.""" + """ + Base class for model sources, implements functions that lets the source be sorted and indexed. + + These shenanigans let this stuff work: + + source1 = LocalModelSource(path='C:/users/mort/foo.safetensors') + mydict = {source1: 'model 1'} + assert mydict['C:/users/mort/foo.safetensors'] == 'model 1' + assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1' + + source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors')) + assert source1 == source2 + assert source1 == 'C:/users/mort/foo.safetensors' + """ def __hash__(self) -> int: """Return hash of the path field, for indexing.""" return hash(str(self)) - def __lt__(self, other: Any) -> int: + def __lt__(self, other: object) -> int: """Return comparison of the stringified version, for sorting.""" return str(self) < str(other) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other: object) -> bool: """Return equality on the stringified version.""" - return str(self) == str(other) + if isinstance(other, Path): + return str(self) == other.as_posix() + else: + return str(self) == str(other) class LocalModelSource(StringLikeSource):