PYTHON-4260 Lazily load optional imports (#1550)

This commit is contained in:
Steven Silvester 2024-03-25 12:55:41 -05:00 committed by GitHub
parent 5e49363c97
commit 42a08c4a34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 201 additions and 92 deletions

View File

@ -2071,6 +2071,21 @@ tasks:
bash $SCRIPT -p $CONFIG -h ${github_commit} -o "mongodb" -n "mongo-python-driver" bash $SCRIPT -p $CONFIG -h ${github_commit} -o "mongodb" -n "mongo-python-driver"
echo '{"results": [{ "status": "PASS", "test_file": "Build", "log_raw": "Test completed" } ]}' > ${PROJECT_DIRECTORY}/test-results.json echo '{"results": [{ "status": "PASS", "test_file": "Build", "log_raw": "Test completed" } ]}' > ${PROJECT_DIRECTORY}/test-results.json
- name: "check-import-time"
tags: ["pr"]
commands:
- command: shell.exec
type: test
params:
shell: "bash"
working_dir: src
script: |
${PREPARE_SHELL}
set -x
export BASE_SHA=${revision}
export HEAD_SHA=${github_commit}
bash .evergreen/run-import-time-test.sh
axes: axes:
# Choice of distro # Choice of distro
- id: platform - id: platform
@ -3046,6 +3061,12 @@ buildvariants:
tasks: tasks:
- name: "assign-pr-reviewer" - name: "assign-pr-reviewer"
- name: rhel8-import-time
display_name: Import Time Check
run_on: rhel87-small
tasks:
- name: "check-import-time"
- name: Release - name: Release
display_name: Release display_name: Release
batchtime: 20160 # 14 days batchtime: 20160 # 14 days

View File

@ -0,0 +1,31 @@
#!/bin/bash -ex
set -o errexit # Exit the script with error if any of the commands fail
set -x
. .evergreen/utils.sh
if [ -z "$PYTHON_BINARY" ]; then
PYTHON_BINARY=$(find_python3)
fi
# Use the previous commit if this was not a PR run.
if [ "$BASE_SHA" == "$HEAD_SHA" ]; then
BASE_SHA=$(git rev-parse HEAD~1)
fi
function get_import_time() {
local log_file
createvirtualenv "$PYTHON_BINARY" import-venv
python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]"
# Import once to cache modules
python -c "import pymongo"
log_file="pymongo-$1.log"
python -X importtime -c "import pymongo" 2> $log_file
}
get_import_time $HEAD_SHA
git checkout $BASE_SHA
get_import_time $BASE_SHA
git checkout $HEAD_SHA
python tools/compare_import_time.py $HEAD_SHA $BASE_SHA

View File

@ -248,7 +248,9 @@ if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
fi fi
if [ -n "$GREEN_FRAMEWORK" ]; then if [ -n "$GREEN_FRAMEWORK" ]; then
python -m pip install $GREEN_FRAMEWORK # Install all optional deps to ensure lazy imports are getting patched.
python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]"
python -m pip install $GREEN_FRAMEWORK
fi fi
# Show the installed packages # Show the installed packages

View File

@ -74,6 +74,12 @@ Unavoidable breaking changes
>>> dict_to_SON(data_as_dict) >>> dict_to_SON(data_as_dict)
SON([('driver', SON([('name', 'PyMongo'), ('version', '4.7.0.dev0')])), ('os', SON([('type', 'Darwin'), ('name', 'Darwin'), ('architecture', 'arm64'), ('version', '14.3')])), ('platform', 'CPython 3.11.6.final.0')]) SON([('driver', SON([('name', 'PyMongo'), ('version', '4.7.0.dev0')])), ('os', SON([('type', 'Darwin'), ('name', 'Darwin'), ('architecture', 'arm64'), ('version', '14.3')])), ('platform', 'CPython 3.11.6.final.0')])
- PyMongo now uses `lazy imports <https://docs.python.org/3/library/importlib.html#implementing-lazy-imports>`_ for external dependencies.
If you are relying on any kind of monkey-patching of the standard library, you may need to explicitly import those external libraries in addition
to ``pymongo`` before applying the patch. Note that we test with ``gevent`` and ``eventlet`` patching, and those continue to work.
- The "aws" extra now requires minimum version of ``1.1.0`` for ``pymongo_auth_aws``.
Changes in Version 4.6.2 Changes in Version 4.6.2
------------------------ ------------------------

