Extract playlist parsing code

This commit is contained in:
Ivan Habunek 2024-04-06 10:15:26 +02:00
parent d6390bc7a2
commit a7ad4d8dcc
No known key found for this signature in database
GPG Key ID: F5F0623FF5EBCB3D
5 changed files with 166 additions and 107 deletions

View File

@ -1,9 +1,10 @@
import click
import logging
import platform
import re
import sys
import click
from twitchdl import __version__
from twitchdl.entities import DownloadOptions
from twitchdl.twitch import ClipsPeriod, VideosSort, VideosType
@ -258,7 +259,7 @@ def download(
overwrite: bool,
output: str,
quality: str | None,
rate_limit: str | None,
rate_limit: int | None,
start: int | None,
max_workers: int,
):

View File

@ -7,19 +7,24 @@ import subprocess
import tempfile
from os import path
from pathlib import Path
from typing import List, Optional, OrderedDict
from urllib.parse import urlencode, urlparse
import click
import httpx
import m3u8
from twitchdl import twitch, utils
from twitchdl.download import download_file
from twitchdl.entities import DownloadOptions
from twitchdl.exceptions import ConsoleError
from twitchdl.http import download_all
from twitchdl.output import blue, bold, dim, green, print_log, yellow
from twitchdl.output import blue, bold, green, print_log, yellow
from twitchdl.playlists import (
enumerate_vods,
load_m3u8,
make_join_playlist,
parse_playlists,
select_playlist,
)
from twitchdl.twitch import Chapter, Clip, ClipAccessToken, Video
@ -40,48 +45,7 @@ def download_one(video: str, args: DownloadOptions):
raise ConsoleError(f"Invalid input: {video}")
def _parse_playlists(playlists_m3u8):
playlists = m3u8.loads(playlists_m3u8)
for p in sorted(playlists.playlists, key=lambda p: p.stream_info.resolution is None):
if p.stream_info.resolution:
name = p.media[0].name
description = "x".join(str(r) for r in p.stream_info.resolution)
else:
name = p.media[0].group_id
description = None
yield name, description, p.uri
def _get_playlist_by_name(playlists, quality):
if quality == "source":
_, _, uri = playlists[0]
return uri
for name, _, uri in playlists:
if name == quality:
return uri
available = ", ".join([name for (name, _, _) in playlists])
msg = f"Quality '{quality}' not found. Available qualities are: {available}"
raise ConsoleError(msg)
def _select_playlist_interactive(playlists):
click.echo("\nAvailable qualities:")
for n, (name, resolution, uri) in enumerate(playlists):
if resolution:
click.echo(f"{n + 1}) {bold(name)} {dim(f'({resolution})')}")
else:
click.echo(f"{n + 1}) {bold(name)}")
no = utils.read_int("Choose quality", min=1, max=len(playlists) + 1, default=1)
_, _, uri = playlists[no - 1]
return uri
def _join_vods(playlist_path: str, target: str, overwrite: bool, video):
def _join_vods(playlist_path: str, target: str, overwrite: bool, video: Video):
description = video["description"] or ""
description = description.strip()
@ -183,26 +147,6 @@ def _clip_target_filename(clip: Clip, args: DownloadOptions):
raise ConsoleError(f"Invalid key {e} used in --output. Supported keys are: {supported}")
def _get_vod_paths(playlist, start: Optional[int], end: Optional[int]) -> List[str]:
"""Extract unique VOD paths for download from playlist."""
files = []
vod_start = 0
for segment in playlist.segments:
vod_end = vod_start + segment.duration
# `vod_end > start` is used here becuase it's better to download a bit
# more than a bit less, similar for the end condition
start_condition = not start or vod_end > start
end_condition = not end or vod_start < end
if start_condition and end_condition and segment.uri not in files:
files.append(segment.uri)
vod_start = vod_end
return files
def _crete_temp_dir(base_uri: str) -> str:
"""Create a temp dir to store downloads if it doesn't exist."""
path = urlparse(base_uri).path.lstrip("/")
@ -291,7 +235,7 @@ def _download_clip(slug: str, args: DownloadOptions) -> None:
click.echo(f"Downloaded: {blue(target)}")
def _download_video(video_id, args: DownloadOptions) -> None:
def _download_video(video_id: str, args: DownloadOptions) -> None:
if args.start and args.end and args.end <= args.start:
raise ConsoleError("End time must be greater than start time")
@ -301,9 +245,7 @@ def _download_video(video_id, args: DownloadOptions) -> None:
if not video:
raise ConsoleError(f"Video {video_id} not found")
title = video["title"]
user = video["creator"]["displayName"]
click.echo(f"Found: {blue(title)} by {yellow(user)}")
click.echo(f"Found: {blue(video['title'])} by {yellow(video['creator']['displayName'])}")
target = _video_target_filename(video, args)
click.echo(f"Output: {blue(target)}")
@ -321,50 +263,33 @@ def _download_video(video_id, args: DownloadOptions) -> None:
access_token = twitch.get_access_token(video_id, auth_token=args.auth_token)
print_log("Fetching playlists...")
playlists_m3u8 = twitch.get_playlists(video_id, access_token)
playlists = list(_parse_playlists(playlists_m3u8))
playlist_uri = (
_get_playlist_by_name(playlists, args.quality)
if args.quality
else _select_playlist_interactive(playlists)
)
playlists_text = twitch.get_playlists(video_id, access_token)
playlists = parse_playlists(playlists_text)
playlist = select_playlist(playlists, args.quality)
print_log("Fetching playlist...")
response = httpx.get(playlist_uri)
response.raise_for_status()
playlist = m3u8.loads(response.text)
vods_text = http_get(playlist.url)
vods_m3u8 = load_m3u8(vods_text)
vods = enumerate_vods(vods_m3u8, start, end)
base_uri = re.sub("/[^/]+$", "/", playlist_uri)
base_uri = re.sub("/[^/]+$", "/", playlist.url)
target_dir = _crete_temp_dir(base_uri)
vod_paths = _get_vod_paths(playlist, start, end)
# Save playlists for debugging purposes
with open(path.join(target_dir, "playlists.m3u8"), "w") as f:
f.write(playlists_m3u8)
f.write(playlists_text)
with open(path.join(target_dir, "playlist.m3u8"), "w") as f:
f.write(response.text)
f.write(vods_text)
click.echo(
f"\nDownloading {len(vod_paths)} VODs using {args.max_workers} workers to {target_dir}"
)
sources = [base_uri + path for path in vod_paths]
targets = [os.path.join(target_dir, f"{k:05d}.ts") for k, _ in enumerate(vod_paths)]
click.echo(f"\nDownloading {len(vods)} VODs using {args.max_workers} workers to {target_dir}")
sources = [base_uri + vod.path for vod in vods]
targets = [os.path.join(target_dir, f"{vod.index:05d}.ts") for vod in vods]
asyncio.run(download_all(sources, targets, args.max_workers, rate_limit=args.rate_limit))
# Make a modified playlist which references downloaded VODs
# Keep only the downloaded segments and skip the rest
org_segments = playlist.segments.copy()
path_map = OrderedDict(zip(vod_paths, targets))
playlist.segments.clear()
for segment in org_segments:
if segment.uri in path_map:
segment.uri = path_map[segment.uri]
playlist.segments.append(segment)
playlist_path = path.join(target_dir, "playlist_downloaded.m3u8")
playlist.dump(playlist_path)
join_playlist = make_join_playlist(vods_m3u8, vods, targets)
join_playlist_path = path.join(target_dir, "playlist_downloaded.m3u8")
join_playlist.dump(join_playlist_path) # type: ignore
click.echo()
if args.no_join:
@ -377,7 +302,7 @@ def _download_video(video_id, args: DownloadOptions) -> None:
_concat_vods(targets, target)
else:
print_log("Joining files...")
_join_vods(playlist_path, target, args.overwrite, video)
_join_vods(join_playlist_path, target, args.overwrite, video)
click.echo()
@ -390,6 +315,12 @@ def _download_video(video_id, args: DownloadOptions) -> None:
click.echo(f"\nDownloaded: {green(target)}")
def http_get(url: str) -> str:
response = httpx.get(url)
response.raise_for_status()
return response.text
def _determine_time_range(video_id: str, args: DownloadOptions):
if args.start or args.end:
return args.start, args.end

