diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 0ba69a53..e4d06eb0 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -24,7 +24,8 @@ After you have created the authentication JSON, you can instantiate the class: .. code-block:: python from ytmusicapi import YTMusic - ytmusic = YTMusic("oauth.json") + ytmusic = YTMusic("browser.json") # or, alternatively + ytmusic = YTMusic("oauth.json", oauth_credentials=OAuthCredentials(client_id=client_id, client_secret=client_secret) With the :code:`ytmusic` instance you can now perform authenticated requests: diff --git a/tests/auth/test_oauth.py b/tests/auth/test_oauth.py index aae772f7..bea97c64 100644 --- a/tests/auth/test_oauth.py +++ b/tests/auth/test_oauth.py @@ -10,6 +10,7 @@ from ytmusicapi.auth.oauth import OAuthToken from ytmusicapi.auth.types import AuthType +from ytmusicapi.exceptions import YTMusicUserError from ytmusicapi.setup import main from ytmusicapi.ytmusic import OAuthCredentials, YTMusic @@ -114,9 +115,19 @@ def test_oauth_custom_client( assert yt_alt_oauth.auth_type != AuthType.OAUTH_CUSTOM_CLIENT with open(oauth_filepath) as f: token_dict = json.load(f) + # oauth token dict entry and alt yt_alt_oauth = YTMusic(token_dict, oauth_credentials=alt_oauth_credentials) assert yt_alt_oauth.auth_type == AuthType.OAUTH_CUSTOM_CLIENT + # forgot to pass OAuth credentials - should raise + with pytest.raises(YTMusicUserError): + YTMusic(token_dict) + + # oauth custom full + token_dict["authorization"] = "Bearer DKLEK23" + yt_alt_oauth = YTMusic(token_dict, oauth_credentials=alt_oauth_credentials) + assert yt_alt_oauth.auth_type == AuthType.OAUTH_CUSTOM_FULL + def test_alt_oauth_request(self, yt_alt_oauth: YTMusic, sample_video): yt_alt_oauth.get_watch_playlist(sample_video) diff --git a/tests/test_ytmusic.py b/tests/test_ytmusic.py index 1e5b0712..51ce16b6 100644 --- a/tests/test_ytmusic.py +++ b/tests/test_ytmusic.py @@ -1,4 +1,7 @@ +from functools import partial + import pytest +import requests from ytmusicapi import YTMusic from ytmusicapi.exceptions import YTMusicUserError @@ -12,3 +15,14 @@ def test_ytmusic_context(): def test_ytmusic_auth_error(): with pytest.raises(YTMusicUserError, match="Invalid auth"): YTMusic(auth="def") + + +def test_ytmusic_session(): + test_session = requests.Session() + test_session.request = partial(test_session.request, timeout=60) + ytmusic = YTMusic(requests_session=test_session) + assert ytmusic._session == test_session + + ytmusic = YTMusic() + assert isinstance(ytmusic._session, requests.Session) + assert ytmusic._session != test_session diff --git a/ytmusicapi/auth/auth_parse.py b/ytmusicapi/auth/auth_parse.py new file mode 100644 index 00000000..4d9abc1c --- /dev/null +++ b/ytmusicapi/auth/auth_parse.py @@ -0,0 +1,51 @@ +import json +from pathlib import Path +from typing import Optional, Union + +from requests.structures import CaseInsensitiveDict + +from ytmusicapi.auth.oauth import OAuthToken +from ytmusicapi.auth.types import AuthType +from ytmusicapi.exceptions import YTMusicUserError + + +def parse_auth_str(auth: Union[str, dict]) -> tuple[CaseInsensitiveDict, Optional[Path]]: + """ + + :param auth: user-provided auth string or dict + :return: parsed header dict based on auth, optionally path to file if it auth was a path to a file + """ + auth_path: Optional[Path] = None + if isinstance(auth, str): + auth_str: str = auth + if auth.startswith("{"): + input_json = json.loads(auth_str) + elif (auth_path := Path(auth_str)).is_file(): + with open(auth_path) as json_file: + input_json = json.load(json_file) + else: + raise YTMusicUserError("Invalid auth JSON string or file path provided.") + return CaseInsensitiveDict(input_json), auth_path + + else: + return CaseInsensitiveDict(auth), auth_path + + +def determine_auth_type(auth_headers: CaseInsensitiveDict) -> AuthType: + """ + Determine the type of auth based on auth headers. + + :param auth_headers: auth headers dict + :return: AuthType enum + """ + auth_type = AuthType.OAUTH_CUSTOM_CLIENT + if OAuthToken.is_oauth(auth_headers): + auth_type = AuthType.OAUTH_CUSTOM_CLIENT + + if authorization := auth_headers.get("authorization"): + if "SAPISIDHASH" in authorization: + auth_type = AuthType.BROWSER + elif authorization.startswith("Bearer"): + auth_type = AuthType.OAUTH_CUSTOM_FULL + + return auth_type diff --git a/ytmusicapi/auth/types.py b/ytmusicapi/auth/types.py index fcddd752..0e5cc529 100644 --- a/ytmusicapi/auth/types.py +++ b/ytmusicapi/auth/types.py @@ -10,15 +10,8 @@ class AuthType(int, Enum): BROWSER = auto() - #: client auth via OAuth token refreshing - OAUTH_DEFAULT = auto() - #: YTM instance is using a non-default OAuth client (id & secret) OAUTH_CUSTOM_CLIENT = auto() #: allows fully formed OAuth headers to ignore browser auth refresh flow OAUTH_CUSTOM_FULL = auto() - - @classmethod - def oauth_types(cls) -> list["AuthType"]: - return [cls.OAUTH_DEFAULT, cls.OAUTH_CUSTOM_CLIENT, cls.OAUTH_CUSTOM_FULL] diff --git a/ytmusicapi/ytmusic.py b/ytmusicapi/ytmusic.py index ce2fbd63..b96535b7 100644 --- a/ytmusicapi/ytmusic.py +++ b/ytmusicapi/ytmusic.py @@ -35,7 +35,8 @@ from ytmusicapi.mixins.watch import WatchMixin from ytmusicapi.parsers.i18n import Parser -from .auth.oauth import OAuthCredentials, OAuthToken, RefreshingToken +from .auth.auth_parse import determine_auth_type, parse_auth_str +from .auth.oauth import OAuthCredentials, RefreshingToken from .auth.oauth.token import Token from .auth.types import AuthType from .exceptions import YTMusicServerError, YTMusicUserError @@ -46,7 +47,7 @@ def __init__( self, auth: Optional[Union[str, dict]] = None, user: Optional[str] = None, - requests_session=True, + requests_session: Optional[requests.Session] = None, proxies: Optional[dict[str, str]] = None, language: str = "en", location: str = "", @@ -64,7 +65,7 @@ def __init__( Otherwise the default account is used. You can retrieve the user ID by going to https://myaccount.google.com/brandaccounts and selecting your brand account. The user ID will be in the URL: https://myaccount.google.com/b/user_id/ - :param requests_session: A Requests session object or a truthy value to create one. + :param requests_session: A Requests session object or None to create one. Default sessions have a request timeout of 30s, which produces a requests.exceptions.ReadTimeout. The timeout can be changed by passing your own Session object:: @@ -72,9 +73,6 @@ def __init__( s.request = functools.partial(s.request, timeout=3) ytm = YTMusic(requests_session=s) - A falsy value disables sessions. - It is generally a good idea to keep sessions enabled for - performance reasons (connection pooling). :param proxies: Optional. Proxy configuration in requests_ format_. .. _requests: https://requests.readthedocs.io/ @@ -89,59 +87,30 @@ def __init__( :param oauth_credentials: Optional. Used to specify a different oauth client to be used for authentication flow. """ - - self._base_headers: Optional[CaseInsensitiveDict] = ( - None #: for authless initializing requests during OAuth flow - ) - self._headers: Optional[CaseInsensitiveDict] = None #: cache formed headers including auth - - self.auth = auth #: raw auth - self._input_dict: CaseInsensitiveDict = ( - CaseInsensitiveDict() - ) #: parsed auth arg value in dictionary format - - self.auth_type: AuthType = AuthType.UNAUTHORIZED - - self._token: Token #: OAuth credential handler - self.oauth_credentials: Optional[OAuthCredentials] #: Client used for OAuth refreshing - - self._session: requests.Session #: request session for connection pooling + #: request session for connection pooling + self._session = self._prepare_session(requests_session) self.proxies: Optional[dict[str, str]] = proxies #: params for session modification - - if isinstance(requests_session, requests.Session): - self._session = requests_session - else: - if requests_session: # Build a new session. - self._session = requests.Session() - self._session.request = partial(self._session.request, timeout=30) # type: ignore[method-assign] - else: # Use the Requests API module as a "session". - self._session = requests.api # type: ignore[assignment] - # see google cookie docs: https://policies.google.com/technologies/cookies # value from https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/extractor/youtube.py#L502 self.cookies = {"SOCS": "CAI"} - if self.auth is not None: - self.oauth_credentials = oauth_credentials - auth_path: Optional[Path] = None - if isinstance(self.auth, str): - auth_str: str = self.auth - if self.auth.startswith("{"): - input_json = json.loads(auth_str) - elif (auth_path := Path(auth_str)).is_file(): - with open(auth_path) as json_file: - input_json = json.load(json_file) - else: - raise YTMusicUserError("Invalid auth JSON string or file path provided.") - self._input_dict = CaseInsensitiveDict(input_json) - - else: - self._input_dict = CaseInsensitiveDict(self.auth) - - if self.oauth_credentials is not None and OAuthToken.is_oauth(self._input_dict): + + self._auth_headers: CaseInsensitiveDict = CaseInsensitiveDict() + self.auth_type = AuthType.UNAUTHORIZED + if auth is not None: + self._auth_headers, auth_path = parse_auth_str(auth) + self.auth_type = determine_auth_type(self._auth_headers) + + self._token: Token + if self.auth_type == AuthType.OAUTH_CUSTOM_CLIENT: + if oauth_credentials is None: + raise YTMusicUserError( + "oauth JSON provided via auth argument, but oauth_credentials not provided." + "Please provide oauth_credentials as specified in the OAuth setup documentation." + ) + #: OAuth credential handler self._token = RefreshingToken( - credentials=self.oauth_credentials, _local_cache=auth_path, **self._input_dict + credentials=oauth_credentials, _local_cache=auth_path, **self._auth_headers ) - self.auth_type = AuthType.OAUTH_CUSTOM_CLIENT if oauth_credentials else AuthType.OAUTH_DEFAULT # prepare context self.context = initialize_context() @@ -170,13 +139,6 @@ def __init__( if user: self.context["context"]["user"]["onBehalfOfUser"] = user - auth_headers = self._input_dict.get("authorization") - if auth_headers: - if "SAPISIDHASH" in auth_headers: - self.auth_type = AuthType.BROWSER - elif auth_headers.startswith("Bearer"): - self.auth_type = AuthType.OAUTH_CUSTOM_FULL - # sapsid, origin, and params all set once during init self.params = YTM_PARAMS if self.auth_type == AuthType.BROWSER: @@ -190,40 +152,38 @@ def __init__( @property def base_headers(self) -> CaseInsensitiveDict: - if not self._base_headers: - if self.auth_type == AuthType.BROWSER or self.auth_type == AuthType.OAUTH_CUSTOM_FULL: - self._base_headers = self._input_dict - else: - self._base_headers = CaseInsensitiveDict( - { - "user-agent": USER_AGENT, - "accept": "*/*", - "accept-encoding": "gzip, deflate", - "content-type": "application/json", - "content-encoding": "gzip", - "origin": YTM_DOMAIN, - } - ) - - return self._base_headers + if self.auth_type == AuthType.BROWSER or self.auth_type == AuthType.OAUTH_CUSTOM_FULL: + return self._auth_headers + + return CaseInsensitiveDict( + { + "user-agent": USER_AGENT, + "accept": "*/*", + "accept-encoding": "gzip, deflate", + "content-type": "application/json", + "content-encoding": "gzip", + "origin": YTM_DOMAIN, + } + ) @property def headers(self) -> CaseInsensitiveDict: - # set on first use - if not self._headers: - self._headers = self.base_headers + headers = self.base_headers + + if "X-Goog-Visitor-Id" not in headers: + headers.update(get_visitor_id(partial(self._send_get_request, use_base_headers=True))) # keys updated each use, custom oauth implementations left untouched if self.auth_type == AuthType.BROWSER: - self._headers["authorization"] = get_authorization(self.sapisid + " " + self.origin) + headers["authorization"] = get_authorization(self.sapisid + " " + self.origin) # Do not set custom headers when using OAUTH_CUSTOM_FULL # Full headers are provided by the downstream client in this scenario. - elif self.auth_type in [x for x in AuthType.oauth_types() if x != AuthType.OAUTH_CUSTOM_FULL]: - self._headers["authorization"] = self._token.as_auth() - self._headers["X-Goog-Request-Time"] = str(int(time.time())) + elif self.auth_type == AuthType.OAUTH_CUSTOM_CLIENT: + headers["authorization"] = self._token.as_auth() + headers["X-Goog-Request-Time"] = str(int(time.time())) - return self._headers + return headers @contextmanager def as_mobile(self) -> Iterator[None]: @@ -259,13 +219,17 @@ def as_mobile(self) -> Iterator[None]: # safely restore the old context self.context["context"]["client"] = copied_context_client + def _prepare_session(self, requests_session: Optional[requests.Session]) -> requests.Session: + """Prepare requests session or use user-provided requests_session""" + if isinstance(requests_session, requests.Session): + return requests_session + self._session = requests.Session() + self._session.request = partial(self._session.request, timeout=30) # type: ignore[method-assign] + return self._session + def _send_request(self, endpoint: str, body: dict, additionalParams: str = "") -> dict: body.update(self.context) - # only required for post requests (?) - if self._headers and "X-Goog-Visitor-Id" not in self._headers: - self._headers.update(get_visitor_id(self._send_get_request)) - response = self._session.post( YTM_BASE_API + endpoint + self.params + additionalParams, json=body, @@ -280,19 +244,21 @@ def _send_request(self, endpoint: str, body: dict, additionalParams: str = "") - raise YTMusicServerError(message + error) return response_text - def _send_get_request(self, url: str, params: Optional[dict] = None) -> Response: + def _send_get_request( + self, url: str, params: Optional[dict] = None, use_base_headers: bool = False + ) -> Response: response = self._session.get( url, params=params, # handle first-use x-goog-visitor-id fetching - headers=self.headers if self._headers else self.base_headers, + headers=self.base_headers if use_base_headers else self.headers, proxies=self.proxies, cookies=self.cookies, ) return response def _check_auth(self): - if not self.auth: + if self.auth_type == AuthType.UNAUTHORIZED: raise YTMusicUserError("Please provide authentication before using this function") def __enter__(self):