View File

@ -17,12 +17,14 @@ from __future__ import annotations
import json import json
from typing import Any, Optional from typing import Any, Optional
from urllib.request import Request, urlopen
def _get_azure_response( def _get_azure_response(
resource: str, client_id: Optional[str] = None, timeout: float = 5 resource: str, client_id: Optional[str] = None, timeout: float = 5
) -> dict[str, Any]: ) -> dict[str, Any]:
# Deferred import to save overall import time.
from urllib.request import Request, urlopen
url = "http://169.254.169.254/metadata/identity/oauth2/token" url = "http://169.254.169.254/metadata/identity/oauth2/token"
url += "?api-version=2018-02-01" url += "?api-version=2018-02-01"
url += f"&resource={resource}" url += f"&resource={resource}"

View File

@ -21,9 +21,10 @@ import time
from collections import deque from collections import deque
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from contextvars import ContextVar, Token from contextvars import ContextVar, Token
from typing import Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
from pymongo.write_concern import WriteConcern if TYPE_CHECKING:
from pymongo.write_concern import WriteConcern
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None) TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
RTT: ContextVar[float] = ContextVar("RTT", default=0.0) RTT: ContextVar[float] = ContextVar("RTT", default=0.0)

38
pymongo/_lazy_import.py Normal file
View File

@ -0,0 +1,38 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you
# may not use this file except in compliance with the License. You
# may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import annotations
import importlib.util
import sys
from types import ModuleType
def lazy_import(name: str) -> ModuleType:
"""Lazily import a module by name
From https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
"""
try:
spec = importlib.util.find_spec(name)
except ValueError:
raise ModuleNotFoundError(name=name) from None
if spec is None:
raise ModuleNotFoundError(name=name)
assert spec is not None
loader = importlib.util.LazyLoader(spec.loader) # type:ignore[arg-type]
spec.loader = loader
module = importlib.util.module_from_spec(spec)
sys.modules[name] = module
loader.exec_module(module)
return module

View File