View File

@ -15,7 +15,7 @@ class DownloadOptions:
overwrite: bool
output: str
quality: str | None
rate_limit: str | None
rate_limit: int | None
start: int | None
max_workers: int

127
twitchdl/playlists.py Normal file
View File

@ -0,0 +1,127 @@
"""
Parse and manipulate m3u8 playlists.
"""
from dataclasses import dataclass
from typing import Generator, OrderedDict
import click
import m3u8
from twitchdl import utils
from twitchdl.output import bold, dim
@dataclass
class Playlist:
name: str
resolution: str | None
url: str
@dataclass
class Vod:
index: int
"""Ordinal number of the VOD in the playlist"""
path: str
"""Path part of the VOD URL"""
duration: int
"""Segment duration in seconds"""
def parse_playlists(playlists_m3u8: str):
def _parse(source: str) -> Generator[Playlist, None, None]:
document = load_m3u8(source)
for p in document.playlists:
if p.stream_info.resolution:
name = p.media[0].name
resolution = "x".join(str(r) for r in p.stream_info.resolution)
else:
name = p.media[0].group_id
resolution = None
yield Playlist(name, resolution, p.uri)
# Move audio to bottom, it has no resolution
return sorted(_parse(playlists_m3u8), key=lambda p: p.resolution is None)
def load_m3u8(playlist_m3u8: str) -> m3u8.M3U8:
return m3u8.loads(playlist_m3u8)
def enumerate_vods(document: m3u8.M3U8, start: int | None, end: int | None) -> list[Vod]:
"""Extract VODs for download from document."""
vods = []
vod_start = 0
for index, segment in enumerate(document.segments):
vod_end = vod_start + segment.duration
# `vod_end > start` is used here becuase it's better to download a bit
# more than a bit less, similar for the end condition
start_condition = not start or vod_end > start
end_condition = not end or vod_start < end
if start_condition and end_condition:
vods.append(Vod(index, segment.uri, segment.duration))
vod_start = vod_end
return vods
def make_join_playlist(
playlist: m3u8.M3U8,
vods: list[Vod],
targets: list[str],
) -> m3u8.Playlist:
"""
Make a modified playlist which references downloaded VODs
Keep only the downloaded segments and skip the rest
"""
org_segments = playlist.segments.copy()
path_map = OrderedDict(zip([v.path for v in vods], targets))
playlist.segments.clear()
for segment in org_segments:
if segment.uri in path_map:
segment.uri = path_map[segment.uri]
playlist.segments.append(segment)
return playlist
def select_playlist(playlists: list[Playlist], quality: str | None) -> Playlist:
return (
select_playlist_by_name(playlists, quality)
if quality is not None
else select_playlist_interactive(playlists)
)
def select_playlist_by_name(playlists: list[Playlist], quality: str) -> Playlist:
if quality == "source":
return playlists[0]
for playlist in playlists:
if playlist.name == quality:
return playlist
available = ", ".join([p.name for p in playlists])
msg = f"Quality '{quality}' not found. Available qualities are: {available}"
raise click.ClickException(msg)
def select_playlist_interactive(playlists: list[Playlist]) -> Playlist:
click.echo("\nAvailable qualities:")
for n, playlist in enumerate(playlists):
if playlist.resolution:
click.echo(f"{n + 1}) {bold(playlist.name)} {dim(f'({playlist.resolution})')}")
else:
click.echo(f"{n + 1}) {bold(playlist.name)}")
no = utils.read_int("Choose quality", min=1, max=len(playlists) + 1, default=1)
playlist = playlists[no - 1]
return playlist

View File

@ -403,7 +403,7 @@ def get_access_token(video_id: str, auth_token: str | None = None) -> AccessToke
raise
def get_playlists(video_id: str, access_token: AccessToken):
def get_playlists(video_id: str, access_token: AccessToken) -> str:
"""
For a given video return a playlist which contains possible video qualities.
"""