PYTHON-5021 - Fix usages of getaddrinfo to be non-blocking (#2059)
This commit is contained in:
parent
8fa6750a7e
commit
e4d84494c3
@ -38,6 +38,7 @@ from pymongo.asynchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
)
|
||||
from pymongo.asynchronous.helpers import _getaddrinfo
|
||||
from pymongo.auth_shared import (
|
||||
MongoCredential,
|
||||
_authenticate_scram_start,
|
||||
@ -177,15 +178,22 @@ def _auth_key(nonce: str, username: str, password: str) -> str:
|
||||
return md5hash.hexdigest()
|
||||
|
||||
|
||||
def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
|
||||
async def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
|
||||
"""Canonicalize hostname following MIT-krb5 behavior."""
|
||||
# https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520
|
||||
if option in [False, "none"]:
|
||||
return hostname
|
||||
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
af, socktype, proto, canonname, sockaddr = (
|
||||
await _getaddrinfo(
|
||||
hostname,
|
||||
None,
|
||||
family=0,
|
||||
type=0,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
flags=socket.AI_CANONNAME,
|
||||
)
|
||||
)[0] # type: ignore[index]
|
||||
|
||||
# For forward just to resolve the cname as dns.lookup() will not return it.
|
||||
if option == "forward":
|
||||
@ -213,7 +221,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnecti
|
||||
# Starting here and continuing through the while loop below - establish
|
||||
# the security context. See RFC 4752, Section 3.1, first paragraph.
|
||||
host = props.service_host or conn.address[0]
|
||||
host = _canonicalize_hostname(host, props.canonicalize_host_name)
|
||||
host = await _canonicalize_hostname(host, props.canonicalize_host_name)
|
||||
service = props.service_name + "@" + host
|
||||
if props.service_realm is not None:
|
||||
service = service + "@" + props.service_realm
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
"""Miscellaneous pieces that need to be synchronized."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import socket
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
@ -68,6 +70,24 @@ def _handle_reauth(func: F) -> F:
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
async def _getaddrinfo(
|
||||
host: Any, port: Any, **kwargs: Any
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if not _IS_SYNC:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
|
||||
else:
|
||||
return socket.getaddrinfo(host, port, **kwargs)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
anext = builtins.anext
|
||||
aiter = builtins.aiter
|
||||
|
||||
@ -40,7 +40,7 @@ from typing import (
|
||||
from bson import DEFAULT_CODEC_OPTIONS
|
||||
from pymongo import _csot, helpers_shared
|
||||
from pymongo.asynchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.asynchronous.helpers import _handle_reauth
|
||||
from pymongo.asynchronous.helpers import _getaddrinfo, _handle_reauth
|
||||
from pymongo.asynchronous.network import command, receive_message
|
||||
from pymongo.common import (
|
||||
MAX_BSON_SIZE,
|
||||
@ -783,7 +783,7 @@ class AsyncConnection:
|
||||
)
|
||||
|
||||
|
||||
def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
async def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
|
||||
"""Given (host, port) and PoolOptions, connect and return a socket object.
|
||||
|
||||
Can raise socket.error.
|
||||
@ -814,7 +814,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
|
||||
for res in await _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
@ -863,7 +863,7 @@ async def _configured_socket(
|
||||
|
||||
Sets socket's SSL and timeout options.
|
||||
"""
|
||||
sock = _create_connection(address, options)
|
||||
sock = await _create_connection(address, options)
|
||||
ssl_context = options._ssl_context
|
||||
|
||||
if ssl_context is None:
|
||||
|
||||
@ -45,6 +45,7 @@ from pymongo.synchronous.auth_oidc import (
|
||||
_authenticate_oidc,
|
||||
_get_authenticator,
|
||||
)
|
||||
from pymongo.synchronous.helpers import _getaddrinfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.hello import Hello
|
||||
@ -180,9 +181,16 @@ def _canonicalize_hostname(hostname: str, option: str | bool) -> str:
|
||||
if option in [False, "none"]:
|
||||
return hostname
|
||||
|
||||
af, socktype, proto, canonname, sockaddr = socket.getaddrinfo(
|
||||
hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME
|
||||
)[0]
|
||||
af, socktype, proto, canonname, sockaddr = (
|
||||
_getaddrinfo(
|
||||
hostname,
|
||||
None,
|
||||
family=0,
|
||||
type=0,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
flags=socket.AI_CANONNAME,
|
||||
)
|
||||
)[0] # type: ignore[index]
|
||||
|
||||
# For forward just to resolve the cname as dns.lookup() will not return it.
|
||||
if option == "forward":
|
||||
|
||||
@ -15,7 +15,9 @@
|
||||
"""Miscellaneous pieces that need to be synchronized."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
import socket
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
@ -68,6 +70,24 @@ def _handle_reauth(func: F) -> F:
|
||||
return cast(F, inner)
|
||||
|
||||
|
||||
def _getaddrinfo(
|
||||
host: Any, port: Any, **kwargs: Any
|
||||
) -> list[
|
||||
tuple[
|
||||
socket.AddressFamily,
|
||||
socket.SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
if not _IS_SYNC:
|
||||
loop = asyncio.get_running_loop()
|
||||
return loop.getaddrinfo(host, port, **kwargs) # type: ignore[return-value]
|
||||
else:
|
||||
return socket.getaddrinfo(host, port, **kwargs)
|
||||
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
next = builtins.next
|
||||
iter = builtins.iter
|
||||
|
||||
@ -84,7 +84,7 @@ from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.socket_checker import SocketChecker
|
||||
from pymongo.ssl_support import HAS_SNI, SSLError
|
||||
from pymongo.synchronous.client_session import _validate_session_write_concern
|
||||
from pymongo.synchronous.helpers import _handle_reauth
|
||||
from pymongo.synchronous.helpers import _getaddrinfo, _handle_reauth
|
||||
from pymongo.synchronous.network import command, receive_message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -812,7 +812,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
|
||||
family = socket.AF_UNSPEC
|
||||
|
||||
err = None
|
||||
for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
|
||||
for res in _getaddrinfo(host, port, family=family, type=socket.SOCK_STREAM): # type: ignore[attr-defined]
|
||||
af, socktype, proto, dummy, sa = res
|
||||
# SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited
|
||||
# number of platforms (newer Linux and *BSD). Starting with CPython 3.4
|
||||
|
||||
@ -275,10 +275,10 @@ class TestGSSAPI(AsyncPyMongoTestCase):
|
||||
async def test_gssapi_canonicalize_host_name(self):
|
||||
# Test the low level method.
|
||||
assert GSSAPI_HOST is not None
|
||||
result = _canonicalize_hostname(GSSAPI_HOST, "forward")
|
||||
result = await _canonicalize_hostname(GSSAPI_HOST, "forward")
|
||||
if "compute-1.amazonaws.com" not in result:
|
||||
self.assertEqual(result, GSSAPI_HOST)
|
||||
result = _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
|
||||
result = await _canonicalize_hostname(GSSAPI_HOST, "forwardAndReverse")
|
||||
self.assertEqual(result, GSSAPI_HOST)
|
||||
|
||||
# Use the equivalent named CANONICALIZE_HOST_NAME.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user