@ -15,38 +15,16 @@
"""MONGODB-AWS Authentication helpers.""" """MONGODB-AWS Authentication helpers."""
from __future__ import annotations from __future__ import annotations
try: from pymongo._lazy_import import lazy_import
import pymongo_auth_aws # type:ignore[import]
from pymongo_auth_aws import (
AwsCredential,
AwsSaslContext,
PyMongoAuthAwsError,
)
try:
pymongo_auth_aws = lazy_import("pymongo_auth_aws")
_HAVE_MONGODB_AWS = True _HAVE_MONGODB_AWS = True
except ImportError: except ImportError:
class AwsSaslContext: # type: ignore
def __init__(self, credentials: MongoCredential):
pass
_HAVE_MONGODB_AWS = False _HAVE_MONGODB_AWS = False
try:
from pymongo_auth_aws.auth import ( # type:ignore[import]
set_cached_credentials,
set_use_cached_credentials,
)
# Enable credential caching. from typing import TYPE_CHECKING, Any, Mapping, Type
set_use_cached_credentials(True)
except ImportError:
def set_cached_credentials(_creds: Optional[AwsCredential]) -> None:
pass
from typing import TYPE_CHECKING, Any, Mapping, Optional, Type
import bson import bson
from bson.binary import Binary from bson.binary import Binary
@ -58,21 +36,6 @@ if TYPE_CHECKING:
from pymongo.pool import Connection from pymongo.pool import Connection
class _AwsSaslContext(AwsSaslContext): # type: ignore
# Dependency injection:
def binary_type(self) -> Type[Binary]:
"""Return the bson.binary.Binary type."""
return Binary
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
"""Encode a dictionary to BSON."""
return bson.encode(doc)
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
"""Decode BSON to a dictionary."""
return bson.decode(data)
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
"""Authenticate using MONGODB-AWS.""" """Authenticate using MONGODB-AWS."""
if not _HAVE_MONGODB_AWS: if not _HAVE_MONGODB_AWS:
@ -84,9 +47,23 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
if conn.max_wire_version < 9: if conn.max_wire_version < 9:
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
# Dependency injection:
def binary_type(self) -> Type[Binary]:
"""Return the bson.binary.Binary type."""
return Binary
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
"""Encode a dictionary to BSON."""
return bson.encode(doc)
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
"""Decode BSON to a dictionary."""
return bson.decode(data)
try: try:
ctx = _AwsSaslContext( ctx = AwsSaslContext(
AwsCredential( pymongo_auth_aws.AwsCredential(
credentials.username, credentials.username,
credentials.password, credentials.password,
credentials.mechanism_properties.aws_session_token, credentials.mechanism_properties.aws_session_token,
@ -108,14 +85,14 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
if res["done"]: if res["done"]:
# SASL complete. # SASL complete.
break break
except PyMongoAuthAwsError as exc: except pymongo_auth_aws.PyMongoAuthAwsError as exc:
# Clear the cached credentials if we hit a failure in auth. # Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None) pymongo_auth_aws.set_cached_credentials(None)
# Convert to OperationFailure and include pymongo-auth-aws version. # Convert to OperationFailure and include pymongo-auth-aws version.
raise OperationFailure( raise OperationFailure(
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})" f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
) from None ) from None
except Exception: except Exception:
# Clear the cached credentials if we hit a failure in auth. # Clear the cached credentials if we hit a failure in auth.
set_cached_credentials(None) pymongo_auth_aws.set_cached_credentials(None)
raise raise

View File

@ -16,16 +16,19 @@ from __future__ import annotations
import warnings import warnings
from typing import Any, Iterable, Optional, Union from typing import Any, Iterable, Optional, Union
try: from pymongo._lazy_import import lazy_import
import snappy # type:ignore[import] from pymongo.hello import HelloCompat
from pymongo.monitoring import _SENSITIVE_COMMANDS
try:
snappy = lazy_import("snappy")
_HAVE_SNAPPY = True _HAVE_SNAPPY = True
except ImportError: except ImportError:
# python-snappy isn't available. # python-snappy isn't available.
_HAVE_SNAPPY = False _HAVE_SNAPPY = False
try: try:
import zlib zlib = lazy_import("zlib")
_HAVE_ZLIB = True _HAVE_ZLIB = True
except ImportError: except ImportError:
@ -33,15 +36,11 @@ except ImportError:
_HAVE_ZLIB = False _HAVE_ZLIB = False
try: try:
from zstandard import ZstdCompressor, ZstdDecompressor zstandard = lazy_import("zstandard")
_HAVE_ZSTD = True _HAVE_ZSTD = True
except ImportError: except ImportError:
_HAVE_ZSTD = False _HAVE_ZSTD = False
from pymongo.hello import HelloCompat
from pymongo.monitoring import _SENSITIVE_COMMANDS
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} _SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) _NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
@ -138,7 +137,7 @@ class ZstdContext:
def compress(data: bytes) -> bytes: def compress(data: bytes) -> bytes:
# ZstdCompressor is not thread safe. # ZstdCompressor is not thread safe.
# TODO: Use a pool? # TODO: Use a pool?
return ZstdCompressor().compress(data) return zstandard.ZstdCompressor().compress(data)
def decompress(data: bytes, compressor_id: int) -> bytes: def decompress(data: bytes, compressor_id: int) -> bytes:
@ -153,6 +152,6 @@ def decompress(data: bytes, compressor_id: int) -> bytes:
elif compressor_id == ZstdContext.compressor_id: elif compressor_id == ZstdContext.compressor_id:
# ZstdDecompressor is not thread safe. # ZstdDecompressor is not thread safe.
# TODO: Use a pool? # TODO: Use a pool?
return ZstdDecompressor().decompress(data) return zstandard.ZstdDecompressor().decompress(data)
else: else:
raise ValueError("Unknown compressorId %d" % (compressor_id,)) raise ValueError("Unknown compressorId %d" % (compressor_id,))

