make ModelLocalSource comparisons work across platforms

This commit is contained in:
Lincoln Stein 2023-12-04 19:07:25 -05:00
parent 620b2d477a
commit 018ccebd6f

View File

@ -31,19 +31,35 @@ class UnknownInstallJobException(Exception):
class StringLikeSource(BaseModel):
"""Base class for model sources, implements functions that lets the source be sorted and indexed."""
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: Any) -> int:
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: Any) -> bool:
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
return str(self) == str(other)
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):