mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
blackify
This commit is contained in:
@ -285,9 +285,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
|
||||
self._validate_filename(job.destination, file_name) # will raise a ValueError exception if file_name is suspicious
|
||||
self._validate_filename(
|
||||
job.destination, file_name
|
||||
) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead")
|
||||
self._logger.warning(
|
||||
f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead"
|
||||
)
|
||||
file_name = os.path.basename(job.source)
|
||||
except KeyError:
|
||||
file_name = os.path.basename(job.source)
|
||||
@ -341,9 +345,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str):
|
||||
if '/' in filename:
|
||||
if "/" in filename:
|
||||
raise ValueError
|
||||
if filename.startswith('..'):
|
||||
if filename.startswith(".."):
|
||||
raise ValueError
|
||||
if len(filename) > os.pathconf(directory, "PC_NAME_MAX"):
|
||||
raise ValueError
|
||||
|
@ -45,34 +45,37 @@ for i in ["12345", "9999", "54321"]:
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
session.mount("http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition" : 'filename="missing.txt"',
|
||||
}
|
||||
),
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found
|
||||
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
# specifies a content disposition that may overwrite files in the parent directory
|
||||
session.mount("http://www.civitai.com/models/malicious",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition" : 'filename="../badness.txt"',
|
||||
}
|
||||
),
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/malicious",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="../badness.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# Would create a path that is too long
|
||||
session.mount("http://www.civitai.com/models/long",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition" : f'filename="{"i"*1000}"',
|
||||
}
|
||||
),
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/long",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition": f'filename="{"i"*1000}"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# mock HuggingFace URLs
|
||||
hf_sd2_paths = [
|
||||
@ -230,16 +233,15 @@ def test_bad_urls():
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/malicious", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "completed"
|
||||
assert job.destination == Path(tmpdir, 'malicious')
|
||||
assert Path(tmpdir, 'malicious').exists()
|
||||
assert job.destination == Path(tmpdir, "malicious")
|
||||
assert Path(tmpdir, "malicious").exists()
|
||||
|
||||
# Nor a destination that would exceed the maximum filename or path length
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/long", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "completed"
|
||||
assert job.destination == Path(tmpdir, 'long')
|
||||
assert Path(tmpdir, 'long').exists()
|
||||
|
||||
assert job.destination == Path(tmpdir, "long")
|
||||
assert Path(tmpdir, "long").exists()
|
||||
|
||||
# create a foreign job which will be invalid for the queue
|
||||
bad_job = DownloadJobBase(id=999, source="mock", destination="mock")
|
||||
@ -341,4 +343,3 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
).exists(), "cancelled file should be deleted"
|
||||
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
||||
|
Reference in New Issue
Block a user