diff --git a/tests/test_api.py b/tests/test_api.py index 40144ba..7197da9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -9,7 +9,7 @@ from twitchdl import twitch from twitchdl.commands.download import get_clip_authenticated_url from twitchdl.commands.videos import get_game_ids from twitchdl.exceptions import ConsoleError -from twitchdl.playlists import enumerate_vods, load_m3u8, parse_playlists +from twitchdl.playlists import parse_playlists TEST_CHANNEL = "bananasaurus_rex" @@ -37,10 +37,6 @@ def test_get_videos(): playlist_txt = httpx.get(playlist_url).text assert playlist_txt.startswith("#EXTM3U") - playlist_m3u8 = load_m3u8(playlist_txt) - vods = enumerate_vods(playlist_m3u8) - assert vods[0].path == "0.ts" - def test_get_clips(): """ diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000..1a870e5 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,90 @@ +from decimal import Decimal + +from twitchdl.commands.download import filter_vods +from twitchdl.playlists import Vod + +VODS = [ + Vod(index=1, path="1.ts", duration=Decimal("10.0")), + Vod(index=2, path="2.ts", duration=Decimal("10.0")), + Vod(index=3, path="3.ts", duration=Decimal("10.0")), + Vod(index=4, path="4.ts", duration=Decimal("10.0")), + Vod(index=5, path="5.ts", duration=Decimal("10.0")), + Vod(index=6, path="6.ts", duration=Decimal("10.0")), + Vod(index=7, path="7.ts", duration=Decimal("10.0")), + Vod(index=8, path="8.ts", duration=Decimal("10.0")), + Vod(index=9, path="9.ts", duration=Decimal("10.0")), + Vod(index=10, path="10.ts", duration=Decimal("3.15")), +] + + +def test_filter_vods_no_start_no_end(): + vods, start_offset, duration = filter_vods(VODS, None, None) + assert vods == VODS + assert start_offset == Decimal("0") + assert duration == Decimal("93.15") + + +def test_filter_vods_start(): + # Zero offset + vods, start_offset, duration = filter_vods(VODS, 0, None) + assert [v.index for v in vods] == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert start_offset == Decimal("0") + assert duration == Decimal("93.15") + + # Mid-vod + vods, start_offset, duration = filter_vods(VODS, 13, None) + assert [v.index for v in vods] == [2, 3, 4, 5, 6, 7, 8, 9, 10] + assert start_offset == Decimal("3.0") + assert duration == Decimal("80.15") + + # Between vods + vods, start_offset, duration = filter_vods(VODS, 50, None) + assert [v.index for v in vods] == [6, 7, 8, 9, 10] + assert start_offset == Decimal("0") + assert duration == Decimal("43.15") + + # Close to end + vods, start_offset, duration = filter_vods(VODS, 93, None) + assert [v.index for v in vods] == [10] + assert start_offset == Decimal("3.0") + assert duration == Decimal("0.15") + + +def test_filter_vods_end(): + # Zero offset + vods, start_offset, duration = filter_vods(VODS, 0, None) + assert [v.index for v in vods] == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert start_offset == Decimal("0") + assert duration == Decimal("93.15") + + # Mid-vod + vods, start_offset, duration = filter_vods(VODS, None, 56) + assert [v.index for v in vods] == [1, 2, 3, 4, 5, 6] + assert start_offset == Decimal("0") + assert duration == Decimal("56") + + # Between vods + vods, start_offset, duration = filter_vods(VODS, None, 30) + assert [v.index for v in vods] == [1, 2, 3] + assert start_offset == Decimal("0") + assert duration == Decimal("30") + + +def test_filter_vods_start_end(): + # Zero offset + vods, start_offset, duration = filter_vods(VODS, 0, 0) + assert [v.index for v in vods] == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + assert start_offset == Decimal("0") + assert duration == Decimal("93.15") + + # Mid-vod + vods, start_offset, duration = filter_vods(VODS, 32, 56) + assert [v.index for v in vods] == [4, 5, 6] + assert start_offset == Decimal("2") + assert duration == Decimal("24") + + # Between vods + vods, start_offset, duration = filter_vods(VODS, 20, 60) + assert [v.index for v in vods] == [3, 4, 5, 6] + assert start_offset == Decimal("0") + assert duration == Decimal("40") diff --git a/twitchdl/commands/download.py b/twitchdl/commands/download.py index ffa843e..7b6a9df 100644 --- a/twitchdl/commands/download.py +++ b/twitchdl/commands/download.py @@ -5,9 +5,10 @@ import re import shutil import subprocess import tempfile +from decimal import Decimal from os import path from pathlib import Path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from urllib.parse import urlencode, urlparse import click @@ -20,10 +21,11 @@ from twitchdl.exceptions import ConsoleError from twitchdl.http import download_all from twitchdl.output import blue, bold, green, print_log, yellow from twitchdl.playlists import ( - enumerate_vods, + Vod, load_m3u8, make_join_playlist, parse_playlists, + parse_vods, select_playlist, ) from twitchdl.twitch import Chapter, Clip, ClipAccessToken, Video @@ -286,14 +288,8 @@ def _download_video(video_id: str, args: DownloadOptions) -> None: print_log("Fetching playlist...") vods_text = http_get(playlist.url) vods_m3u8 = load_m3u8(vods_text) - vods, start_offset, end_offset = enumerate_vods(vods_m3u8, start, end) - vods_duration = sum(v.duration for v in vods) - duration = vods_duration - start_offset - end_offset - - print(f"{vods_duration=}") - print(f"{start_offset=}") - print(f"{end_offset=}") - print(f"{duration=}") + all_vods = parse_vods(vods_m3u8) + vods, start_offset, duration = filter_vods(all_vods, start, end) if args.dry_run: click.echo("Dry run, video not downloaded.") @@ -342,6 +338,33 @@ def _download_video(video_id: str, args: DownloadOptions) -> None: click.echo(f"\nDownloaded: {green(target)}") +def filter_vods( + vods: List[Vod], start: Optional[int], end: Optional[int] +) -> Tuple[List[Vod], Decimal, Decimal]: + vod_start = Decimal(0) + start_offset = Decimal(0) + end_offset = Decimal(0) + filtered_vods: List[Vod] = [] + + for vod in vods: + vod_end = vod_start + vod.duration + + if (not start or vod_end > start) and (not end or vod_start < end): + filtered_vods.append(vod) + + if start and start > vod_start and start < vod_end: + start_offset = start - vod_start + + if end and end > vod_start and end < vod_end: + end_offset = vod_end - end + + vod_start = vod_end + + filtered_vod_duration = sum(v.duration for v in filtered_vods) + duration = filtered_vod_duration - start_offset - end_offset + return filtered_vods, start_offset, duration + + def http_get(url: str) -> str: response = httpx.get(url) response.raise_for_status() diff --git a/twitchdl/playlists.py b/twitchdl/playlists.py index 72fdbc6..ecbee26 100644 --- a/twitchdl/playlists.py +++ b/twitchdl/playlists.py @@ -3,7 +3,8 @@ Parse and manipulate m3u8 playlists. """ from dataclasses import dataclass -from typing import Generator, List, Optional, OrderedDict, Tuple +from decimal import Decimal +from typing import Generator, List, Optional, OrderedDict import click import m3u8 @@ -27,7 +28,7 @@ class Vod: """Ordinal number of the VOD in the playlist""" path: str """Path part of the VOD URL""" - duration: int + duration: Decimal """Segment duration in seconds""" @@ -53,37 +54,11 @@ def load_m3u8(playlist_m3u8: str) -> m3u8.M3U8: return m3u8.loads(playlist_m3u8) -def enumerate_vods( - document: m3u8.M3U8, - start: Optional[int] = None, - end: Optional[int] = None, -) -> Tuple[List[Vod], int, int]: - """Extract VODs for download from document.""" - vods = [] - vod_start = 0 - - # How much time needs to be taken off by ffmpeg when joining - start_offset = 0 - end_offset = 0 - - for index, segment in enumerate(document.segments): - vod_end = vod_start + segment.duration - - start_condition = not start or vod_end > start - end_condition = not end or vod_start < end - - if start_condition and end_condition: - vods.append(Vod(index, segment.uri, segment.duration)) - - if start and start > vod_start and start < vod_end: - start_offset = start - vod_start - - if end and end > vod_start and end < vod_end: - end_offset = vod_end - end - - vod_start = vod_end - - return vods, int(start_offset), int(end_offset) +def parse_vods(document: m3u8.M3U8) -> List[Vod]: + return [ + Vod(index, segment.uri, Decimal(segment.duration)) + for index, segment in enumerate(document.segments) + ] def make_join_playlist(