View File

@ -15,6 +15,7 @@
"""Exceptions raised by PyMongo.""" """Exceptions raised by PyMongo."""
from __future__ import annotations from __future__ import annotations
from ssl import SSLCertVerificationError as _CertificateError # noqa: F401
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union
from bson.errors import InvalidDocument from bson.errors import InvalidDocument
@ -22,17 +23,6 @@ from bson.errors import InvalidDocument
if TYPE_CHECKING: if TYPE_CHECKING:
from pymongo.typings import _DocumentOut from pymongo.typings import _DocumentOut
try:
# CPython 3.7+
from ssl import SSLCertVerificationError as _CertificateError
except ImportError:
try:
from ssl import CertificateError as _CertificateError
except ImportError:
class _CertificateError(ValueError): # type: ignore
pass
class PyMongoError(Exception): class PyMongoError(Exception):
"""Base class for all PyMongo exceptions.""" """Base class for all PyMongo exceptions."""

View File

@ -25,14 +25,10 @@ from errno import EINTR as _EINTR
from ipaddress import ip_address as _ip_address from ipaddress import ip_address as _ip_address
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
from OpenSSL import SSL as _SSL from OpenSSL import SSL as _SSL
from OpenSSL import crypto as _crypto from OpenSSL import crypto as _crypto
from service_identity import CertificateError as _SICertificateError
from service_identity import VerificationError as _SIVerificationError
from service_identity.pyopenssl import verify_hostname as _verify_hostname
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address
from pymongo._lazy_import import lazy_import
from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import ConfigurationError as _ConfigurationError
from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.errors import _CertificateError # type:ignore[attr-defined]
from pymongo.ocsp_cache import _OCSPCache from pymongo.ocsp_cache import _OCSPCache
@ -41,6 +37,10 @@ from pymongo.socket_checker import SocketChecker as _SocketChecker
from pymongo.socket_checker import _errno_from_exception from pymongo.socket_checker import _errno_from_exception
from pymongo.write_concern import validate_boolean from pymongo.write_concern import validate_boolean
_x509 = lazy_import("cryptography.x509")
_service_identity = lazy_import("service_identity")
_service_identity_pyopenssl = lazy_import("service_identity.pyopenssl")
if TYPE_CHECKING: if TYPE_CHECKING:
from ssl import VerifyMode from ssl import VerifyMode
@ -340,7 +340,7 @@ class SSLContext:
if encoding == "x509_asn": if encoding == "x509_asn":
if trust is True or oid in trust: if trust is True or oid in trust:
cert_store.add_cert( cert_store.add_cert(
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert)) _crypto.X509.from_cryptography(_x509.load_der_x509_certificate(cert))
) )
def load_default_certs(self) -> None: def load_default_certs(self) -> None:
@ -406,9 +406,12 @@ class SSLContext:
if self.check_hostname and server_hostname is not None: if self.check_hostname and server_hostname is not None:
try: try:
if _is_ip_address(server_hostname): if _is_ip_address(server_hostname):
_verify_ip_address(ssl_conn, server_hostname) _service_identity_pyopenssl.verify_ip_address(ssl_conn, server_hostname)
else: else:
_verify_hostname(ssl_conn, server_hostname) _service_identity_pyopenssl.verify_hostname(ssl_conn, server_hostname)
except (_SICertificateError, _SIVerificationError) as exc: except (
_service_identity.SICertificateError,
_service_identity.SIVerificationError,
) as exc:
raise _CertificateError(str(exc)) from None raise _CertificateError(str(exc)) from None
return ssl_conn return ssl_conn

