Import python venv for stability

This commit is contained in:
2026-02-15 21:24:16 -08:00
parent 1343e93a59
commit 7d784705c9
4997 changed files with 1628270 additions and 0 deletions
@@ -0,0 +1,171 @@
__all__ = [
"Session",
"AsyncSession",
"BrowserType",
"BrowserTypeLiteral",
"CurlWsFlag",
"request",
"head",
"get",
"post",
"put",
"patch",
"delete",
"options",
"RequestsError",
"Cookies",
"Headers",
"Request",
"Response",
"AsyncWebSocket",
"WebSocket",
"WebSocketError",
"WebSocketClosed",
"WebSocketTimeout",
"WsCloseCode",
"ExtraFingerprints",
"CookieTypes",
"HeaderTypes",
"ProxySpec",
]
from typing import Optional, TYPE_CHECKING, TypedDict
from ..const import CurlWsFlag
from .cookies import Cookies, CookieTypes
from .errors import RequestsError
from .headers import Headers, HeaderTypes
from .impersonate import BrowserType, BrowserTypeLiteral, ExtraFingerprints
from .models import Request, Response
from .session import (
AsyncSession,
HttpMethod,
ProxySpec,
Session,
ThreadType,
RequestParams,
Unpack,
)
from .websockets import (
AsyncWebSocket,
WebSocket,
WebSocketClosed,
WebSocketError,
WebSocketTimeout,
WsCloseCode,
)
if TYPE_CHECKING:
class SessionRequestParams(RequestParams, total=False):
thread: Optional[ThreadType]
curl_options: Optional[dict]
debug: Optional[bool]
else:
SessionRequestParams = TypedDict
def request(
method: HttpMethod,
url: str,
thread: Optional[ThreadType] = None,
curl_options: Optional[dict] = None,
debug: Optional[bool] = None,
**kwargs: Unpack[RequestParams],
) -> Response:
"""Send an http request.
Parameters:
method: http method for the request: GET/POST/PUT/DELETE etc.
url: url for the requests.
params: query string for the requests.
data: form values(dict/list/tuple) or binary data to use in body,
``Content-Type: application/x-www-form-urlencoded`` will be added if a dict
is given.
json: json values to use in body, `Content-Type: application/json` will be added
automatically.
headers: headers to send.
cookies: cookies to use.
files: not supported, use ``multipart`` instead.
auth: HTTP basic auth, a tuple of (username, password), only basic auth is
supported.
timeout: how many seconds to wait before giving up.
allow_redirects: whether to allow redirection.
max_redirects: max redirect counts, default 30, use -1 for unlimited.
proxies: dict of proxies to use, prefer to use ``proxy`` if they are the same.
format: ``{"http": proxy_url, "https": proxy_url}``.
proxy: proxy to use, format: "http://user@pass:proxy_url".
Can't be used with `proxies` parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
verify: whether to verify https certs.
referer: shortcut for setting referer header.
accept_encoding: shortcut for setting accept-encoding header.
content_callback: a callback function to receive response body.
``def callback(chunk: bytes) -> None:``
impersonate: which browser version to impersonate.
ja3: ja3 string to impersonate.
akamai: akamai string to impersonate.
extra_fp: extra fingerprints options, in complement to ja3 and akamai strings.
thread: thread engine to use for working with other thread implementations.
choices: eventlet, gevent.
default_headers: whether to set default browser headers when impersonating.
default_encoding: encoding for decoding response content if charset is not found
in headers. Defaults to "utf-8". Can be set to a callable for automatic
detection.
quote: Set characters to be quoted, i.e. percent-encoded. Default safe string
is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character will be
removed from the safe string, thus quoted. If set to False, the url will be
kept as is, without any automatic percent-encoding, you must encode the URL
yourself.
curl_options: extra curl options to use.
http_version: limiting http version, defaults to http2.
debug: print extra curl debug info.
interface: which interface to use.
cert: a tuple of (cert, key) filenames for client cert.
stream: streaming the response, default False.
max_recv_speed: maximum receive speed, bytes per second.
multipart: upload files using the multipart format, see examples for details.
discard_cookies: discard cookies from server. Default to False.
Returns:
A ``Response`` object.
"""
debug = False if debug is None else debug
with Session(thread=thread, curl_options=curl_options, debug=debug) as s:
return s.request(method=method, url=url, **kwargs)
def head(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="HEAD", url=url, **kwargs)
def get(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="GET", url=url, **kwargs)
def post(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="POST", url=url, **kwargs)
def put(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="PUT", url=url, **kwargs)
def patch(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="PATCH", url=url, **kwargs)
def delete(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="DELETE", url=url, **kwargs)
def options(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="OPTIONS", url=url, **kwargs)
def trace(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="TRACE", url=url, **kwargs)
def query(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="QUERY", url=url, **kwargs)
@@ -0,0 +1,364 @@
# Adapted from: https://github.com/encode/httpx/blob/master/httpx/_models.py,
# which is licensed under the BSD License.
# See https://github.com/encode/httpx/blob/master/LICENSE.md
__all__ = ["Cookies"]
import re
import time
import warnings
from dataclasses import dataclass
from http.cookiejar import Cookie, CookieJar
from http.cookies import _unquote
from typing import Optional, Union
from collections.abc import Iterator, MutableMapping
from urllib.parse import urlparse
from ..utils import CurlCffiWarning
from .errors import CookieConflict, RequestsError
CookieTypes = Union["Cookies", CookieJar, dict[str, str], list[tuple[str, str]]]
@dataclass
class CurlMorsel:
name: str
value: str
hostname: str = ""
subdomains: bool = False
path: str = "/"
secure: bool = False
expires: int = 0
http_only: bool = False
@staticmethod
def parse_bool(s):
return s == "TRUE"
@staticmethod
def dump_bool(s):
return "TRUE" if s else "FALSE"
@classmethod
def from_curl_format(cls, set_cookie_line: bytes):
(
hostname,
subdomains,
path,
secure,
expires,
name,
value,
) = set_cookie_line.decode().split("\t")
if hostname and hostname[0] == "#":
http_only = True
# e.g. #HttpOnly_postman-echo.com
domain = hostname[10:] # len("#HttpOnly_") == 10
else:
http_only = False
domain = hostname
return cls(
hostname=domain,
subdomains=cls.parse_bool(subdomains),
path=path,
secure=cls.parse_bool(secure),
expires=int(expires),
name=name,
value=_unquote(value),
http_only=http_only,
)
def to_curl_format(self):
if not self.hostname:
raise RequestsError(f"Domain not found for cookie {self.name}={self.value}")
return "\t".join(
[
self.hostname,
self.dump_bool(self.subdomains),
self.path,
self.dump_bool(self.secure),
str(self.expires),
self.name,
self.value,
]
)
@classmethod
def from_cookiejar_cookie(cls, cookie: Cookie):
return cls(
name=cookie.name,
value=cookie.value or "",
hostname=cookie.domain,
subdomains=cookie.domain_specified,
path=cookie.path,
secure=cookie.secure,
expires=int(cookie.expires or 0),
http_only=False,
)
def to_cookiejar_cookie(self) -> Cookie:
# the leading dot actually does not mean anything nowadays
# https://stackoverflow.com/a/20884869/1061155
# https://github.com/python/cpython/blob/d6555abfa7384b5a40435a11bdd2aa6bbf8f5cfc/Lib/http/cookiejar.py#L1535
return Cookie(
version=0,
name=self.name,
value=self.value,
port=None,
port_specified=False,
domain=self.hostname,
domain_specified=self.subdomains,
domain_initial_dot=bool(self.hostname.startswith(".")),
path=self.path,
path_specified=bool(self.path),
secure=self.secure,
# using if explicitly to make it clear.
expires=None if self.expires == 0 else self.expires,
discard=self.expires == 0,
comment=None,
comment_url=None,
rest=dict(http_only=f"{self.http_only}"),
rfc2109=False,
)
cut_port_re = re.compile(r":\d+$", re.ASCII)
IPV4_RE = re.compile(r"\.\d+$", re.ASCII)
class Cookies(MutableMapping[str, str]):
"""
HTTP Cookies, as a mutable mapping.
"""
def __init__(self, cookies: Optional[CookieTypes] = None) -> None:
if cookies is None or isinstance(cookies, dict):
self.jar = CookieJar()
if isinstance(cookies, dict):
for key, value in cookies.items():
self.set(key, value)
elif isinstance(cookies, list):
self.jar = CookieJar()
for key, value in cookies:
self.set(key, value)
elif isinstance(cookies, Cookies):
self.jar = CookieJar()
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
else:
self.jar = cookies
def _eff_request_host(self, request) -> str:
"""
Almost equivalent to the eff_request_host function in:
https://github.com/python/cpython/blob/3.11/Lib/http/cookiejar.py#L636
"""
host = urlparse(request.url)[1]
if host == "":
host = request.headers.get("Host", "")
# remove port, if present
host = cut_port_re.sub("", host, 1)
host = host.lower()
if host.find(".") == -1 and not IPV4_RE.search(host):
host += ".local"
return host
def get_cookies_for_curl(self, request) -> list[CurlMorsel]:
"""the process is similar to ``cookiejar.add_cookie_header``, but load all
cookies"""
self.jar._cookies_lock.acquire() # type: ignore
morsels = []
try:
self.jar._policy._now = self._now = int(time.time()) # type: ignore
for cookie in self.jar:
morsel = CurlMorsel.from_cookiejar_cookie(cookie)
if not morsel.hostname:
morsel.hostname = self._eff_request_host(request)
morsels.append(morsel)
finally:
self.jar._cookies_lock.release() # type: ignore
self.jar.clear_expired_cookies()
return morsels
def update_cookies_from_curl(self, morsels: list[CurlMorsel]):
for morsel in morsels:
cookie = morsel.to_cookiejar_cookie()
self.jar.set_cookie(cookie)
self.jar.clear_expired_cookies()
def set(
self, name: str, value: str, domain: str = "", path: str = "/", secure=False
) -> None:
"""
Set a cookie value by name. May optionally include domain and path.
"""
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie
if name.startswith("__Secure-") and secure is False:
warnings.warn(
"`secure` changed to True for `__Secure-` prefixed cookies",
CurlCffiWarning,
stacklevel=2,
)
secure = True
elif name.startswith("__Host-") and (secure is False or domain or path != "/"):
warnings.warn(
"`host` changed to True, `domain` removed, `path` changed to `/` "
"for `__Host-` prefixed cookies",
CurlCffiWarning,
stacklevel=2,
)
secure = True
domain = ""
path = "/"
kwargs = {
"version": 0,
"name": name,
"value": value,
"port": None,
"port_specified": False,
"domain": domain,
"domain_specified": bool(domain),
"domain_initial_dot": domain.startswith("."),
"path": path,
"path_specified": bool(path),
"secure": secure,
"expires": None,
"discard": True,
"comment": None,
"comment_url": None,
"rest": {"HttpOnly": None},
"rfc2109": False,
}
cookie = Cookie(**kwargs)
self.jar.set_cookie(cookie)
def get( # type: ignore
self,
name: str,
default: Optional[str] = None,
domain: Optional[str] = None,
path: Optional[str] = None,
) -> Optional[str]:
"""
Get a cookie by name. May optionally include domain and path
in order to specify exactly which cookie to retrieve.
"""
value = None
matched_domain = ""
for cookie in self.jar:
if (
cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
):
# if cookies on two different domains do not share a same value
if (
value is not None
and not matched_domain.endswith(cookie.domain)
and not str(cookie.domain).endswith(matched_domain)
and value != cookie.value
):
message = (
f"Multiple cookies exist with name={name} on "
f"{matched_domain} and {cookie.domain}, add domain "
"parameter to suppress this error."
)
raise CookieConflict(message)
value = cookie.value
matched_domain = cookie.domain or ""
if value is None:
return default
return value
def get_dict(
self, domain: Optional[str] = None, path: Optional[str] = None
) -> dict:
"""
Cookies with the same name on different domains may overwrite each other,
do NOT use this function as a method of serialization.
"""
ret = {}
for cookie in self.jar:
if (domain is None or cookie.domain == domain) and (
path is None or cookie.path == path
):
ret[cookie.name] = cookie.value
return ret
def delete(
self,
name: str,
domain: Optional[str] = None,
path: Optional[str] = None,
) -> None:
"""
Delete a cookie by name. May optionally include domain and path
in order to specify exactly which cookie to delete.
"""
if domain is not None and path is not None:
return self.jar.clear(domain, path, name)
remove = [
cookie
for cookie in self.jar
if cookie.name == name
and (domain is None or cookie.domain == domain)
and (path is None or cookie.path == path)
]
for cookie in remove:
self.jar.clear(cookie.domain, cookie.path, cookie.name)
def clear(self, domain: Optional[str] = None, path: Optional[str] = None) -> None:
"""
Delete all cookies. Optionally include a domain and path in
order to only delete a subset of all the cookies.
"""
args = []
if domain is not None:
args.append(domain)
if path is not None:
assert domain is not None
args.append(path)
self.jar.clear(*args)
def update(self, cookies: Optional[CookieTypes] = None) -> None: # type: ignore
cookies = Cookies(cookies)
for cookie in cookies.jar:
self.jar.set_cookie(cookie)
def __setitem__(self, name: str, value: str) -> None:
return self.set(name, value)
def __getitem__(self, name: str) -> str:
value = self.get(name)
if value is None:
raise KeyError(name)
return value
def __delitem__(self, name: str) -> None:
return self.delete(name)
def __len__(self) -> int:
return len(self.jar)
def __iter__(self) -> Iterator[str]:
return (cookie.name for cookie in self.jar)
def __bool__(self) -> bool:
for _ in self.jar:
return True
return False
def __repr__(self) -> str:
cookies_repr = ", ".join(
[
f"<Cookie {cookie.name}={cookie.value} for {cookie.domain} />"
for cookie in self.jar
]
)
return f"<Cookies[{cookies_repr}]>"
@@ -0,0 +1,7 @@
# for compatibility with 0.5.x
__all__ = ["CurlError", "RequestsError", "CookieConflict", "SessionClosed"]
from ..curl import CurlError
from .exceptions import CookieConflict, SessionClosed
from .exceptions import RequestException as RequestsError
@@ -0,0 +1,227 @@
# Apache 2.0 License
# Vendored from https://github.com/psf/requests/blob/main/src/requests/exceptions.py
# With our own addtions
import json
from typing import Literal, Union
from ..const import CurlECode
from ..curl import CurlError
# Note IOError is an alias of OSError in Python 3.x
class RequestException(CurlError, OSError):
"""Base exception for curl_cffi.requests package"""
def __init__(
self,
msg,
code: Union[CurlECode, Literal[0]] = 0,
response=None,
*args,
**kwargs,
):
super().__init__(msg, code, *args, **kwargs)
self.response = response
class CookieConflict(RequestException):
"""Same cookie exists for different domains."""
class SessionClosed(RequestException):
"""The session has already been closed."""
class ImpersonateError(RequestException):
"""The impersonate config was wrong or impersonate failed."""
# not used
class InvalidJSONError(RequestException):
"""A JSON error occurred. not used"""
# not used
class JSONDecodeError(InvalidJSONError, json.JSONDecodeError):
"""Couldn't decode the text into json. not used"""
class HTTPError(RequestException):
"""An HTTP error occurred."""
class IncompleteRead(HTTPError):
"""Incomplete read of content"""
class ConnectionError(RequestException):
"""A Connection error occurred."""
class DNSError(ConnectionError):
"""Could not resolve"""
class ProxyError(RequestException):
"""A proxy error occurred."""
class SSLError(ConnectionError):
"""An SSL error occurred."""
class CertificateVerifyError(SSLError):
"""Raised when certificate validated has failed"""
class Timeout(RequestException):
"""The request timed out."""
# not used
class ConnectTimeout(ConnectionError, Timeout):
"""The request timed out while trying to connect to the remote server.
Requests that produced this error are safe to retry.
not used
"""
# not used
class ReadTimeout(Timeout):
"""The server did not send any data in the allotted amount of time. not used"""
# not used
class URLRequired(RequestException):
"""A valid URL is required to make a request. not used"""
class TooManyRedirects(RequestException):
"""Too many redirects."""
# not used
class MissingSchema(RequestException, ValueError):
"""The URL scheme (e.g. http or https) is missing. not used"""
class InvalidSchema(RequestException, ValueError):
"""The URL scheme provided is either invalid or unsupported. not used"""
class InvalidURL(RequestException, ValueError):
"""The URL provided was somehow invalid."""
# not used
class InvalidHeader(RequestException, ValueError):
"""The header value provided was somehow invalid. not used"""
# not used
class InvalidProxyURL(InvalidURL):
"""The proxy URL provided is invalid. not used"""
# not used
class ChunkedEncodingError(RequestException):
"""The server declared chunked encoding but sent an invalid chunk. not used"""
# not used
class ContentDecodingError(RequestException):
"""Failed to decode response content. not used"""
# not used
class StreamConsumedError(RequestException, TypeError):
"""The content for this response was already consumed. not used"""
# does not support
class RetryError(RequestException):
"""Custom retries logic failed. not used"""
# not used
class UnrewindableBodyError(RequestException):
"""Requests encountered an error when trying to rewind a body. not used"""
class InterfaceError(RequestException):
"""A specified outgoing interface could not be used."""
# Warnings
# TODO: use this warning as a base
class RequestsWarning(Warning):
"""Base warning for Requests. not used"""
# not used
class FileModeWarning(RequestsWarning, DeprecationWarning):
"""A file was opened in text mode, but Requests determined its binary length.
not used"""
# not used
class RequestsDependencyWarning(RequestsWarning):
"""An imported dependency doesn't match the expected version range."""
CODE2ERROR = {
0: RequestException,
CurlECode.UNSUPPORTED_PROTOCOL: InvalidSchema,
CurlECode.URL_MALFORMAT: InvalidURL,
CurlECode.COULDNT_RESOLVE_PROXY: ProxyError,
CurlECode.COULDNT_RESOLVE_HOST: DNSError,
CurlECode.COULDNT_CONNECT: ConnectionError,
CurlECode.WEIRD_SERVER_REPLY: ConnectionError,
CurlECode.REMOTE_ACCESS_DENIED: ConnectionError,
CurlECode.HTTP2: HTTPError,
CurlECode.HTTP_RETURNED_ERROR: HTTPError,
CurlECode.WRITE_ERROR: RequestException,
CurlECode.READ_ERROR: RequestException,
CurlECode.OUT_OF_MEMORY: RequestException,
CurlECode.OPERATION_TIMEDOUT: Timeout,
CurlECode.SSL_CONNECT_ERROR: SSLError,
CurlECode.INTERFACE_FAILED: InterfaceError,
CurlECode.TOO_MANY_REDIRECTS: TooManyRedirects,
CurlECode.UNKNOWN_OPTION: RequestException,
CurlECode.SETOPT_OPTION_SYNTAX: RequestException,
CurlECode.GOT_NOTHING: ConnectionError,
CurlECode.SSL_ENGINE_NOTFOUND: SSLError,
CurlECode.SSL_ENGINE_SETFAILED: SSLError,
CurlECode.SEND_ERROR: ConnectionError,
CurlECode.RECV_ERROR: ConnectionError,
CurlECode.SSL_CERTPROBLEM: SSLError,
CurlECode.SSL_CIPHER: SSLError,
CurlECode.PEER_FAILED_VERIFICATION: CertificateVerifyError,
CurlECode.BAD_CONTENT_ENCODING: HTTPError,
CurlECode.SSL_ENGINE_INITFAILED: SSLError,
CurlECode.SSL_CACERT_BADFILE: SSLError,
CurlECode.SSL_CRL_BADFILE: SSLError,
CurlECode.SSL_ISSUER_ERROR: SSLError,
CurlECode.SSL_PINNEDPUBKEYNOTMATCH: SSLError,
CurlECode.SSL_INVALIDCERTSTATUS: SSLError,
CurlECode.HTTP2_STREAM: HTTPError,
CurlECode.HTTP3: HTTPError,
CurlECode.QUIC_CONNECT_ERROR: ConnectionError,
CurlECode.PROXY: ProxyError,
CurlECode.SSL_CLIENTCERT: SSLError,
CurlECode.ECH_REQUIRED: SSLError,
CurlECode.PARTIAL_FILE: IncompleteRead,
}
# credits: https://github.com/yt-dlp/yt-dlp/blob/master/yt_dlp/networking/_curlcffi.py#L241
# Unlicense
def code2error(code: Union[CurlECode, Literal[0]], msg: str):
if code == CurlECode.RECV_ERROR and "CONNECT" in msg:
return ProxyError
return CODE2ERROR.get(code, RequestException)
@@ -0,0 +1,347 @@
# Copied from: https://github.com/encode/httpx/blob/master/httpx/_models.py,
# which is licensed under the BSD License.
# See https://github.com/encode/httpx/blob/master/LICENSE.md
from collections.abc import (
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
MutableMapping,
Sequence,
ValuesView,
)
from typing import Any, AnyStr, Optional, Union, cast
HeaderTypes = Union[
"Headers",
Mapping[str, Optional[str]],
Mapping[bytes, Optional[bytes]],
Sequence[tuple[str, str]],
Sequence[tuple[bytes, bytes]],
Sequence[Union[str, bytes]],
]
def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str:
return value if isinstance(value, str) else value.decode(encoding)
SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}
def obfuscate_sensitive_headers(
items: Iterable[tuple[AnyStr, Optional[AnyStr]]],
) -> Iterator[tuple[AnyStr, Optional[AnyStr]]]:
for k, v in items:
if to_str(k.lower()) in SENSITIVE_HEADERS:
v = b"[secure]" if isinstance(v, bytes) else "[secure]" # type: ignore
yield k, v
def normalize_header_key(
value: Union[str, bytes],
lower: bool,
encoding: Optional[str] = None,
) -> bytes:
"""
Coerce str/bytes into a strictly byte-wise HTTP header key.
"""
bytes_value = (
value if isinstance(value, bytes) else value.encode(encoding or "ascii")
)
return bytes_value.lower() if lower else bytes_value
def normalize_header_value(
value: Union[str, bytes, int, None], encoding: Optional[str] = None
) -> Union[bytes, None]:
"""
Coerce str/bytes into a strictly byte-wise HTTP header value.
"""
if value is None:
return None
if isinstance(value, bytes):
return value
# The default encoding for header value should be latin-1
# See: RFC and https://github.com/python/cpython/blob/bc264eac3ad14dab748e33b3d714c2674872791f/Lib/http/client.py#L1309
if isinstance(value, int):
return str(value).encode()
return cast(str, value).encode(encoding or "latin-1")
class Headers(MutableMapping[str, Optional[str]]):
"""
HTTP headers, as a case-insensitive multi-dict.
"""
def __init__(
self, headers: Optional[HeaderTypes] = None, encoding: Optional[str] = None
):
self._list: list[tuple[bytes, bytes, Optional[bytes]]]
if isinstance(headers, Headers):
self._list = list(headers._list)
encoding = encoding or headers.encoding
elif not headers:
self._list = []
elif isinstance(headers, Mapping):
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in headers.items()
]
elif isinstance(headers, list):
# list of "Name: Value" pairs
if isinstance(headers[0], (str, bytes)):
sep = ":" if isinstance(headers[0], str) else b":"
h = []
for line in headers:
k, v = line.split(sep, maxsplit=1) # pyright: ignore
h.append((k, v.strip()))
# list of (Name, Value) pairs
elif isinstance(headers[0], tuple):
h = headers
self._list = [
(
normalize_header_key(k, lower=False, encoding=encoding),
normalize_header_key(k, lower=True, encoding=encoding),
normalize_header_value(v, encoding),
)
for k, v in h # pyright: ignore
]
self._encoding = encoding
@property
def encoding(self) -> str:
"""
Header encoding is mandated as ascii, but we allow fallbacks to utf-8
or iso-8859-1.
"""
if self._encoding is None:
for encoding in ["ascii", "utf-8"]:
for key, value in self.raw:
try:
key.decode(encoding)
value.decode(encoding) if value is not None else value
except UnicodeDecodeError:
break
else:
# The else block runs if 'break' did not occur, meaning
# all values fitted the encoding.
self._encoding = encoding
break
else:
# The ISO-8859-1 encoding covers all 256 code points in a byte,
# so will never raise decode errors.
self._encoding = "iso-8859-1"
return self._encoding
@encoding.setter
def encoding(self, value: str) -> None:
self._encoding = value
@property
def raw(self) -> list[tuple[bytes, Optional[bytes]]]:
"""
Returns a list of the raw header items, as byte pairs.
"""
return [(raw_key, value) for raw_key, _, value in self._list]
def keys(self) -> KeysView[str]:
return {key.decode(self.encoding): None for _, key, _ in self._list}.keys()
def values(self) -> ValuesView[Optional[str]]:
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else "None"
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.values()
def items(self) -> ItemsView[str, Optional[str]]:
"""
Return `(key, value)` items of headers. Concatenate headers
into a single comma separated value when a key occurs multiple times.
"""
values_dict: dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else "None"
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.items()
def multi_items(self) -> list[tuple[str, Optional[str]]]:
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurrences of the same key without concatenating into a single
comma separated value.
"""
return [
(
key.decode(self.encoding),
value.decode(self.encoding) if value is not None else value,
)
for key, _, value in self._list
]
def get(self, key: str, default: Any = None) -> Any:
"""
Return a header value. If multiple occurrences of the header occur
then concatenate them together with commas.
"""
try:
return self[key]
except KeyError:
return default
def get_list(self, key: str, split_commas: bool = False) -> list[Optional[str]]:
"""
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma separated header
values are split into multiple return strings.
"""
get_header_key = key.lower().encode(self.encoding)
values = [
item_value.decode(self.encoding) if item_value is not None else item_value
for _, item_key, item_value in self._list
if item_key.lower() == get_header_key
]
if not split_commas:
return values
split_values = []
for value in values:
split_values.extend([item.strip() for item in value.split(",")]) # type: ignore
return split_values
def update(self, headers: Optional[HeaderTypes] = None) -> None: # type: ignore
headers = Headers(headers)
for key in headers:
if key in self:
self.pop(key)
self._list.extend(headers._list)
def copy(self) -> "Headers":
return Headers(self, encoding=self.encoding)
def __getitem__(self, key: str) -> Optional[str]:
"""
Return a single header value.
If there are multiple headers with the same key, then we concatenate
them with commas. See: https://tools.ietf.org/html/rfc7230#section-3.2.2
"""
normalized_key = key.lower().encode(self.encoding)
items = [
header_value.decode(self.encoding)
if header_value is not None
else header_value
for _, header_key, header_value in self._list
if header_key == normalized_key
]
if items == [None]:
return None
if items:
return ", ".join([str(item) for item in items])
raise KeyError(key)
def __setitem__(self, key: str, value: Optional[str]) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.encode(self._encoding or "utf-8")
set_value = (
value.encode(self._encoding or "utf-8") if value is not None else value
)
lookup_key = set_key.lower()
found_indexes = [
idx
for idx, (_, item_key, _) in enumerate(self._list)
if item_key == lookup_key
]
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, lookup_key, set_value)
else:
self._list.append((set_key, lookup_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode(self.encoding)
pop_indexes = [
idx
for idx, (_, item_key, _) in enumerate(self._list)
if item_key.lower() == del_key
]
if not pop_indexes:
raise KeyError(key)
for idx in reversed(pop_indexes):
del self._list[idx]
def __contains__(self, key: Any) -> bool:
header_key = key.lower().encode(self.encoding)
return header_key in [key for _, key, _ in self._list]
def __iter__(self) -> Iterator[Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: Any) -> bool:
try:
other_headers = Headers(other)
except ValueError:
return False
self_list = [(key, value) for _, key, value in self._list]
other_list = [(key, value) for _, key, value in other_headers._list]
return sorted(self_list) == sorted(other_list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
encoding_str = ""
if self.encoding != "ascii":
encoding_str = f", encoding={self.encoding!r}"
as_list = list(obfuscate_sensitive_headers(self.multi_items()))
as_dict = dict(as_list)
no_duplicate_keys = len(as_dict) == len(as_list)
if no_duplicate_keys:
return f"{class_name}({as_dict!r}{encoding_str})"
return f"{class_name}({as_list!r}{encoding_str})"
@@ -0,0 +1,435 @@
import warnings
from dataclasses import dataclass
from enum import Enum
from typing import Literal, Optional, TypedDict
from ..const import CurlOpt, CurlSslVersion
from ..utils import CurlCffiWarning
BrowserTypeLiteral = Literal[
# Edge
"edge99",
"edge101",
# Chrome
"chrome99",
"chrome100",
"chrome101",
"chrome104",
"chrome107",
"chrome110",
"chrome116",
"chrome119",
"chrome120",
"chrome123",
"chrome124",
"chrome131",
"chrome133a",
"chrome136",
"chrome99_android",
"chrome131_android",
# Safari
"safari153",
"safari155",
"safari170",
"safari172_ios",
"safari180",
"safari180_ios",
"safari184",
"safari184_ios",
"safari260",
"safari260_ios",
# Firefox
"firefox133",
"firefox135",
"tor145",
# alias
"chrome",
"edge",
"safari",
"safari_ios",
"safari_beta",
"safari_ios_beta",
"chrome_android",
"firefox",
# deprecated aliases
"safari15_3",
"safari15_5",
"safari17_0",
"safari17_2_ios",
"safari18_0",
"safari18_0_ios",
"safari18_4",
"safari18_4_ios",
# Canonical names
# "edge_99",
# "edge_101",
# "safari_15.3_macos",
# "safari_15.5_macos",
# "safari_17.2_ios",
# "safari_17.0_macos",
# "safari_18.0_ios",
# "safari_18.0_macos",
]
DEFAULT_CHROME = "chrome136"
DEFAULT_EDGE = "edge101"
DEFAULT_SAFARI = "safari184"
DEFAULT_SAFARI_IOS = "safari184_ios"
DEFAULT_SAFARI_BETA = "safari260"
DEFAULT_SAFARI_IOS_BETA = "safari260_ios"
DEFAULT_CHROME_ANDROID = "chrome131_android"
DEFAULT_FIREFOX = "firefox135"
DEFAULT_TOR = "tor145"
REAL_TARGET_MAP = {
"chrome": "chrome136",
"edge": "edge101",
"safari": "safari184",
"safari_ios": "safari184_ios",
"safari_beta": "safari260",
"safari_ios_beta": "safari260_ios",
"chrome_android": "chrome131_android",
"firefox": "firefox135",
"tor": "tor145",
}
def normalize_browser_type(item):
if item == "chrome": # noqa: SIM116
return DEFAULT_CHROME
elif item == "edge":
return DEFAULT_EDGE
elif item == "safari":
return DEFAULT_SAFARI
elif item == "safari_ios":
return DEFAULT_SAFARI_IOS
elif item == "safari_beta":
return DEFAULT_SAFARI_BETA
elif item == "safari_ios_beta":
return DEFAULT_SAFARI_IOS_BETA
elif item == "chrome_android":
return DEFAULT_CHROME_ANDROID
elif item == "firefox":
return DEFAULT_FIREFOX
elif item == "tor":
return DEFAULT_TOR
else:
return item
class BrowserType(str, Enum): # TODO: remove in version 1.x
edge99 = "edge99"
edge101 = "edge101"
chrome99 = "chrome99"
chrome100 = "chrome100"
chrome101 = "chrome101"
chrome104 = "chrome104"
chrome107 = "chrome107"
chrome110 = "chrome110"
chrome116 = "chrome116"
chrome119 = "chrome119"
chrome120 = "chrome120"
chrome123 = "chrome123"
chrome124 = "chrome124"
chrome131 = "chrome131"
chrome133a = "chrome133a"
chrome136 = "chrome136"
chrome99_android = "chrome99_android"
chrome131_android = "chrome131_android"
safari153 = "safari153"
safari155 = "safari155"
safari170 = "safari170"
safari172_ios = "safari172_ios"
safari180 = "safari180"
safari180_ios = "safari180_ios"
safari184 = "safari184"
safari184_ios = "safari184_ios"
safari260 = "safari260"
safari260_ios = "safari260_ios"
firefox133 = "firefox133"
firefox135 = "firefox135"
tor145 = "tor145"
# deprecated aliases
safari15_3 = "safari15_3"
safari15_5 = "safari15_5"
safari17_0 = "safari17_0"
safari17_2_ios = "safari17_2_ios"
safari18_0 = "safari18_0"
safari18_0_ios = "safari18_0_ios"
@dataclass
class ExtraFingerprints:
tls_min_version: int = CurlSslVersion.TLSv1_2
tls_grease: bool = False
tls_permute_extensions: bool = False
tls_cert_compression: Literal["zlib", "brotli"] = "brotli"
tls_signature_algorithms: Optional[list[str]] = None
tls_delegated_credential: str = ""
tls_record_size_limit: int = 0
http2_stream_weight: int = 256
http2_stream_exclusive: int = 1
http2_no_priority: bool = False
class ExtraFpDict(TypedDict, total=False):
tls_min_version: int
tls_grease: bool
tls_permute_extensions: bool
tls_cert_compression: Literal["zlib", "brotli"]
tls_signature_algorithms: Optional[list[str]]
tls_delegated_credential: str
tls_record_size_limit: int
http2_stream_weight: int
http2_stream_exclusive: int
http2_no_priority: bool
# TLS version are in the format of 0xAABB, where AA is major version and BB is minor
# version. As of today, the major version is always 03.
TLS_VERSION_MAP = {
0x0301: CurlSslVersion.TLSv1_0, # 769
0x0302: CurlSslVersion.TLSv1_1, # 770
0x0303: CurlSslVersion.TLSv1_2, # 771
0x0304: CurlSslVersion.TLSv1_3, # 772
}
# A list of the possible cipher suite ids. Taken from
# http://www.iana.org/assignments/tls-parameters/tls-parameters.xml
# via BoringSSL
TLS_CIPHER_NAME_MAP = {
0x000A: "TLS_RSA_WITH_3DES_EDE_CBC_SHA",
0x002F: "TLS_RSA_WITH_AES_128_CBC_SHA",
0x0033: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA",
0x0035: "TLS_RSA_WITH_AES_256_CBC_SHA",
0x0039: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA",
0x003C: "TLS_RSA_WITH_AES_128_CBC_SHA256",
0x003D: "TLS_RSA_WITH_AES_256_CBC_SHA256",
0x0067: "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256",
0x006B: "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256",
0x008C: "TLS_PSK_WITH_AES_128_CBC_SHA",
0x008D: "TLS_PSK_WITH_AES_256_CBC_SHA",
0x009C: "TLS_RSA_WITH_AES_128_GCM_SHA256",
0x009D: "TLS_RSA_WITH_AES_256_GCM_SHA384",
0x009E: "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256",
0x009F: "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384",
0x1301: "TLS_AES_128_GCM_SHA256",
0x1302: "TLS_AES_256_GCM_SHA384",
0x1303: "TLS_CHACHA20_POLY1305_SHA256",
0xC008: "TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA",
0xC009: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
0xC00A: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
0xC012: "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
0xC013: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
0xC014: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
0xC023: "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
0xC024: "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384",
0xC027: "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
0xC028: "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
0xC02B: "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
0xC02C: "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
0xC02F: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
0xC030: "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
0xC035: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA",
0xC036: "TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA",
0xCCA8: "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
0xCCA9: "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
0xCCAC: "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256",
}
# RFC tls extensions: https://datatracker.ietf.org/doc/html/rfc6066
# IANA list: https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml
TLS_EXTENSION_NAME_MAP = {
0: "server_name",
1: "max_fragment_length",
2: "client_certificate_url",
3: "trusted_ca_keys",
4: "truncated_hmac",
5: "status_request",
6: "user_mapping",
7: "client_authz",
8: "server_authz",
9: "cert_type",
10: "supported_groups", # (renamed from "elliptic_curves")
11: "ec_point_formats",
12: "srp",
13: "signature_algorithms",
14: "use_srtp",
15: "heartbeat",
16: "application_layer_protocol_negotiation",
17: "status_request_v2",
18: "signed_certificate_timestamp",
19: "client_certificate_type",
20: "server_certificate_type",
21: "padding",
22: "encrypt_then_mac",
23: "extended_master_secret",
24: "token_binding",
25: "cached_info",
26: "tls_lts",
27: "compress_certificate",
28: "record_size_limit",
29: "pwd_protect",
30: "pwd_clear",
31: "password_salt",
32: "ticket_pinning",
33: "tls_cert_with_extern_psk",
34: "delegated_credential",
35: "session_ticket", # (renamed from "SessionTicket TLS")
36: "TLMSP",
37: "TLMSP_proxying",
38: "TLMSP_delegate",
39: "supported_ekt_ciphers",
# 40:"Reserved",
41: "pre_shared_key",
42: "early_data",
43: "supported_versions",
44: "cookie",
45: "psk_key_exchange_modes",
# 46:"Reserved",
47: "certificate_authorities",
48: "oid_filters",
49: "post_handshake_auth",
50: "signature_algorithms_cert",
51: "key_share",
52: "transparency_info",
# 53:"connection_id", # (deprecated)
54: "connection_id",
55: "external_id_hash",
56: "external_session_id",
57: "quic_transport_parameters",
58: "ticket_request",
59: "dnssec_chain",
60: "sequence_number_encryption_algorithms",
61: "rrc",
17513: "application_settings", # BoringSSL private usage
17613: "application_settings new", # BoringSSL private usage
# 62-2569:"Unassigned
# 2570:"Reserved
# 2571-6681:"Unassigned
# 6682:"Reserved
# 6683-10793:"Unassigned
# 10794:"Reserved
# 10795-14905:"Unassigned
# 14906:"Reserved
# 14907-19017:"Unassigned
# 19018:"Reserved
# 19019-23129:"Unassigned
# 23130:"Reserved
# 23131-27241:"Unassigned
# 27242:"Reserved
# 27243-31353:"Unassigned
# 31354:"Reserved
# 31355-35465:"Unassigned
# 35466:"Reserved
# 35467-39577:"Unassigned
# 39578:"Reserved
# 39579-43689:"Unassigned
# 43690:"Reserved
# 43691-47801:"Unassigned
# 47802:"Reserved
# 47803-51913:"Unassigned
# 51914:"Reserved
# 51915-56025:"Unassigned
# 56026:"Reserved
# 56027-60137:"Unassigned
# 60138:"Reserved
# 60139-64249:"Unassigned
# 64250:"Reserved
# 64251-64767:"Unassigned
64768: "ech_outer_extensions",
# 64769-65036:"Unassigned
65037: "encrypted_client_hello",
# 65038-65279:"Unassigned
# 65280:"Reserved for Private Use
65281: "renegotiation_info",
# 65282-65535:"Reserved for Private Use
}
TLS_EC_CURVES_MAP = {
19: "P-192",
21: "P-224",
23: "P-256",
24: "P-384",
25: "P-521",
29: "X25519",
256: "ffdhe2048",
257: "ffdhe3072",
4588: "X25519MLKEM768",
25497: "X25519Kyber768Draft00",
}
def toggle_extension(curl, extension_id: int, enable: bool):
# ECH
if extension_id == 65037:
if enable:
curl.setopt(CurlOpt.ECH, "grease")
else:
curl.setopt(CurlOpt.ECH, "")
# compress certificate
elif extension_id == 27:
if enable:
warnings.warn(
"Cert compression setting to brotli, "
"you had better specify which to use: zlib/brotli",
CurlCffiWarning,
stacklevel=1,
)
curl.setopt(CurlOpt.SSL_CERT_COMPRESSION, "brotli")
else:
curl.setopt(CurlOpt.SSL_CERT_COMPRESSION, "")
# ALPS: application settings
elif extension_id == 17513:
if enable:
curl.setopt(CurlOpt.SSL_ENABLE_ALPS, 1)
else:
curl.setopt(CurlOpt.SSL_ENABLE_ALPS, 0)
elif extension_id == 17613:
if enable:
curl.setopt(CurlOpt.SSL_ENABLE_ALPS, 1)
curl.setopt(CurlOpt.TLS_USE_NEW_ALPS_CODEPOINT, 1)
else:
curl.setopt(CurlOpt.SSL_ENABLE_ALPS, 0)
curl.setopt(CurlOpt.TLS_USE_NEW_ALPS_CODEPOINT, 0)
# server_name
elif extension_id == 0:
raise NotImplementedError(
"It's unlikely that the server_name(0) extension being changed."
)
# ALPN
elif extension_id == 16:
if enable:
curl.setopt(CurlOpt.SSL_ENABLE_ALPN, 1)
else:
curl.setopt(CurlOpt.SSL_ENABLE_ALPN, 0)
# status_request
elif extension_id == 5:
if enable:
curl.setopt(CurlOpt.TLS_STATUS_REQUEST, 1)
# signed_certificate_timestamps
elif extension_id == 18:
if enable:
curl.setopt(CurlOpt.TLS_SIGNED_CERT_TIMESTAMPS, 1)
# session_ticket
elif extension_id == 35:
if enable:
curl.setopt(CurlOpt.SSL_ENABLE_TICKET, 1)
else:
curl.setopt(CurlOpt.SSL_ENABLE_TICKET, 0)
# padding, should be ignored
elif extension_id == 21:
pass # type: ignore
# firefox extension, toggled by extra_fp
elif extension_id in [34, 28]:
pass
else:
raise NotImplementedError(
f"This extension({extension_id}) can not be toggled for now, it may be "
"updated later."
)
@@ -0,0 +1,314 @@
from contextlib import suppress
import queue
import re
import warnings
from concurrent.futures import Future
from typing import Any, Callable, Optional, Union
from collections.abc import Awaitable
from ..curl import Curl
from ..utils import CurlCffiWarning
from .cookies import Cookies
from .exceptions import HTTPError, RequestException
from .headers import Headers
# Use orjson if present
try:
from orjson import loads
except ImportError:
from json import loads
with suppress(ImportError):
from markdownify import markdownify as md
import readability as rd
CHARSET_RE = re.compile(r"charset=([\w-]+)")
STREAM_END = object()
def clear_queue(q: queue.Queue):
with q.mutex:
q.queue.clear()
q.all_tasks_done.notify_all()
q.unfinished_tasks = 0
class Request:
"""Representing a sent request."""
def __init__(self, url: str, headers: Headers, method: str):
self.url = url
self.headers = headers
self.method = method
class Response:
"""Contains information the server sends.
Attributes:
url: url used in the request.
content: response body in bytes.
text: response body in str.
status_code: http status code.
reason: http response reason, such as OK, Not Found.
ok: is status_code in [200, 400)?
headers: response headers.
cookies: response cookies.
elapsed: how many seconds the request cost.
encoding: http body encoding.
charset: alias for encoding.
primary_ip: primary ip of the server.
primary_port: primary port of the server.
local_ip: local ip used in this connection.
local_port: local port used in this connection.
charset_encoding: encoding specified by the Content-Type header.
default_encoding: encoding for decoding response content if charset is not found
in headers. Defaults to "utf-8". Can be set to a callable for automatic
detection.
redirect_count: how many redirects happened.
redirect_url: the final redirected url.
http_version: http version used.
history: history redirections, only headers are available.
"""
def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = None):
self.curl = curl
self.request = request
self.url = ""
self.content = b""
self.status_code = 200
self.reason = "OK"
self.ok = True
self.headers = Headers()
self.cookies = Cookies()
self.elapsed = 0.0
self.default_encoding: Union[str, Callable[[bytes], str]] = "utf-8"
self.redirect_count = 0
self.redirect_url = ""
self.http_version = 0
self.primary_ip: str = ""
self.primary_port: int = 0
self.local_ip: str = ""
self.local_port: int = 0
self.history: list[dict[str, Any]] = []
self.infos: dict[str, Any] = {}
self.queue: Optional[queue.Queue] = None
self.stream_task: Optional[Future] = None
self.astream_task: Optional[Awaitable] = None
self.quit_now = None
@property
def charset(self) -> str:
"""Alias for encoding."""
return self.encoding
@property
def encoding(self) -> str:
"""
Determines the encoding to decode byte content into text.
The method follows a specific priority to decide the encoding:
1. If ``.encoding`` has been explicitly set, it is used.
2. The encoding specified by the ``charset`` parameter in the ``Content-Type``
header.
3. The encoding specified by the ``default_encoding`` attribute. This can either
be a string (e.g., "utf-8") or a callable for charset autodetection.
"""
if not hasattr(self, "_encoding"):
encoding = self.charset_encoding
if encoding is None:
if isinstance(self.default_encoding, str):
encoding = self.default_encoding
elif callable(self.default_encoding):
encoding = self.default_encoding(self.content)
self._encoding = encoding or "utf-8"
return self._encoding
@encoding.setter
def encoding(self, value: str) -> None:
if hasattr(self, "_text"):
raise ValueError("Cannot set encoding after text has been accessed")
self._encoding = value
@property
def charset_encoding(self) -> Optional[str]:
"""Return the encoding, as specified by the Content-Type header."""
content_type = self.headers.get("Content-Type")
if content_type:
charset_match = CHARSET_RE.search(content_type)
return charset_match.group(1) if charset_match else None
return None
@property
def text(self) -> str:
if not hasattr(self, "_text"):
if not self.content:
self._text = ""
else:
self._text = self._decode(self.content)
return self._text
def markdown(self) -> str:
doc = rd.Document(self.content)
title = doc.title()
summary = doc.summary(html_partial=True)
body_as_md = md(f"<h1>{title}</h1><main>{summary}</main>")
return body_as_md
def _decode(self, content: bytes) -> str:
try:
return content.decode(self.encoding, errors="replace")
except (UnicodeDecodeError, LookupError):
return content.decode("utf-8-sig")
def raise_for_status(self):
"""Raise an error if status code is not in [200, 400)"""
if not self.ok:
raise HTTPError(f"HTTP Error {self.status_code}: {self.reason}", 0, self)
def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
"""
iterate streaming content line by line, separated by ``\\n``.
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None
for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
if pending is not None:
chunk = pending + chunk
lines = chunk.split(delimiter) if delimiter else chunk.splitlines()
pending = (
lines.pop()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]
else None
)
yield from lines
if pending is not None:
yield pending
def iter_content(self, chunk_size=None, decode_unicode=False):
"""
iterate streaming content chunk by chunk in bytes.
"""
if chunk_size:
warnings.warn(
"chunk_size is ignored, there is no way to tell curl that.",
CurlCffiWarning,
stacklevel=2,
)
if decode_unicode:
raise NotImplementedError()
assert self.queue and self.curl, "stream mode is not enabled."
while True:
chunk = self.queue.get()
# re-raise the exception if something wrong happened.
if isinstance(chunk, RequestException):
self.curl.reset()
raise chunk
# end of stream.
if chunk is STREAM_END:
break
yield chunk
def json(self, **kw):
"""return a parsed json object of the content."""
return loads(self.content, **kw)
def close(self):
"""Close the streaming connection, only valid in stream mode."""
if self.quit_now:
self.quit_now.set()
if self.stream_task:
self.stream_task.result()
async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
"""
iterate streaming content line by line, separated by ``\\n``.
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None
async for chunk in self.aiter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
if pending is not None:
chunk = pending + chunk
lines = chunk.split(delimiter) if delimiter else chunk.splitlines()
pending = (
lines.pop()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]
else None
)
for line in lines:
yield line
if pending is not None:
yield pending
async def aiter_content(self, chunk_size=None, decode_unicode=False):
"""
iterate streaming content chunk by chunk in bytes.
"""
if chunk_size:
warnings.warn(
"chunk_size is ignored, there is no way to tell curl that.",
CurlCffiWarning,
stacklevel=2,
)
if decode_unicode:
raise NotImplementedError()
assert self.queue and self.curl, "stream mode is not enabled."
while True:
chunk = await self.queue.get()
# re-raise the exception if something wrong happened.
if isinstance(chunk, RequestException):
await self.aclose()
raise chunk
# end of stream.
if chunk is STREAM_END:
await self.aclose()
return
yield chunk
async def atext(self) -> str:
"""
Return a decoded string.
"""
return self._decode(await self.acontent())
async def acontent(self) -> bytes:
"""wait and read the streaming content in one bytes object."""
chunks = []
async for chunk in self.aiter_content():
chunks.append(chunk)
return b"".join(chunks)
async def aclose(self):
"""Close the streaming connection, only valid in stream mode."""
if self.astream_task:
await self.astream_task
# It prints the status code of the response instead of the object's memory location.
def __repr__(self) -> str:
return f"<Response [{self.status_code}]>"
@@ -0,0 +1,698 @@
from __future__ import annotations
__all__ = ["HttpVersionLiteral", "set_curl_options", "not_set"]
import asyncio
import math
import queue
import warnings
from collections import Counter
from io import BytesIO
from json import dumps
from typing import TYPE_CHECKING, Any, Callable, Final, Literal, Optional, Union, cast
from urllib.parse import ParseResult, parse_qsl, quote, urlencode, urljoin, urlparse
from ..const import CurlHttpVersion, CurlOpt, CurlSslVersion
from ..curl import CURL_WRITEFUNC_ERROR, CurlMime
from ..utils import CurlCffiWarning
from .cookies import Cookies
from .exceptions import ImpersonateError, InvalidURL
from .headers import Headers
from .impersonate import (
TLS_CIPHER_NAME_MAP,
TLS_EC_CURVES_MAP,
TLS_VERSION_MAP,
ExtraFingerprints,
normalize_browser_type,
toggle_extension,
)
from .models import Request
if TYPE_CHECKING:
from ..curl import Curl
from .cookies import CookieTypes
from .headers import HeaderTypes
from .impersonate import BrowserTypeLiteral, ExtraFpDict
from .session import ProxySpec
HttpMethod = Literal[
"GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "TRACE", "PATCH", "QUERY"
]
HttpVersionLiteral = Literal["v1", "v2", "v2tls", "v2_prior_knowledge", "v3", "v3only"]
SAFE_CHARS = set("!#$%&'()*+,/:;=?@[]~")
not_set: Final[Any] = object()
# ruff: noqa: SIM116
def normalize_http_version(
version: Union[CurlHttpVersion, HttpVersionLiteral],
) -> CurlHttpVersion:
if version == "v1":
return CurlHttpVersion.V1_1
elif version == "v3":
return CurlHttpVersion.V3
elif version == "v3only":
return CurlHttpVersion.V3ONLY
elif version == "v2":
return CurlHttpVersion.V2_0
elif version == "v2tls":
return CurlHttpVersion.V2TLS
elif version == "v2_prior_knowledge":
return CurlHttpVersion.V2_PRIOR_KNOWLEDGE
return version # type: ignore
def is_absolute_url(url: str) -> bool:
"""Check if the provided url is an absolute url"""
parsed_url = urlparse(url)
return bool(parsed_url.scheme and parsed_url.hostname)
def quote_path_and_params(url: str, quote_str: str = ""):
safe = "".join(SAFE_CHARS - set(quote_str))
parsed_url = urlparse(url)
parsed_get_args = parse_qsl(parsed_url.query, keep_blank_values=True)
encoded_get_args = urlencode(parsed_get_args, doseq=True, safe=safe)
return ParseResult(
parsed_url.scheme,
parsed_url.netloc,
quote(parsed_url.path, safe=safe),
parsed_url.params,
encoded_get_args,
parsed_url.fragment,
).geturl()
def update_url_params(url: str, params: Union[dict, list, tuple]) -> str:
"""Add URL query params to provided URL being aware of existing.
Args:
url: string of target URL
params: dict containing requested params to be added
Returns:
string with updated URL
>> url = 'http://stackoverflow.com/test?answers=true'
>> new_params = {'answers': False, 'data': ['some','values']}
>> update_url_params(url, new_params)
'http://stackoverflow.com/test?data=some&data=values&answers=false'
"""
# No need to unquote, since requote_uri will be called later.
parsed_url = urlparse(url)
# Extracting URL arguments from parsed URL, NOTE the result is a list, not dict
parsed_get_args = parse_qsl(parsed_url.query, keep_blank_values=True)
# Merging URL arguments dict with new params
old_args_counter = Counter(x[0] for x in parsed_get_args)
if isinstance(params, dict):
params = list(params.items())
new_args_counter = Counter(x[0] for x in params)
for key, value in params:
# Bool and Dict values should be converted to json-friendly values
if isinstance(value, (bool, dict)):
value = dumps(value)
# 1 to 1 mapping, we have to search and update it.
if old_args_counter.get(key) == 1 and new_args_counter.get(key) == 1:
parsed_get_args = [
(x if x[0] != key else (key, value)) for x in parsed_get_args
]
else:
parsed_get_args.append((key, value))
# Converting URL argument to proper query string
encoded_get_args = urlencode(parsed_get_args, doseq=True)
# Creating new parsed result object based on provided with new
# URL arguments. Same thing happens inside of urlparse.
new_url = ParseResult(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
encoded_get_args,
parsed_url.fragment,
).geturl()
return new_url
# Adapted from: https://github.com/psf/requests/blob/1ae6fc3137a11e11565ed22436aa1e77277ac98c/src%2Frequests%2Futils.py#L633-L682
# License: Apache 2.0
# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)
def unquote_unreserved(uri: str) -> str:
"""Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
"""
parts = uri.split("%")
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError as e:
raise InvalidURL(f"Invalid percent-escape sequence: '{h}'") from e
if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = f"%{parts[i]}"
else:
parts[i] = f"%{parts[i]}"
return "".join(parts)
def requote_uri(uri: str) -> str:
"""Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~|"
safe_without_percent = "!#$&'()*+,/:;=?@[]~|"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)
# TODO: should we move this function to headers.py?
def update_header_line(
header_lines: list[str], key: str, value: str, replace: bool = False
):
"""Update header line list by key value pair."""
found = False
for idx, line in enumerate(header_lines):
if line.lower().startswith(key.lower() + ":"):
found = True
if replace:
header_lines[idx] = f"{key}: {value}"
break
if not found:
header_lines.append(f"{key}: {value}")
def peek_queue(q: queue.Queue, default=None):
try:
return q.queue[0]
except IndexError:
return default
def peek_aio_queue(q: asyncio.Queue, default=None):
try:
return q._queue[0] # type: ignore
except IndexError:
return default
def toggle_extensions_by_ids(curl: Curl, extension_ids):
# TODO: find a better representation, rather than magic numbers
default_enabled = {0, 10, 11, 13, 16, 23, 35, 43, 45, 51, 65281}
to_enable_ids = extension_ids - default_enabled
for ext_id in to_enable_ids:
toggle_extension(curl, ext_id, enable=True)
# print("to_enable: ", to_enable_ids)
to_disable_ids = default_enabled - extension_ids
for ext_id in to_disable_ids:
toggle_extension(curl, ext_id, enable=False)
# print("to_disable: ", to_disable_ids)
def set_ja3_options(curl: Curl, ja3: str, permute: bool = False):
"""
Detailed explanation: https://engineering.salesforce.com/tls-fingerprinting-with-ja3-and-ja3s-247362855967/
"""
tls_version, ciphers, extensions, curves, curve_formats = ja3.split(",")
curl_tls_version = TLS_VERSION_MAP[int(tls_version)]
curl.setopt(CurlOpt.SSLVERSION, curl_tls_version | CurlSslVersion.MAX_DEFAULT)
assert curl_tls_version == CurlSslVersion.TLSv1_2, "Only TLS v1.2 works for now."
cipher_names = []
for cipher in ciphers.split("-"):
cipher_id = int(cipher)
cipher_name = TLS_CIPHER_NAME_MAP.get(cipher_id)
if not cipher_name:
raise ImpersonateError(f"Cipher {hex(cipher_id)} is not found")
cipher_names.append(cipher_name)
curl.setopt(CurlOpt.SSL_CIPHER_LIST, ":".join(cipher_names))
if extensions.endswith("-21"):
extensions = extensions[:-3]
warnings.warn(
"Padding(21) extension found in ja3 string, whether to add it should "
"be managed by the SSL engine. The TLS client hello packet may contain "
"or not contain this extension, any of which should be correct.",
CurlCffiWarning,
stacklevel=1,
)
extension_ids = set(int(e) for e in extensions.split("-"))
toggle_extensions_by_ids(curl, extension_ids)
if not permute:
curl.setopt(CurlOpt.TLS_EXTENSION_ORDER, extensions)
curve_names = []
for curve in curves.split("-"):
curve_id = int(curve)
curve_name = TLS_EC_CURVES_MAP[curve_id]
curve_names.append(curve_name)
curl.setopt(CurlOpt.SSL_EC_CURVES, ":".join(curve_names))
assert int(curve_formats) == 0, "Only curve_formats == 0 is supported."
def set_akamai_options(curl: Curl, akamai: str):
"""
Detailed explanation: https://www.blackhat.com/docs/eu-17/materials/eu-17-Shuster-Passive-Fingerprinting-Of-HTTP2-Clients-wp.pdf
"""
settings, window_update, streams, header_order = akamai.split("|")
# For compatiblity with tls.peet.ws
settings = settings.replace(",", ";")
curl.setopt(CurlOpt.HTTP_VERSION, CurlHttpVersion.V2_0)
curl.setopt(CurlOpt.HTTP2_SETTINGS, settings)
curl.setopt(CurlOpt.HTTP2_WINDOW_UPDATE, int(window_update))
if streams != "0":
curl.setopt(CurlOpt.HTTP2_STREAMS, streams)
# m,a,s,p -> masp
# curl-impersonate only accepts masp format, without commas.
curl.setopt(CurlOpt.HTTP2_PSEUDO_HEADERS_ORDER, header_order.replace(",", ""))
def set_extra_fp(curl: Curl, fp: ExtraFingerprints):
if fp.tls_signature_algorithms:
curl.setopt(CurlOpt.SSL_SIG_HASH_ALGS, ",".join(fp.tls_signature_algorithms))
curl.setopt(CurlOpt.SSLVERSION, fp.tls_min_version | CurlSslVersion.MAX_DEFAULT)
curl.setopt(CurlOpt.TLS_GREASE, int(fp.tls_grease))
curl.setopt(CurlOpt.SSL_PERMUTE_EXTENSIONS, int(fp.tls_permute_extensions))
curl.setopt(CurlOpt.SSL_CERT_COMPRESSION, fp.tls_cert_compression)
curl.setopt(CurlOpt.STREAM_WEIGHT, fp.http2_stream_weight)
curl.setopt(CurlOpt.STREAM_EXCLUSIVE, fp.http2_stream_exclusive)
if fp.tls_delegated_credential:
curl.setopt(CurlOpt.TLS_DELEGATED_CREDENTIALS, fp.tls_delegated_credential)
if fp.tls_record_size_limit:
curl.setopt(CurlOpt.TLS_RECORD_SIZE_LIMIT, fp.tls_record_size_limit)
if fp.http2_no_priority:
curl.setopt(CurlOpt.HTTP2_NO_PRIORITY, fp.http2_no_priority)
def set_curl_options(
curl: Curl,
method: HttpMethod,
url: str,
*,
params_list: list[Union[dict, list, tuple, None]] = [], # noqa: B006
base_url: Optional[str] = None,
data: Optional[Union[dict[str, str], list[tuple], str, BytesIO, bytes]] = None,
json: Optional[dict | list] = None,
headers_list: list[Optional[HeaderTypes]] = [], # noqa: B006
cookies_list: list[Optional[CookieTypes]] = [], # noqa: B006
files: Optional[dict] = None,
auth: Optional[tuple[str, str]] = None,
timeout: Optional[Union[float, tuple[float, float], object]] = not_set,
allow_redirects: Optional[bool] = True,
max_redirects: Optional[int] = 30,
proxies_list: list[Optional[ProxySpec]] = [], # noqa: B006
proxy: Optional[str] = None,
proxy_auth: Optional[tuple[str, str]] = None,
verify_list: list[Union[bool, str, None]] = [], # noqa: B006
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br, zstd",
content_callback: Optional[Callable] = None,
impersonate: Optional[Union[BrowserTypeLiteral, str]] = None,
ja3: Optional[str] = None,
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: bool = True,
quote: Union[str, Literal[False]] = "",
http_version: Optional[Union[CurlHttpVersion, HttpVersionLiteral]] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, tuple[str, str]]] = None,
stream: Optional[bool] = None,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
queue_class: Any = None,
event_class: Any = None,
curl_options: Optional[dict[CurlOpt, str]] = None,
):
c = curl
method = method.upper() # type: ignore
# method
if method == "POST":
c.setopt(CurlOpt.POST, 1)
elif method != "GET":
c.setopt(CurlOpt.CUSTOMREQUEST, method.encode())
if method == "HEAD":
c.setopt(CurlOpt.NOBODY, 1)
# url
base_params, params = params_list
if base_params:
url = update_url_params(url, base_params)
if params:
url = update_url_params(url, params)
if base_url:
url = urljoin(base_url, url)
if quote:
url = quote_path_and_params(url, quote_str=quote)
if quote is not False:
url = requote_uri(url)
c.setopt(CurlOpt.URL, url.encode())
# data/body/json
if isinstance(data, (dict, list, tuple)):
body = urlencode(data).encode()
elif isinstance(data, str):
body = data.encode()
elif isinstance(data, BytesIO):
body = data.read()
elif isinstance(data, bytes):
body = data
elif data is None:
body = b""
else:
raise TypeError("data must be dict/list/tuple, str, BytesIO or bytes")
if json is not None:
body = dumps(json, separators=(",", ":")).encode()
# Tell libcurl to be aware of bodies and related headers when,
# 1. POST/PUT/PATCH, even if the body is empty, it's up to curl to decide what to do
# 2. GET/DELETE with body, although it's against the RFC, some applications.
# e.g. Elasticsearch, use this.
if body or method in ("POST", "PUT", "PATCH"):
c.setopt(CurlOpt.POSTFIELDS, body)
# necessary if body contains '\0'
c.setopt(CurlOpt.POSTFIELDSIZE, len(body))
if method == "GET":
c.setopt(CurlOpt.CUSTOMREQUEST, method)
# headers
base_headers, headers = headers_list
# let headers encoding take precedence over base headers encoding
encoding = headers.encoding if isinstance(headers, Headers) else None
h = Headers(base_headers, encoding=encoding)
h.update(headers)
# remove Host header if it's unnecessary, otherwise curl may get confused.
# Host header will be automatically added by curl if it's not present.
# https://github.com/lexiforest/curl_cffi/issues/119
host_header = h.get("Host")
if host_header is not None:
u = urlparse(url)
if host_header == u.netloc or host_header == u.hostname:
h.pop("Host", None)
# Make curl always include empty headers.
# See: https://stackoverflow.com/a/32911474/1061155
header_lines = []
for k, v in h.multi_items():
if v is None:
header_lines.append(f"{k}:") # Explictly disable this header
elif v == "":
header_lines.append(f"{k};") # Add an empty valued header
else:
header_lines.append(f"{k}: {v}")
# Add content-type if missing
if json is not None:
update_header_line(header_lines, "Content-Type", "application/json")
if isinstance(data, dict) and method != "POST":
update_header_line(
header_lines, "Content-Type", "application/x-www-form-urlencoded"
)
if isinstance(data, (str, bytes)):
update_header_line(header_lines, "Content-Type", "application/octet-stream")
# Never send `Expect` header.
update_header_line(header_lines, "Expect", "", replace=True)
c.setopt(CurlOpt.HTTPHEADER, [h.encode() for h in header_lines])
req = Request(url, h, method)
# cookies
c.setopt(CurlOpt.COOKIEFILE, b"") # always enable the curl cookie engine first
c.setopt(CurlOpt.COOKIELIST, "ALL") # remove all the old cookies first.
base_cookies, cookies = cookies_list
if base_cookies:
for morsel in base_cookies.get_cookies_for_curl(req): # type: ignore
curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format())
if cookies:
temp_cookies = Cookies(cookies)
for morsel in temp_cookies.get_cookies_for_curl(req):
curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format())
# files
if files:
raise NotImplementedError(
"files is not supported, use `multipart`. See examples here: "
"https://github.com/lexiforest/curl_cffi/blob/main/examples/upload.py"
)
# multipart
if multipart:
# multipart will overrides postfields
for k, v in cast(dict, data or {}).items():
multipart.addpart(name=k, data=v.encode() if isinstance(v, str) else v)
c.setopt(CurlOpt.MIMEPOST, multipart._form)
# auth
if auth:
username, password = auth
c.setopt(CurlOpt.USERNAME, username.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none]
c.setopt(CurlOpt.PASSWORD, password.encode()) # pyright: ignore [reportPossiblyUnboundVariable=none]
# timeout
if timeout is None:
timeout = 0 # indefinitely
if isinstance(timeout, tuple):
connect_timeout, read_timeout = timeout
all_timeout = connect_timeout + read_timeout
c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(connect_timeout * 1000))
if not stream:
c.setopt(CurlOpt.TIMEOUT_MS, int(all_timeout * 1000))
else:
# trick from: https://github.com/lexiforest/curl_cffi/issues/156
c.setopt(CurlOpt.LOW_SPEED_LIMIT, 1)
c.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(all_timeout))
elif isinstance(timeout, (int, float)):
if not stream:
c.setopt(CurlOpt.TIMEOUT_MS, int(timeout * 1000))
else:
c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(timeout * 1000))
c.setopt(CurlOpt.LOW_SPEED_LIMIT, 1)
c.setopt(CurlOpt.LOW_SPEED_TIME, math.ceil(timeout))
# allow_redirects
c.setopt(CurlOpt.FOLLOWLOCATION, int(allow_redirects)) # type: ignore
# max_redirects
c.setopt(CurlOpt.MAXREDIRS, max_redirects)
# proxies
base_proxies, proxies = proxies_list
if proxy and proxies:
raise TypeError("Cannot specify both 'proxy' and 'proxies'")
if proxy:
proxies = {"all": proxy}
if proxies is None:
proxies = base_proxies
if proxies:
# Turn on proxy_credential_no_reuse, which has the following benefits:
# 1. New connection will be made when proxy username changed
# 2. New TLS session will be created based on proxy address, i.e. when accessing
# the same site with different proxies, TLS session won't leak previous IP.
c.setopt(CurlOpt.PROXY_CREDENTIAL_NO_REUSE, 1)
parts = urlparse(url)
proxy = cast(Optional[str], proxies.get(parts.scheme, proxies.get("all")))
if parts.hostname:
proxy = (
proxies.get( # type: ignore
f"{parts.scheme}://{parts.hostname}",
proxies.get(f"all://{parts.hostname}"),
)
or proxy
)
if proxy is not None:
c.setopt(CurlOpt.PROXY, proxy)
if parts.scheme == "https":
if proxy.startswith("https://"):
warnings.warn(
"Make sure you are using https over https proxy, otherwise, "
"the proxy prefix should be 'http://' not 'https://', "
"see: https://github.com/lexiforest/curl_cffi/issues/6",
CurlCffiWarning,
stacklevel=2,
)
# For https site with http tunnel proxy, tell curl to enable tunneling
if not proxy.startswith("socks"):
c.setopt(CurlOpt.HTTPPROXYTUNNEL, 1)
# proxy_auth
if proxy_auth:
username, password = proxy_auth
c.setopt(CurlOpt.PROXYUSERNAME, username.encode())
c.setopt(CurlOpt.PROXYPASSWORD, password.encode())
# verify
base_verify, verify = verify_list
if verify is False or not base_verify and verify is None:
c.setopt(CurlOpt.SSL_VERIFYPEER, 0)
c.setopt(CurlOpt.SSL_VERIFYHOST, 0)
# cert for this single request
if isinstance(verify, str):
c.setopt(CurlOpt.CAINFO, verify)
# cert for the session
if verify in (None, True) and isinstance(base_verify, str):
c.setopt(CurlOpt.CAINFO, base_verify)
# referer
if referer:
c.setopt(CurlOpt.REFERER, referer.encode())
# accept_encoding
if accept_encoding is not None:
c.setopt(CurlOpt.ACCEPT_ENCODING, accept_encoding.encode())
# cert
if cert:
if isinstance(cert, str):
c.setopt(CurlOpt.SSLCERT, cert)
else:
cert, key = cert
c.setopt(CurlOpt.SSLCERT, cert)
c.setopt(CurlOpt.SSLKEY, key)
# impersonate
if impersonate:
impersonate = normalize_browser_type(impersonate)
ret = c.impersonate(impersonate, default_headers=default_headers) # type: ignore
if ret != 0:
raise ImpersonateError(f"Impersonating {impersonate} is not supported")
# extra_fp options
if extra_fp:
if isinstance(extra_fp, dict):
extra_fp = ExtraFingerprints(**extra_fp)
if impersonate:
warnings.warn(
"Extra fingerprints was altered after impersonated version was set.",
CurlCffiWarning,
stacklevel=1,
)
set_extra_fp(c, extra_fp)
# ja3 string
if ja3:
if impersonate:
warnings.warn(
"JA3 fingerprint was altered after impersonated version was set.",
CurlCffiWarning,
stacklevel=1,
)
permute = False
if isinstance(extra_fp, ExtraFingerprints) and extra_fp.tls_permute_extensions:
permute = True
if isinstance(extra_fp, dict) and extra_fp.get("tls_permute_extensions"):
permute = True
set_ja3_options(c, ja3, permute=permute)
# akamai string
if akamai:
if impersonate:
warnings.warn(
"Akamai fingerprint was altered after impersonated version was set.",
CurlCffiWarning,
stacklevel=1,
)
set_akamai_options(c, akamai)
# http_version, after impersonate, which will change this to http2
if http_version:
http_version = normalize_http_version(http_version)
c.setopt(CurlOpt.HTTP_VERSION, http_version)
buffer = None
q = None
header_recved = None
quit_now = None
if stream:
q = queue_class()
header_recved = event_class()
quit_now = event_class()
def qput(chunk):
if not header_recved.is_set():
header_recved.set()
if quit_now.is_set():
return CURL_WRITEFUNC_ERROR
q.put_nowait(chunk)
return len(chunk)
c.setopt(CurlOpt.WRITEFUNCTION, qput)
elif content_callback is not None:
c.setopt(CurlOpt.WRITEFUNCTION, content_callback)
else:
buffer = BytesIO()
c.setopt(CurlOpt.WRITEDATA, buffer)
header_buffer = BytesIO()
c.setopt(CurlOpt.HEADERDATA, header_buffer)
# interface
if interface:
c.setopt(CurlOpt.INTERFACE, interface.encode())
# max_recv_speed
# do not check, since 0 is a valid value to disable it
c.setopt(CurlOpt.MAX_RECV_SPEED_LARGE, max_recv_speed)
# set extra options, after all others, because it will alter some options
if curl_options:
for option, setting in curl_options.items():
c.setopt(option, setting)
return req, buffer, header_buffer, q, header_recved, quit_now
@@ -0,0 +1,839 @@
from __future__ import annotations
import asyncio
import struct
from enum import IntEnum
from functools import partial
from json import dumps, loads
from select import select
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
TypeVar,
Union,
)
import warnings
from curl_cffi.utils import CurlCffiWarning
from ..aio import CURL_SOCKET_BAD, get_selector
from ..const import CurlECode, CurlInfo, CurlOpt, CurlWsFlag
from ..curl import Curl, CurlError
from .exceptions import SessionClosed, Timeout
from .utils import not_set, set_curl_options
if TYPE_CHECKING:
from typing_extensions import Self
from ..const import CurlHttpVersion
from ..curl import CurlWsFrame
from .cookies import CookieTypes
from .headers import HeaderTypes
from .impersonate import BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .session import AsyncSession, ProxySpec
T = TypeVar("T")
ON_DATA_T = Callable[["WebSocket", bytes, CurlWsFrame], None]
ON_MESSAGE_T = Callable[["WebSocket", Union[bytes, str]], None]
ON_ERROR_T = Callable[["WebSocket", CurlError], None]
ON_OPEN_T = Callable[["WebSocket"], None]
ON_CLOSE_T = Callable[["WebSocket", int, str], None]
# We need a partial for dumps() because a custom function may not accept the parameter
dumps = partial(dumps, separators=(",", ":"))
class WsCloseCode(IntEnum):
"""See: https://www.iana.org/assignments/websocket/websocket.xhtml"""
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
UNKNOWN = 1005
ABNORMAL_CLOSURE = 1006
INVALID_DATA = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
TLS_HANDSHAKE = 1015
UNAUTHORIZED = 3000
FORBIDDEN = 3003
TIMEOUT = 3008
class WebSocketError(CurlError):
"""WebSocket-specific error."""
def __init__(
self, message: str, code: Union[WsCloseCode, CurlECode, Literal[0]] = 0
):
super().__init__(message, code) # type: ignore
class WebSocketClosed(WebSocketError, SessionClosed):
"""WebSocket is already closed."""
class WebSocketTimeout(WebSocketError, Timeout):
"""WebSocket operation timed out."""
async def aselect(
fd,
mode: Literal["read", "write"] = "read",
*,
loop: asyncio.AbstractEventLoop,
timeout: Optional[float] = None,
) -> bool:
future = loop.create_future()
if mode == "read":
loop.add_reader(fd, future.set_result, None)
future.add_done_callback(lambda _: loop.remove_reader(fd))
elif mode == "write":
loop.add_writer(fd, future.set_result, None)
future.add_done_callback(lambda _: loop.remove_writer(fd))
else:
raise ValueError(f"Invalid mode: {mode}. Must be 'read' or 'write'")
try:
await asyncio.wait_for(future, timeout)
except asyncio.TimeoutError:
return False
return True
class BaseWebSocket:
def __init__(self, curl: Curl, *, autoclose: bool = True, debug: bool = False):
self._curl: Curl = curl
self.autoclose: bool = autoclose
self._close_code: Optional[int] = None
self._close_reason: Optional[str] = None
self.debug = debug
self.closed = False
@property
def curl(self):
if self._curl is not_set:
self._curl = Curl(debug=self.debug)
return self._curl
@property
def close_code(self) -> Optional[int]:
"""The WebSocket close code, if the connection has been closed."""
return self._close_code
@property
def close_reason(self) -> Optional[str]:
"""The WebSocket close reason, if the connection has been closed."""
return self._close_reason
@staticmethod
def _pack_close_frame(code: int, reason: bytes) -> bytes:
return struct.pack("!H", code) + reason
@staticmethod
def _unpack_close_frame(frame: bytes) -> tuple[int, str]:
if len(frame) < 2:
code = WsCloseCode.UNKNOWN
reason = ""
else:
try:
code = struct.unpack_from("!H", frame)[0]
reason = frame[2:].decode()
except UnicodeDecodeError as e:
raise WebSocketError(
"Invalid close message", WsCloseCode.INVALID_DATA
) from e
except Exception as e:
raise WebSocketError(
"Invalid close frame", WsCloseCode.PROTOCOL_ERROR
) from e
else:
if (
code not in WsCloseCode._value2member_map_
or code == WsCloseCode.UNKNOWN
):
raise WebSocketError(
f"Invalid close code: {code}", WsCloseCode.PROTOCOL_ERROR
)
return code, reason
def terminate(self):
"""Terminate the underlying connection."""
self.closed = True
self.curl.close()
EventTypeLiteral = Literal["open", "close", "data", "message", "error"]
class WebSocket(BaseWebSocket):
"""A WebSocket implementation using libcurl."""
def __init__(
self,
curl: Union[Curl, Any] = not_set,
*,
autoclose: bool = True,
skip_utf8_validation: bool = False,
debug: bool = False,
on_open: Optional[ON_OPEN_T] = None,
on_close: Optional[ON_CLOSE_T] = None,
on_data: Optional[ON_DATA_T] = None,
on_message: Optional[ON_MESSAGE_T] = None,
on_error: Optional[ON_ERROR_T] = None,
):
"""
Args:
autoclose: whether to close the WebSocket after receiving a close frame.
skip_utf8_validation: whether to skip UTF-8 validation for text frames in
run_forever().
debug: print extra curl debug info.
on_open: open callback, ``def on_open(ws)``
on_close: close callback, ``def on_close(ws, code, reason)``
on_data: raw data receive callback, ``def on_data(ws, data, frame)``
on_message: message receive callback, ``def on_message(ws, message)``
on_error: error callback, ``def on_error(ws, exception)``
"""
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.skip_utf8_validation = skip_utf8_validation
self._emitters: dict[EventTypeLiteral, Callable] = {}
if on_open:
self._emitters["open"] = on_open
if on_close:
self._emitters["close"] = on_close
if on_data:
self._emitters["data"] = on_data
if on_message:
self._emitters["message"] = on_message
if on_error:
self._emitters["error"] = on_error
def __iter__(self) -> WebSocket:
if self.closed:
raise WebSocketClosed("WebSocket is closed")
return self
def __next__(self) -> bytes:
msg, flags = self.recv()
if flags & CurlWsFlag.CLOSE:
raise StopIteration
return msg
def _emit(self, event_type: EventTypeLiteral, *args) -> None:
callback = self._emitters.get(event_type)
if callback:
try:
callback(self, *args)
except Exception as e:
error_callback = self._emitters.get("error")
if error_callback:
error_callback(self, e)
else:
warnings.warn(
f"WebSocket callback '{event_type}' failed",
CurlCffiWarning,
stacklevel=2,
)
def connect(
self,
url: str,
params: Optional[Union[dict, list, tuple]] = None,
headers: Optional[HeaderTypes] = None,
cookies: Optional[CookieTypes] = None,
auth: Optional[tuple[str, str]] = None,
timeout: Optional[Union[float, tuple[float, float], object]] = not_set,
allow_redirects: bool = True,
max_redirects: int = 30,
proxies: Optional[ProxySpec] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[tuple[str, str]] = None,
verify: Optional[bool] = None,
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br",
impersonate: Optional[BrowserTypeLiteral] = None,
ja3: Optional[str] = None,
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
default_headers: bool = True,
quote: Union[str, Literal[False]] = "",
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, tuple[str, str]]] = None,
max_recv_speed: int = 0,
curl_options: Optional[dict[CurlOpt, str]] = None,
):
"""Connect to the WebSocket.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
Args:
url: url for the requests.
params: query string for the requests.
headers: headers to send.
cookies: cookies to use.
auth: HTTP basic auth, a tuple of (username, password), only basic auth is
supported.
timeout: how many seconds to wait before giving up.
allow_redirects: whether to allow redirection.
max_redirects: max redirect counts, default 30, use -1 for unlimited.
proxies: dict of proxies to use, prefer to use ``proxy`` if they are the
same. format: ``{"http": proxy_url, "https": proxy_url}``.
proxy: proxy to use, format: "http://user@pass:proxy_url".
Can't be used with `proxies` parameter.
proxy_auth: HTTP basic auth for proxy, a tuple of (username, password).
verify: whether to verify https certs.
referer: shortcut for setting referer header.
accept_encoding: shortcut for setting accept-encoding header.
impersonate: which browser version to impersonate.
ja3: ja3 string to impersonate.
akamai: akamai string to impersonate.
extra_fp: extra fingerprints options, in complement to ja3 and akamai str.
default_headers: whether to set default browser headers.
default_encoding: encoding for decoding response content if charset is not
found in headers. Defaults to "utf-8". Can be set to a callable for
automatic detection.
quote: Set characters to be quoted, i.e. percent-encoded. Default safe
string is ``!#$%&'()*+,/:;=?@[]~``. If set to a sting, the character
will be removed from the safe string, thus quoted. If set to False, the
url will be kept as is, without any automatic percent-encoding, you must
encode the URL yourself.
curl_options: extra curl options to use.
http_version: limiting http version, defaults to http2.
interface: which interface to use.
cert: a tuple of (cert, key) filenames for client cert.
max_recv_speed: maximum receive speed, bytes per second.
curl_options: extra curl options to use.
"""
curl = self.curl
set_curl_options(
curl=curl,
method="GET",
url=url,
params_list=[None, params],
headers_list=[None, headers],
cookies_list=[None, cookies],
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
proxies_list=[None, proxies],
proxy=proxy,
proxy_auth=proxy_auth,
verify_list=[None, verify],
referer=referer,
accept_encoding=accept_encoding,
impersonate=impersonate,
ja3=ja3,
akamai=akamai,
extra_fp=extra_fp,
default_headers=default_headers,
quote=quote,
http_version=http_version,
interface=interface,
max_recv_speed=max_recv_speed,
cert=cert,
curl_options=curl_options,
)
# Magic number defined in: https://curl.se/docs/websocket.html
curl.setopt(CurlOpt.CONNECT_ONLY, 2)
curl.perform()
return self
def recv_fragment(self) -> tuple[bytes, CurlWsFrame]:
"""Receive a single curl websocket fragment as bytes."""
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
chunk, frame = self.curl.ws_recv()
if frame.flags & CurlWsFlag.CLOSE:
try:
self._close_code, self._close_reason = self._unpack_close_frame(chunk)
except WebSocketError as e:
# Follow the spec to close the connection
# Errors do not respect autoclose
self._close_code = e.code
self.close(e.code)
raise
if self.autoclose:
self.close()
return chunk, frame
def recv(self) -> tuple[bytes, int]:
"""
Receive a frame as bytes. libcurl splits frames into fragments, so we have to
collect all the chunks for a frame.
"""
chunks = []
flags = 0
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
while True:
try:
# Try to receive the first fragment first
chunk, frame = self.recv_fragment()
flags = frame.flags
chunks.append(chunk)
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
break
except CurlError as e:
if e.code == CurlECode.AGAIN:
# According to https://curl.se/libcurl/c/curl_ws_recv.html
# > in real application: wait for socket here, e.g. using select()
_, _, _ = select([sock_fd], [], [], 0.5)
else:
raise
return b"".join(chunks), flags
def recv_str(self) -> str:
"""Receive a text frame."""
data, flags = self.recv()
if not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Not valid text frame", WsCloseCode.INVALID_DATA)
return data.decode()
def recv_json(self, *, loads: Callable[[str], T] = loads) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is json.loads.
"""
data = self.recv_str()
return loads(data)
def send(self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.BINARY):
"""Send a data frame.
Args:
payload: data to send.
flags: flags for the frame.
"""
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
if self.closed:
raise WebSocketClosed("WebSocket is already closed")
# curl expects bytes
if isinstance(payload, str):
payload = payload.encode()
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
# Loop checks for CurlECode.Again
# https://curl.se/libcurl/c/curl_ws_send.html
offset = 0
while offset < len(payload):
current_buffer = payload[offset:]
try:
n_sent = self.curl.ws_send(current_buffer, flags)
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, writeable, _ = select([], [sock_fd], [], 0.5)
if not writeable:
raise WebSocketError("Socket write timeout") from e
continue
raise
offset += n_sent
return offset
def send_binary(self, payload: bytes):
"""Send a binary frame.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
def send_bytes(self, payload: bytes):
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
"""
return self.send(payload, CurlWsFlag.BINARY)
def send_str(self, payload: str):
"""Send a text frame.
Args:
payload: text data to send.
"""
return self.send(payload, CurlWsFlag.TEXT)
def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps):
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is json.dumps.
"""
return self.send_str(dumps(payload))
def ping(self, payload: Union[str, bytes]):
"""Send a ping frame.
Args:
payload: data to send.
"""
return self.send(payload, CurlWsFlag.PING)
def run_forever(self, url: str = "", **kwargs):
"""Run the WebSocket forever. See :meth:`connect` for details on parameters.
libcurl automatically handles pings and pongs.
ref: https://curl.se/libcurl/c/libcurl-ws.html
"""
if url:
self.connect(url, **kwargs)
sock_fd = self.curl.getinfo(CurlInfo.ACTIVESOCKET)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
self._emit("open")
# Keep reading the messages and invoke callbacks
# TODO: Reconnect logic
chunks = []
self.keep_running = True
while self.keep_running:
try:
chunk, frame = self.recv_fragment()
flags = frame.flags
self._emit("data", chunk, frame)
chunks.append(chunk)
if not (frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0):
continue
# Avoid unnecessary computation
if "message" in self._emitters:
# Concatenate collected chunks with the final message
msg = b"".join(chunks)
if (flags & CurlWsFlag.TEXT) and not self.skip_utf8_validation:
try:
msg = msg.decode() # type: ignore
except UnicodeDecodeError as e:
self._close_code = WsCloseCode.INVALID_DATA
self.close(WsCloseCode.INVALID_DATA)
raise WebSocketError(
"Invalid UTF-8", WsCloseCode.INVALID_DATA
) from e
if (flags & CurlWsFlag.BINARY) or (flags & CurlWsFlag.TEXT):
self._emit("message", msg)
chunks = [] # Reset chunks for next message
if flags & CurlWsFlag.CLOSE:
self.keep_running = False
self._emit("close", self._close_code or 0, self._close_reason or "")
except CurlError as e:
if e.code == CurlECode.AGAIN:
_, _, _ = select([sock_fd], [], [], 0.5)
else:
self._emit("error", e)
if not self.closed:
code = WsCloseCode.UNKNOWN
if isinstance(e, WebSocketError):
code = e.code
self.close(code)
raise
def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
"""Close the connection.
Args:
code: close code.
message: close reason.
"""
if self.curl is not_set:
return
# TODO: As per spec, we should wait for the server to close the connection
# But this is not a requirement
msg = self._pack_close_frame(code, message)
self.send(msg, CurlWsFlag.CLOSE)
# The only way to close the connection appears to be curl_easy_cleanup
self.terminate()
class AsyncWebSocket(BaseWebSocket):
"""An async WebSocket implementation using libcurl."""
def __init__(
self,
session: AsyncSession,
curl: Curl,
*,
autoclose: bool = True,
debug: bool = False,
):
super().__init__(curl=curl, autoclose=autoclose, debug=debug)
self.session = session
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._recv_lock = asyncio.Lock()
self._send_lock = asyncio.Lock()
@property
def loop(self):
if self._loop is None:
self._loop = get_selector(asyncio.get_running_loop())
return self._loop
def __aiter__(self) -> Self:
if self.closed:
raise WebSocketClosed("WebSocket has been closed")
return self
async def __anext__(self) -> bytes:
msg, flags = await self.recv()
if flags & CurlWsFlag.CLOSE:
raise StopAsyncIteration
return msg
async def recv_fragment(
self, *, timeout: Optional[float] = None
) -> tuple[bytes, CurlWsFrame]:
"""Receive a single frame as bytes.
Args:
timeout: how many seconds to wait before giving up.
"""
if self.closed:
raise WebSocketClosed("WebSocket is closed")
if self._recv_lock.locked():
raise TypeError("Concurrent call to recv_fragment() is not allowed")
async with self._recv_lock:
try:
chunk, frame = await asyncio.wait_for(
self.loop.run_in_executor(None, self.curl.ws_recv), timeout
)
except asyncio.TimeoutError as e:
raise WebSocketTimeout("WebSocket recv_fragment() timed out") from e
if frame.flags & CurlWsFlag.CLOSE:
try:
code, message = self._close_code, self._close_reason = (
self._unpack_close_frame(chunk)
)
except WebSocketError as e:
# Follow the spec to close the connection
# Errors do not respect autoclose
self._close_code = e.code
await self.close(e.code)
raise
if self.autoclose:
await self.close(code, message.encode())
return chunk, frame
async def recv(self, *, timeout: Optional[float] = None) -> tuple[bytes, int]:
"""
Receive a frame as bytes. libcurl splits frames into fragments, so we have to
collect all the chunks for a frame.
Args:
timeout: how many seconds to wait before giving up.
"""
loop = self.loop
chunks = []
flags = 0
sock_fd = await loop.run_in_executor(
None, self.curl.getinfo, CurlInfo.ACTIVESOCKET
)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
while True:
try:
chunk, frame = await self.recv_fragment(timeout=timeout)
flags = frame.flags
chunks.append(chunk)
if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0:
break
except CurlError as e:
if e.code == CurlECode.AGAIN:
await aselect(sock_fd, loop=loop, timeout=timeout)
else:
raise
return b"".join(chunks), flags
async def recv_str(self, *, timeout: Optional[float] = None) -> str:
"""Receive a text frame.
Args:
timeout: how many seconds to wait before giving up.
"""
data, flags = await self.recv(timeout=timeout)
if not (flags & CurlWsFlag.TEXT):
raise WebSocketError("Invalid UTF-8", WsCloseCode.INVALID_DATA)
return data.decode()
async def recv_json(
self, *, loads: Callable[[str], T] = loads, timeout: Optional[float] = None
) -> T:
"""Receive a JSON frame.
Args:
loads: JSON decoder, default is json.loads.
timeout: how many seconds to wait before giving up.
"""
data = await self.recv_str(timeout=timeout)
return loads(data)
async def send(
self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.BINARY
):
"""Send a data frame.
Args:
payload: data to send.
flags: flags for the frame.
"""
if self.closed:
raise WebSocketClosed("WebSocket is closed")
# curl expects bytes
if isinstance(payload, str):
payload = payload.encode()
sock_fd = await self.loop.run_in_executor(
None, self.curl.getinfo, CurlInfo.ACTIVESOCKET
)
if sock_fd == CURL_SOCKET_BAD:
raise WebSocketError(
"Invalid active socket", CurlECode.NO_CONNECTION_AVAILABLE
)
# TODO: Why does concurrently sending fail
async with self._send_lock:
offset = 0
# Loop checks for CurlECode.Again
# https://curl.se/libcurl/c/curl_ws_send.html
while offset < len(payload):
current_buffer = payload[offset:]
try:
n_sent = await self.loop.run_in_executor(
None, self.curl.ws_send, current_buffer, flags
)
except CurlError as e:
if e.code == CurlECode.AGAIN:
writeable = await aselect(
sock_fd, mode="write", loop=self.loop, timeout=0.5
)
if not writeable:
raise WebSocketError("Socket write timeout") from e
continue
raise
offset += n_sent
return offset
async def send_binary(self, payload: bytes):
"""Send a binary frame.
Args:
payload: binary data to send.
"""
return await self.send(payload, CurlWsFlag.BINARY)
async def send_bytes(self, payload: bytes):
"""Send a binary frame, alias of :meth:`send_binary`.
Args:
payload: binary data to send.
"""
return await self.send(payload, CurlWsFlag.BINARY)
async def send_str(self, payload: str):
"""Send a text frame.
Args:
payload: text data to send.
"""
return await self.send(payload, CurlWsFlag.TEXT)
async def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps):
"""Send a JSON frame.
Args:
payload: data to send.
dumps: JSON encoder, default is json.dumps.
"""
return await self.send_str(dumps(payload))
async def ping(self, payload: Union[str, bytes]):
"""Send a ping frame.
Args:
payload: data to send.
"""
return await self.send(payload, CurlWsFlag.PING)
async def close(self, code: int = WsCloseCode.OK, message: bytes = b""):
"""Close the connection.
Args:
code: close code.
message: close reason.
"""
# TODO: As per spec, we should wait for the server to close the connection
# But this is not a requirement
msg = self._pack_close_frame(code, message)
await self.send(msg, CurlWsFlag.CLOSE)
# The only way to close the connection appears to be curl_easy_cleanup
self.terminate()
def terminate(self):
"""Terminate the underlying connection."""
super().terminate()
if not self.session._closed:
# WebSocket curls CANNOT be reused
self.session.push_curl(None)