# Adapted from: https://github.com/encode/httpx/blob/master/httpx/_models.py, # which is licensed under the BSD License. # See https://github.com/encode/httpx/blob/master/LICENSE.md __all__ = ["Cookies"] import re import time import warnings from dataclasses import dataclass from http.cookiejar import Cookie, CookieJar from http.cookies import _unquote from typing import Optional, Union from collections.abc import Iterator, MutableMapping from urllib.parse import urlparse from ..utils import CurlCffiWarning from .errors import CookieConflict, RequestsError CookieTypes = Union["Cookies", CookieJar, dict[str, str], list[tuple[str, str]]] @dataclass class CurlMorsel: name: str value: str hostname: str = "" subdomains: bool = False path: str = "/" secure: bool = False expires: int = 0 http_only: bool = False @staticmethod def parse_bool(s): return s == "TRUE" @staticmethod def dump_bool(s): return "TRUE" if s else "FALSE" @classmethod def from_curl_format(cls, set_cookie_line: bytes): ( hostname, subdomains, path, secure, expires, name, value, ) = set_cookie_line.decode().split("\t") if hostname and hostname[0] == "#": http_only = True # e.g. #HttpOnly_postman-echo.com domain = hostname[10:] # len("#HttpOnly_") == 10 else: http_only = False domain = hostname return cls( hostname=domain, subdomains=cls.parse_bool(subdomains), path=path, secure=cls.parse_bool(secure), expires=int(expires), name=name, value=_unquote(value), http_only=http_only, ) def to_curl_format(self): if not self.hostname: raise RequestsError(f"Domain not found for cookie {self.name}={self.value}") return "\t".join( [ self.hostname, self.dump_bool(self.subdomains), self.path, self.dump_bool(self.secure), str(self.expires), self.name, self.value, ] ) @classmethod def from_cookiejar_cookie(cls, cookie: Cookie): return cls( name=cookie.name, value=cookie.value or "", hostname=cookie.domain, subdomains=cookie.domain_specified, path=cookie.path, secure=cookie.secure, expires=int(cookie.expires or 0), http_only=False, ) def to_cookiejar_cookie(self) -> Cookie: # the leading dot actually does not mean anything nowadays # https://stackoverflow.com/a/20884869/1061155 # https://github.com/python/cpython/blob/d6555abfa7384b5a40435a11bdd2aa6bbf8f5cfc/Lib/http/cookiejar.py#L1535 return Cookie( version=0, name=self.name, value=self.value, port=None, port_specified=False, domain=self.hostname, domain_specified=self.subdomains, domain_initial_dot=bool(self.hostname.startswith(".")), path=self.path, path_specified=bool(self.path), secure=self.secure, # using if explicitly to make it clear. expires=None if self.expires == 0 else self.expires, discard=self.expires == 0, comment=None, comment_url=None, rest=dict(http_only=f"{self.http_only}"), rfc2109=False, ) cut_port_re = re.compile(r":\d+$", re.ASCII) IPV4_RE = re.compile(r"\.\d+$", re.ASCII) class Cookies(MutableMapping[str, str]): """ HTTP Cookies, as a mutable mapping. """ def __init__(self, cookies: Optional[CookieTypes] = None) -> None: if cookies is None or isinstance(cookies, dict): self.jar = CookieJar() if isinstance(cookies, dict): for key, value in cookies.items(): self.set(key, value) elif isinstance(cookies, list): self.jar = CookieJar() for key, value in cookies: self.set(key, value) elif isinstance(cookies, Cookies): self.jar = CookieJar() for cookie in cookies.jar: self.jar.set_cookie(cookie) else: self.jar = cookies def _eff_request_host(self, request) -> str: """ Almost equivalent to the eff_request_host function in: https://github.com/python/cpython/blob/3.11/Lib/http/cookiejar.py#L636 """ host = urlparse(request.url)[1] if host == "": host = request.headers.get("Host", "") # remove port, if present host = cut_port_re.sub("", host, 1) host = host.lower() if host.find(".") == -1 and not IPV4_RE.search(host): host += ".local" return host def get_cookies_for_curl(self, request) -> list[CurlMorsel]: """the process is similar to ``cookiejar.add_cookie_header``, but load all cookies""" self.jar._cookies_lock.acquire() # type: ignore morsels = [] try: self.jar._policy._now = self._now = int(time.time()) # type: ignore for cookie in self.jar: morsel = CurlMorsel.from_cookiejar_cookie(cookie) if not morsel.hostname: morsel.hostname = self._eff_request_host(request) morsels.append(morsel) finally: self.jar._cookies_lock.release() # type: ignore self.jar.clear_expired_cookies() return morsels def update_cookies_from_curl(self, morsels: list[CurlMorsel]): for morsel in morsels: cookie = morsel.to_cookiejar_cookie() self.jar.set_cookie(cookie) self.jar.clear_expired_cookies() def set( self, name: str, value: str, domain: str = "", path: str = "/", secure=False ) -> None: """ Set a cookie value by name. May optionally include domain and path. """ # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie if name.startswith("__Secure-") and secure is False: warnings.warn( "`secure` changed to True for `__Secure-` prefixed cookies", CurlCffiWarning, stacklevel=2, ) secure = True elif name.startswith("__Host-") and (secure is False or domain or path != "/"): warnings.warn( "`host` changed to True, `domain` removed, `path` changed to `/` " "for `__Host-` prefixed cookies", CurlCffiWarning, stacklevel=2, ) secure = True domain = "" path = "/" kwargs = { "version": 0, "name": name, "value": value, "port": None, "port_specified": False, "domain": domain, "domain_specified": bool(domain), "domain_initial_dot": domain.startswith("."), "path": path, "path_specified": bool(path), "secure": secure, "expires": None, "discard": True, "comment": None, "comment_url": None, "rest": {"HttpOnly": None}, "rfc2109": False, } cookie = Cookie(**kwargs) self.jar.set_cookie(cookie) def get( # type: ignore self, name: str, default: Optional[str] = None, domain: Optional[str] = None, path: Optional[str] = None, ) -> Optional[str]: """ Get a cookie by name. May optionally include domain and path in order to specify exactly which cookie to retrieve. """ value = None matched_domain = "" for cookie in self.jar: if ( cookie.name == name and (domain is None or cookie.domain == domain) and (path is None or cookie.path == path) ): # if cookies on two different domains do not share a same value if ( value is not None and not matched_domain.endswith(cookie.domain) and not str(cookie.domain).endswith(matched_domain) and value != cookie.value ): message = ( f"Multiple cookies exist with name={name} on " f"{matched_domain} and {cookie.domain}, add domain " "parameter to suppress this error." ) raise CookieConflict(message) value = cookie.value matched_domain = cookie.domain or "" if value is None: return default return value def get_dict( self, domain: Optional[str] = None, path: Optional[str] = None ) -> dict: """ Cookies with the same name on different domains may overwrite each other, do NOT use this function as a method of serialization. """ ret = {} for cookie in self.jar: if (domain is None or cookie.domain == domain) and ( path is None or cookie.path == path ): ret[cookie.name] = cookie.value return ret def delete( self, name: str, domain: Optional[str] = None, path: Optional[str] = None, ) -> None: """ Delete a cookie by name. May optionally include domain and path in order to specify exactly which cookie to delete. """ if domain is not None and path is not None: return self.jar.clear(domain, path, name) remove = [ cookie for cookie in self.jar if cookie.name == name and (domain is None or cookie.domain == domain) and (path is None or cookie.path == path) ] for cookie in remove: self.jar.clear(cookie.domain, cookie.path, cookie.name) def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None: """ Delete all cookies. Optionally include a domain and path in order to only delete a subset of all the cookies. """ args = [] if domain is not None: args.append(domain) if path is not None: assert domain is not None args.append(path) self.jar.clear(*args) def update(self, cookies: Optional[CookieTypes] = None) -> None: # type: ignore cookies = Cookies(cookies) for cookie in cookies.jar: self.jar.set_cookie(cookie) def __setitem__(self, name: str, value: str) -> None: return self.set(name, value) def __getitem__(self, name: str) -> str: value = self.get(name) if value is None: raise KeyError(name) return value def __delitem__(self, name: str) -> None: return self.delete(name) def __len__(self) -> int: return len(self.jar) def __iter__(self) -> Iterator[str]: return (cookie.name for cookie in self.jar) def __bool__(self) -> bool: for _ in self.jar: return True return False def __repr__(self) -> str: cookies_repr = ", ".join( [ f"" for cookie in self.jar ] ) return f""