PYTHON-2560 Retry KMS requests on transient errors (#2024)

This commit is contained in:
Shane Harvey 2024-12-03 16:16:47 -08:00 committed by GitHub
parent ce1c49a668
commit ff2f95987f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 260 additions and 40 deletions

View File

@ -19,6 +19,7 @@ import asyncio
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
@ -63,7 +64,11 @@ from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.asynchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
@ -72,7 +77,7 @@ from pymongo.errors import (
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
@ -88,6 +93,9 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
_IS_SYNC = False
@ -103,6 +111,13 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return await _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
@ -166,8 +181,8 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
@ -175,9 +190,13 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
await asyncio.sleep(sleep_sec)
try:
conn = await _configured_socket((host, port), opts)
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
@ -194,20 +213,29 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

View File

@ -19,6 +19,7 @@ import asyncio
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
@ -67,7 +68,7 @@ from pymongo.errors import (
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
@ -80,7 +81,11 @@ from pymongo.synchronous.collection import Collection
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.synchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern
@ -88,6 +93,9 @@ from pymongo.write_concern import WriteConcern
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext
from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address
_IS_SYNC = True
@ -103,6 +111,13 @@ _DATA_KEY_OPTS: CodecOptions[dict[str, Any]] = CodecOptions(
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)
def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
@ -166,8 +181,8 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
@ -175,9 +190,13 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
time.sleep(sleep_sec)
try:
conn = _configured_socket((host, port), opts)
conn = _connect_kms(address, opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
@ -194,20 +213,29 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.

View File

@ -17,6 +17,8 @@ from __future__ import annotations
import base64
import copy
import http.client
import json
import os
import pathlib
import re
@ -91,6 +93,7 @@ from pymongo.errors import (
WriteError,
)
from pymongo.operations import InsertOne, ReplaceOne, UpdateOne
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
@ -1366,9 +1369,8 @@ class TestCustomEndpoint(AsyncEncryptionIntegrationTest):
"key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"),
"endpoint": "kms.us-east-1.amazonaws.com:12345",
}
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx:
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"):
await self.client_encryption.create_data_key("aws", master_key=master_key)
self.assertIsInstance(ctx.exception.cause, AutoReconnect)
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def test_05_aws_endpoint_wrong_region(self):
@ -2853,6 +2855,86 @@ class TestRangeQueryDefaultsProse(AsyncEncryptionIntegrationTest):
assert len(payload) > len(self.payload_defaults)
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests
class TestKmsRetryProse(AsyncEncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def asyncSetUp(self):
await super().asyncSetUp()
# 1, create client with only tlsCAFile.
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003"
providers["gcp"]["endpoint"] = "127.0.0.1:9003"
kms_tls_opts = {
p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers
}
self.client_encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)
async def http_post(self, path, data=None):
# Note, the connection to the mock server needs to be closed after
# each request because the server is single threaded.
ctx: ssl.SSLContext = get_ssl_context(
CLIENT_PEM, # certfile
None, # passphrase
CA_PEM, # ca_certs
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
)
conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx)
try:
if data is not None:
headers = {"Content-type": "application/json"}
body = json.dumps(data)
else:
headers = {}
body = None
conn.request("POST", path, body, headers)
res = conn.getresponse()
res.read()
finally:
conn.close()
async def _test(self, provider, master_key):
await self.http_post("/reset")
# Case 1: createDataKey and encrypt with TCP retry
await self.http_post("/set_failpoint/network", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/network", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)
# Case 2: createDataKey and encrypt with HTTP retry
await self.http_post("/set_failpoint/http", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/http", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)
# Case 3: createDataKey fails after too many retries
await self.http_post("/set_failpoint/network", {"count": 4})
with self.assertRaisesRegex(EncryptionError, "KMS request failed after"):
await self.client_encryption.create_data_key(provider, master_key=master_key)
async def test_kms_retry(self):
await self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"})
await self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"})
await self._test(
"gcp",
{
"projectId": "foo",
"location": "bar",
"keyRing": "baz",
"keyName": "qux",
"endpoint": "127.0.0.1:9003",
},
)
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys
class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
@async_client_context.require_no_standalone

View File

@ -17,6 +17,8 @@ from __future__ import annotations
import base64
import copy
import http.client
import json
import os
import pathlib
import re
@ -88,6 +90,7 @@ from pymongo.errors import (
WriteError,
)
from pymongo.operations import InsertOne, ReplaceOne, UpdateOne
from pymongo.ssl_support import get_ssl_context
from pymongo.synchronous import encryption
from pymongo.synchronous.encryption import Algorithm, ClientEncryption, QueryType
from pymongo.synchronous.mongo_client import MongoClient
@ -1360,9 +1363,8 @@ class TestCustomEndpoint(EncryptionIntegrationTest):
"key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"),
"endpoint": "kms.us-east-1.amazonaws.com:12345",
}
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx:
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"):
self.client_encryption.create_data_key("aws", master_key=master_key)
self.assertIsInstance(ctx.exception.cause, AutoReconnect)
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def test_05_aws_endpoint_wrong_region(self):
@ -2835,6 +2837,86 @@ class TestRangeQueryDefaultsProse(EncryptionIntegrationTest):
assert len(payload) > len(self.payload_defaults)
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests
class TestKmsRetryProse(EncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
def setUp(self):
super().setUp()
# 1, create client with only tlsCAFile.
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003"
providers["gcp"]["endpoint"] = "127.0.0.1:9003"
kms_tls_opts = {
p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers
}
self.client_encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)
def http_post(self, path, data=None):
# Note, the connection to the mock server needs to be closed after
# each request because the server is single threaded.
ctx: ssl.SSLContext = get_ssl_context(
CLIENT_PEM, # certfile
None, # passphrase
CA_PEM, # ca_certs
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
)
conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx)
try:
if data is not None:
headers = {"Content-type": "application/json"}
body = json.dumps(data)
else:
headers = {}
body = None
conn.request("POST", path, body, headers)
res = conn.getresponse()
res.read()
finally:
conn.close()
def _test(self, provider, master_key):
self.http_post("/reset")
# Case 1: createDataKey and encrypt with TCP retry
self.http_post("/set_failpoint/network", {"count": 1})
key_id = self.client_encryption.create_data_key(provider, master_key=master_key)
self.http_post("/set_failpoint/network", {"count": 1})
self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)
# Case 2: createDataKey and encrypt with HTTP retry
self.http_post("/set_failpoint/http", {"count": 1})
key_id = self.client_encryption.create_data_key(provider, master_key=master_key)
self.http_post("/set_failpoint/http", {"count": 1})
self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)
# Case 3: createDataKey fails after too many retries
self.http_post("/set_failpoint/network", {"count": 4})
with self.assertRaisesRegex(EncryptionError, "KMS request failed after"):
self.client_encryption.create_data_key(provider, master_key=master_key)
def test_kms_retry(self):
self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"})
self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"})
self._test(
"gcp",
{
"projectId": "foo",
"location": "bar",
"keyRing": "baz",
"keyName": "qux",
"endpoint": "127.0.0.1:9003",
},
)
# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys
class TestAutomaticDecryptionKeys(EncryptionIntegrationTest):
@client_context.require_no_standalone