Skip to content

Commit

Permalink
refactor YTMusicBase, improve oauth doc (#727)
Browse files Browse the repository at this point in the history
* refactor YTMusicBase, improve oauth doc

* fix recursion issue

* fix bugs, fix coverage, add session test
  • Loading branch information
sigma67 authored Jan 19, 2025
1 parent af04fa4 commit a004bed
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 99 deletions.
3 changes: 2 additions & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ After you have created the authentication JSON, you can instantiate the class:
.. code-block:: python
from ytmusicapi import YTMusic
ytmusic = YTMusic("oauth.json")
ytmusic = YTMusic("browser.json") # or, alternatively
ytmusic = YTMusic("oauth.json", oauth_credentials=OAuthCredentials(client_id=client_id, client_secret=client_secret)
With the :code:`ytmusic` instance you can now perform authenticated requests:
Expand Down
11 changes: 11 additions & 0 deletions tests/auth/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from ytmusicapi.auth.oauth import OAuthToken
from ytmusicapi.auth.types import AuthType
from ytmusicapi.exceptions import YTMusicUserError
from ytmusicapi.setup import main
from ytmusicapi.ytmusic import OAuthCredentials, YTMusic

Expand Down Expand Up @@ -114,9 +115,19 @@ def test_oauth_custom_client(
assert yt_alt_oauth.auth_type != AuthType.OAUTH_CUSTOM_CLIENT
with open(oauth_filepath) as f:
token_dict = json.load(f)

# oauth token dict entry and alt
yt_alt_oauth = YTMusic(token_dict, oauth_credentials=alt_oauth_credentials)
assert yt_alt_oauth.auth_type == AuthType.OAUTH_CUSTOM_CLIENT

# forgot to pass OAuth credentials - should raise
with pytest.raises(YTMusicUserError):
YTMusic(token_dict)

# oauth custom full
token_dict["authorization"] = "Bearer DKLEK23"
yt_alt_oauth = YTMusic(token_dict, oauth_credentials=alt_oauth_credentials)
assert yt_alt_oauth.auth_type == AuthType.OAUTH_CUSTOM_FULL

def test_alt_oauth_request(self, yt_alt_oauth: YTMusic, sample_video):
yt_alt_oauth.get_watch_playlist(sample_video)
14 changes: 14 additions & 0 deletions tests/test_ytmusic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from functools import partial

import pytest
import requests

from ytmusicapi import YTMusic
from ytmusicapi.exceptions import YTMusicUserError
Expand All @@ -12,3 +15,14 @@ def test_ytmusic_context():
def test_ytmusic_auth_error():
with pytest.raises(YTMusicUserError, match="Invalid auth"):
YTMusic(auth="def")


def test_ytmusic_session():
test_session = requests.Session()
test_session.request = partial(test_session.request, timeout=60)
ytmusic = YTMusic(requests_session=test_session)
assert ytmusic._session == test_session

ytmusic = YTMusic()
assert isinstance(ytmusic._session, requests.Session)
assert ytmusic._session != test_session
51 changes: 51 additions & 0 deletions ytmusicapi/auth/auth_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import json
from pathlib import Path
from typing import Optional, Union

from requests.structures import CaseInsensitiveDict

from ytmusicapi.auth.oauth import OAuthToken
from ytmusicapi.auth.types import AuthType
from ytmusicapi.exceptions import YTMusicUserError


def parse_auth_str(auth: Union[str, dict]) -> tuple[CaseInsensitiveDict, Optional[Path]]:
"""
:param auth: user-provided auth string or dict
:return: parsed header dict based on auth, optionally path to file if it auth was a path to a file
"""
auth_path: Optional[Path] = None
if isinstance(auth, str):
auth_str: str = auth
if auth.startswith("{"):
input_json = json.loads(auth_str)
elif (auth_path := Path(auth_str)).is_file():
with open(auth_path) as json_file:
input_json = json.load(json_file)
else:
raise YTMusicUserError("Invalid auth JSON string or file path provided.")
return CaseInsensitiveDict(input_json), auth_path

else:
return CaseInsensitiveDict(auth), auth_path


def determine_auth_type(auth_headers: CaseInsensitiveDict) -> AuthType:
"""
Determine the type of auth based on auth headers.
:param auth_headers: auth headers dict
:return: AuthType enum
"""
auth_type = AuthType.OAUTH_CUSTOM_CLIENT
if OAuthToken.is_oauth(auth_headers):
auth_type = AuthType.OAUTH_CUSTOM_CLIENT

if authorization := auth_headers.get("authorization"):
if "SAPISIDHASH" in authorization:
auth_type = AuthType.BROWSER
elif authorization.startswith("Bearer"):
auth_type = AuthType.OAUTH_CUSTOM_FULL

return auth_type
7 changes: 0 additions & 7 deletions ytmusicapi/auth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,8 @@ class AuthType(int, Enum):

BROWSER = auto()

#: client auth via OAuth token refreshing
OAUTH_DEFAULT = auto()

#: YTM instance is using a non-default OAuth client (id & secret)
OAUTH_CUSTOM_CLIENT = auto()

#: allows fully formed OAuth headers to ignore browser auth refresh flow
OAUTH_CUSTOM_FULL = auto()

@classmethod
def oauth_types(cls) -> list["AuthType"]:
return [cls.OAUTH_DEFAULT, cls.OAUTH_CUSTOM_CLIENT, cls.OAUTH_CUSTOM_FULL]
148 changes: 57 additions & 91 deletions ytmusicapi/ytmusic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from ytmusicapi.mixins.watch import WatchMixin
from ytmusicapi.parsers.i18n import Parser

from .auth.oauth import OAuthCredentials, OAuthToken, RefreshingToken
from .auth.auth_parse import determine_auth_type, parse_auth_str
from .auth.oauth import OAuthCredentials, RefreshingToken
from .auth.oauth.token import Token
from .auth.types import AuthType
from .exceptions import YTMusicServerError, YTMusicUserError
Expand All @@ -46,7 +47,7 @@ def __init__(
self,
auth: Optional[Union[str, dict]] = None,
user: Optional[str] = None,
requests_session=True,
requests_session: Optional[requests.Session] = None,
proxies: Optional[dict[str, str]] = None,
language: str = "en",
location: str = "",
Expand All @@ -64,17 +65,14 @@ def __init__(
Otherwise the default account is used. You can retrieve the user ID
by going to https://myaccount.google.com/brandaccounts and selecting your brand account.
The user ID will be in the URL: https://myaccount.google.com/b/user_id/
:param requests_session: A Requests session object or a truthy value to create one.
:param requests_session: A Requests session object or None to create one.
Default sessions have a request timeout of 30s, which produces a requests.exceptions.ReadTimeout.
The timeout can be changed by passing your own Session object::
s = requests.Session()
s.request = functools.partial(s.request, timeout=3)
ytm = YTMusic(requests_session=s)
A falsy value disables sessions.
It is generally a good idea to keep sessions enabled for
performance reasons (connection pooling).
:param proxies: Optional. Proxy configuration in requests_ format_.
.. _requests: https://requests.readthedocs.io/
Expand All @@ -89,59 +87,30 @@ def __init__(
:param oauth_credentials: Optional. Used to specify a different oauth client to be
used for authentication flow.
"""

self._base_headers: Optional[CaseInsensitiveDict] = (
None #: for authless initializing requests during OAuth flow
)
self._headers: Optional[CaseInsensitiveDict] = None #: cache formed headers including auth

self.auth = auth #: raw auth
self._input_dict: CaseInsensitiveDict = (
CaseInsensitiveDict()
) #: parsed auth arg value in dictionary format

self.auth_type: AuthType = AuthType.UNAUTHORIZED

self._token: Token #: OAuth credential handler
self.oauth_credentials: Optional[OAuthCredentials] #: Client used for OAuth refreshing

self._session: requests.Session #: request session for connection pooling
#: request session for connection pooling
self._session = self._prepare_session(requests_session)
self.proxies: Optional[dict[str, str]] = proxies #: params for session modification

if isinstance(requests_session, requests.Session):
self._session = requests_session
else:
if requests_session: # Build a new session.
self._session = requests.Session()
self._session.request = partial(self._session.request, timeout=30) # type: ignore[method-assign]
else: # Use the Requests API module as a "session".
self._session = requests.api # type: ignore[assignment]

# see google cookie docs: https://policies.google.com/technologies/cookies
# value from https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/extractor/youtube.py#L502
self.cookies = {"SOCS": "CAI"}
if self.auth is not None:
self.oauth_credentials = oauth_credentials
auth_path: Optional[Path] = None
if isinstance(self.auth, str):
auth_str: str = self.auth
if self.auth.startswith("{"):
input_json = json.loads(auth_str)
elif (auth_path := Path(auth_str)).is_file():
with open(auth_path) as json_file:
input_json = json.load(json_file)
else:
raise YTMusicUserError("Invalid auth JSON string or file path provided.")
self._input_dict = CaseInsensitiveDict(input_json)

else:
self._input_dict = CaseInsensitiveDict(self.auth)

if self.oauth_credentials is not None and OAuthToken.is_oauth(self._input_dict):

self._auth_headers: CaseInsensitiveDict = CaseInsensitiveDict()
self.auth_type = AuthType.UNAUTHORIZED
if auth is not None:
self._auth_headers, auth_path = parse_auth_str(auth)
self.auth_type = determine_auth_type(self._auth_headers)

self._token: Token
if self.auth_type == AuthType.OAUTH_CUSTOM_CLIENT:
if oauth_credentials is None:
raise YTMusicUserError(
"oauth JSON provided via auth argument, but oauth_credentials not provided."
"Please provide oauth_credentials as specified in the OAuth setup documentation."
)
#: OAuth credential handler
self._token = RefreshingToken(
credentials=self.oauth_credentials, _local_cache=auth_path, **self._input_dict
credentials=oauth_credentials, _local_cache=auth_path, **self._auth_headers
)
self.auth_type = AuthType.OAUTH_CUSTOM_CLIENT if oauth_credentials else AuthType.OAUTH_DEFAULT

# prepare context
self.context = initialize_context()
Expand Down Expand Up @@ -170,13 +139,6 @@ def __init__(
if user:
self.context["context"]["user"]["onBehalfOfUser"] = user

auth_headers = self._input_dict.get("authorization")
if auth_headers:
if "SAPISIDHASH" in auth_headers:
self.auth_type = AuthType.BROWSER
elif auth_headers.startswith("Bearer"):
self.auth_type = AuthType.OAUTH_CUSTOM_FULL

# sapsid, origin, and params all set once during init
self.params = YTM_PARAMS
if self.auth_type == AuthType.BROWSER:
Expand All @@ -190,40 +152,38 @@ def __init__(

@property
def base_headers(self) -> CaseInsensitiveDict:
if not self._base_headers:
if self.auth_type == AuthType.BROWSER or self.auth_type == AuthType.OAUTH_CUSTOM_FULL:
self._base_headers = self._input_dict
else:
self._base_headers = CaseInsensitiveDict(
{
"user-agent": USER_AGENT,
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"content-type": "application/json",
"content-encoding": "gzip",
"origin": YTM_DOMAIN,
}
)

return self._base_headers
if self.auth_type == AuthType.BROWSER or self.auth_type == AuthType.OAUTH_CUSTOM_FULL:
return self._auth_headers

return CaseInsensitiveDict(
{
"user-agent": USER_AGENT,
"accept": "*/*",
"accept-encoding": "gzip, deflate",
"content-type": "application/json",
"content-encoding": "gzip",
"origin": YTM_DOMAIN,
}
)

@property
def headers(self) -> CaseInsensitiveDict:
# set on first use
if not self._headers:
self._headers = self.base_headers
headers = self.base_headers

if "X-Goog-Visitor-Id" not in headers:
headers.update(get_visitor_id(partial(self._send_get_request, use_base_headers=True)))

# keys updated each use, custom oauth implementations left untouched
if self.auth_type == AuthType.BROWSER:
self._headers["authorization"] = get_authorization(self.sapisid + " " + self.origin)
headers["authorization"] = get_authorization(self.sapisid + " " + self.origin)

# Do not set custom headers when using OAUTH_CUSTOM_FULL
# Full headers are provided by the downstream client in this scenario.
elif self.auth_type in [x for x in AuthType.oauth_types() if x != AuthType.OAUTH_CUSTOM_FULL]:
self._headers["authorization"] = self._token.as_auth()
self._headers["X-Goog-Request-Time"] = str(int(time.time()))
elif self.auth_type == AuthType.OAUTH_CUSTOM_CLIENT:
headers["authorization"] = self._token.as_auth()
headers["X-Goog-Request-Time"] = str(int(time.time()))

return self._headers
return headers

@contextmanager
def as_mobile(self) -> Iterator[None]:
Expand Down Expand Up @@ -259,13 +219,17 @@ def as_mobile(self) -> Iterator[None]:
# safely restore the old context
self.context["context"]["client"] = copied_context_client

def _prepare_session(self, requests_session: Optional[requests.Session]) -> requests.Session:
"""Prepare requests session or use user-provided requests_session"""
if isinstance(requests_session, requests.Session):
return requests_session
self._session = requests.Session()
self._session.request = partial(self._session.request, timeout=30) # type: ignore[method-assign]
return self._session

def _send_request(self, endpoint: str, body: dict, additionalParams: str = "") -> dict:
body.update(self.context)

# only required for post requests (?)
if self._headers and "X-Goog-Visitor-Id" not in self._headers:
self._headers.update(get_visitor_id(self._send_get_request))

response = self._session.post(
YTM_BASE_API + endpoint + self.params + additionalParams,
json=body,
Expand All @@ -280,19 +244,21 @@ def _send_request(self, endpoint: str, body: dict, additionalParams: str = "") -
raise YTMusicServerError(message + error)
return response_text

def _send_get_request(self, url: str, params: Optional[dict] = None) -> Response:
def _send_get_request(
self, url: str, params: Optional[dict] = None, use_base_headers: bool = False
) -> Response:
response = self._session.get(
url,
params=params,
# handle first-use x-goog-visitor-id fetching
headers=self.headers if self._headers else self.base_headers,
headers=self.base_headers if use_base_headers else self.headers,
proxies=self.proxies,
cookies=self.cookies,
)
return response

def _check_auth(self):
if not self.auth:
if self.auth_type == AuthType.UNAUTHORIZED:
raise YTMusicUserError("Please provide authentication before using this function")

def __enter__(self):
Expand Down

0 comments on commit a004bed

Please sign in to comment.