diff --git a/httpx/_client.py b/httpx/_client.py index 13cd9336..710a864c 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -46,7 +46,7 @@ from ._types import ( TimeoutTypes, ) from ._urls import URL, QueryParams -from ._utils import URLPattern, get_environment_proxies +from ._utils import URLPattern, get_environment_proxies, build_url_pattern if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -695,7 +695,7 @@ class Client(BaseClient): transport=transport, ) self._mounts: dict[URLPattern, BaseTransport | None] = { - URLPattern(key): None + build_url_pattern(key): None if proxy is None else self._init_proxy_transport( proxy, @@ -710,7 +710,7 @@ class Client(BaseClient): } if mounts is not None: self._mounts.update( - {URLPattern(key): transport for key, transport in mounts.items()} + {build_url_pattern(key): transport for key, transport in mounts.items()} ) self._mounts = dict(sorted(self._mounts.items())) @@ -1410,7 +1410,7 @@ class AsyncClient(BaseClient): ) self._mounts: dict[URLPattern, AsyncBaseTransport | None] = { - URLPattern(key): None + build_url_pattern(key): None if proxy is None else self._init_proxy_transport( proxy, @@ -1425,7 +1425,7 @@ class AsyncClient(BaseClient): } if mounts is not None: self._mounts.update( - {URLPattern(key): transport for key, transport in mounts.items()} + {build_url_pattern(key): transport for key, transport in mounts.items()} ) self._mounts = dict(sorted(self._mounts.items())) diff --git a/httpx/_utils.py b/httpx/_utils.py index 7c8820cd..4588a6df 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -6,6 +6,8 @@ import re import typing from urllib.request import getproxies +from abc import abstractmethod + from ._types import PrimitiveData if typing.TYPE_CHECKING: # pragma: no cover @@ -123,24 +125,35 @@ def peek_filelike_length(stream: typing.Any) -> int | None: return length -class URLPattern: +class Pattern(typing.Protocol): + @abstractmethod + def matches(self, other: URL) -> bool: + pass + + @property + @abstractmethod + def priority(self) -> tuple[int, int, int]: + pass + + +class WildcardURLPattern(Pattern): """ A utility class currently used for making lookups against proxy keys... # Wildcard matching... - >>> pattern = URLPattern("all://") + >>> pattern = WildcardURLPattern("all://") >>> pattern.matches(httpx.URL("http://example.com")) True # Witch scheme matching... - >>> pattern = URLPattern("https://") + >>> pattern = WildcardURLPattern("https://") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) False # With domain matching... - >>> pattern = URLPattern("https://example.com") + >>> pattern = WildcardURLPattern("https://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) @@ -149,7 +162,7 @@ class URLPattern: False # Wildcard scheme, with domain matching... - >>> pattern = URLPattern("all://example.com") + >>> pattern = WildcardURLPattern("all://example.com") >>> pattern.matches(httpx.URL("https://example.com")) True >>> pattern.matches(httpx.URL("http://example.com")) @@ -158,7 +171,7 @@ class URLPattern: False # With port matching... - >>> pattern = URLPattern("https://example.com:1234") + >>> pattern = WildcardURLPattern("https://example.com:1234") >>> pattern.matches(httpx.URL("https://example.com:1234")) True >>> pattern.matches(httpx.URL("https://example.com")) @@ -229,7 +242,51 @@ class URLPattern: return self.priority < other.priority def __eq__(self, other: typing.Any) -> bool: - return isinstance(other, URLPattern) and self.pattern == other.pattern + return isinstance(other, WildcardURLPattern) and self.pattern == other.pattern + + +class IPNetPattern(Pattern): + def __init__(self, ip_net: str) -> None: + try: + addr, range = ip_net.split('/', 1) + if addr[0] == '[' and addr[-1] == ']': + addr = addr[1:-1] + ip_net = f'{addr}/{range}' + except ValueError: + pass # not a range + self.net = ipaddress.ip_network(ip_net) + + def matches(self, other: URL): + try: + return ipaddress.ip_address(other.host) in self.net + except ValueError: + return False + + @property + def priority(self) -> tuple[int, int, int]: + return -1, 0, 0 # higher priority than URLPatterns + + def __hash__(self) -> int: + return hash(self.net) + + def __lt__(self, other: URLPattern) -> bool: + return self.priority < other.priority + + def __eq__(self, other: typing.Any) -> bool: + return isinstance(other, IPNetPattern) and self.net == other.net + + +URLPattern = IPNetPattern | WildcardURLPattern + + +def build_url_pattern(pattern: str) -> URLPattern: + try: + proto, rest = pattern.split('://', 1) + if proto == 'all' and '/' in rest: + return IPNetPattern(rest) + except ValueError: # covers .split() and IPNetPattern + pass + return WildcardURLPattern(pattern) def is_ipv4_hostname(hostname: str) -> bool: @@ -245,4 +302,4 @@ def is_ipv6_hostname(hostname: str) -> bool: ipaddress.IPv6Address(hostname.split("/")[0]) except Exception: return False - return True + return True \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py index 276db172..046818a2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import random import pytest import httpx -from httpx._utils import URLPattern, get_environment_proxies +from httpx._utils import build_url_pattern, get_environment_proxies @pytest.mark.parametrize( @@ -128,24 +128,30 @@ def test_get_environment_proxies(environment, proxies): ("http://", "https://example.com", False), ("all://", "https://example.com:123", True), ("", "https://example.com:123", True), + ('all://192.168.0.0/24', 'http://192.168.0.1', True), + ('all://192.168.0.0/24', 'https://192.168.1.1', False), + ('all://[2001:db8:abcd:0012::]/64', 'http://[2001:db8:abcd:12::1]', True), + ('all://[2001:db8:abcd:0012::]/64', 'http://[2001:db8:abcd:13::1]:8080', False), ], ) def test_url_matches(pattern, url, expected): - pattern = URLPattern(pattern) + pattern = build_url_pattern(pattern) assert pattern.matches(httpx.URL(url)) == expected def test_pattern_priority(): matchers = [ - URLPattern("all://"), - URLPattern("http://"), - URLPattern("http://example.com"), - URLPattern("http://example.com:123"), + build_url_pattern("all://"), + build_url_pattern("http://"), + build_url_pattern("http://example.com"), + build_url_pattern("http://example.com:123"), + build_url_pattern("192.168.0.1/16"), ] random.shuffle(matchers) assert sorted(matchers) == [ - URLPattern("http://example.com:123"), - URLPattern("http://example.com"), - URLPattern("http://"), - URLPattern("all://"), + build_url_pattern("192.168.0.1/16"), + build_url_pattern("http://example.com:123"), + build_url_pattern("http://example.com"), + build_url_pattern("http://"), + build_url_pattern("all://"), ]