View File

@ -17,18 +17,19 @@ from __future__ import annotations
import ipaddress import ipaddress
import random import random
from typing import Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
try:
from dns import resolver
_HAVE_DNSPYTHON = True
except ImportError:
_HAVE_DNSPYTHON = False
from pymongo._lazy_import import lazy_import
from pymongo.common import CONNECT_TIMEOUT from pymongo.common import CONNECT_TIMEOUT
from pymongo.errors import ConfigurationError from pymongo.errors import ConfigurationError
if TYPE_CHECKING:
from dns import resolver
else:
resolver = lazy_import("dns.resolver")
_HAVE_DNSPYTHON = True
# dnspython can return bytes or str from various parts # dnspython can return bytes or str from various parts
# of its API depending on version. We always want str. # of its API depending on version. We always want str.

View File

@ -45,7 +45,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
aws = [ aws = [
"pymongo-auth-aws<2.0.0", "pymongo-auth-aws>=1.1.0,<2.0.0",
] ]
encryption = [ encryption = [
"pymongo[aws]", "pymongo[aws]",
@ -207,6 +207,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?)|dummy.*)$"
"test/*.py" = ["PT", "E402", "PLW", "SIM", "E741", "PTH", "S", "B904", "E722", "T201", "test/*.py" = ["PT", "E402", "PLW", "SIM", "E741", "PTH", "S", "B904", "E722", "T201",
"RET", "ARG", "F405", "B028", "PGH001", "B018", "F403", "RUF015", "E731", "B007", "RET", "ARG", "F405", "B028", "PGH001", "B018", "F403", "RUF015", "E731", "B007",
"UP031", "F401", "B023", "F811"] "UP031", "F401", "B023", "F811"]
"tools/*.py" = ["T201"]
"green_framework_test.py" = ["T201"] "green_framework_test.py" = ["T201"]
[tool.coverage.run] [tool.coverage.run]

View File

@ -0,0 +1,37 @@
# Copyright 2024-Present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import sys
base_sha = sys.argv[-1]
head_sha = sys.argv[-2]
def get_total_time(sha: str) -> int:
with open(f"pymongo-{sha}.log") as fid:
last_line = fid.readlines()[-1]
return int(last_line.split()[4])
base_time = get_total_time(base_sha)
curr_time = get_total_time(head_sha)
# Check if we got 20% or more slower.
change = int((curr_time - base_time) / base_time * 100)
if change > 20:
print(f"PyMongo import got {change} percent worse")
sys.exit(1)
print(f"Import time changed by {change} percent")

View File

@ -35,7 +35,7 @@ for dirname in ["pymongo", "bson", "gridfs"]:
missing.append(path) missing.append(path)
if missing: if missing:
print(f"Missing '{pattern}' import in:") # noqa: T201 print(f"Missing '{pattern}' import in:")
for item in missing: for item in missing:
print(item) # noqa: T201 print(item)
sys.exit(1) sys.exit(1)

View File

@ -35,7 +35,7 @@ if os.environ.get("ENSURE_UNIVERSAL2") == "1":
parent_dir = Path(pymongo.__path__[0]).parent parent_dir = Path(pymongo.__path__[0]).parent
for pkg in ["pymongo", "bson", "grifs"]: for pkg in ["pymongo", "bson", "grifs"]:
for so_file in Path(f"{parent_dir}/{pkg}").glob("*.so"): for so_file in Path(f"{parent_dir}/{pkg}").glob("*.so"):
print(f"Checking universal2 compatibility in {so_file}...") # noqa: T201 print(f"Checking universal2 compatibility in {so_file}...")
output = subprocess.check_output(["file", so_file]) # noqa: S603, S607 output = subprocess.check_output(["file", so_file]) # noqa: S603, S607
if "arm64" not in output.decode("utf-8"): if "arm64" not in output.decode("utf-8"):
sys.exit("Universal wheel was not compiled with arm64 support") sys.exit("Universal wheel was not compiled with arm64 support")