PYTHON-4260 Lazily load optional imports (#1550)
This commit is contained in:
parent
5e49363c97
commit
42a08c4a34
@ -2071,6 +2071,21 @@ tasks:
|
||||
bash $SCRIPT -p $CONFIG -h ${github_commit} -o "mongodb" -n "mongo-python-driver"
|
||||
echo '{"results": [{ "status": "PASS", "test_file": "Build", "log_raw": "Test completed" } ]}' > ${PROJECT_DIRECTORY}/test-results.json
|
||||
|
||||
- name: "check-import-time"
|
||||
tags: ["pr"]
|
||||
commands:
|
||||
- command: shell.exec
|
||||
type: test
|
||||
params:
|
||||
shell: "bash"
|
||||
working_dir: src
|
||||
script: |
|
||||
${PREPARE_SHELL}
|
||||
set -x
|
||||
export BASE_SHA=${revision}
|
||||
export HEAD_SHA=${github_commit}
|
||||
bash .evergreen/run-import-time-test.sh
|
||||
|
||||
axes:
|
||||
# Choice of distro
|
||||
- id: platform
|
||||
@ -3046,6 +3061,12 @@ buildvariants:
|
||||
tasks:
|
||||
- name: "assign-pr-reviewer"
|
||||
|
||||
- name: rhel8-import-time
|
||||
display_name: Import Time Check
|
||||
run_on: rhel87-small
|
||||
tasks:
|
||||
- name: "check-import-time"
|
||||
|
||||
- name: Release
|
||||
display_name: Release
|
||||
batchtime: 20160 # 14 days
|
||||
|
||||
31
.evergreen/run-import-time-test.sh
Executable file
31
.evergreen/run-import-time-test.sh
Executable file
@ -0,0 +1,31 @@
|
||||
#!/bin/bash -ex
|
||||
|
||||
set -o errexit # Exit the script with error if any of the commands fail
|
||||
set -x
|
||||
|
||||
. .evergreen/utils.sh
|
||||
|
||||
if [ -z "$PYTHON_BINARY" ]; then
|
||||
PYTHON_BINARY=$(find_python3)
|
||||
fi
|
||||
|
||||
# Use the previous commit if this was not a PR run.
|
||||
if [ "$BASE_SHA" == "$HEAD_SHA" ]; then
|
||||
BASE_SHA=$(git rev-parse HEAD~1)
|
||||
fi
|
||||
|
||||
function get_import_time() {
|
||||
local log_file
|
||||
createvirtualenv "$PYTHON_BINARY" import-venv
|
||||
python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]"
|
||||
# Import once to cache modules
|
||||
python -c "import pymongo"
|
||||
log_file="pymongo-$1.log"
|
||||
python -X importtime -c "import pymongo" 2> $log_file
|
||||
}
|
||||
|
||||
get_import_time $HEAD_SHA
|
||||
git checkout $BASE_SHA
|
||||
get_import_time $BASE_SHA
|
||||
git checkout $HEAD_SHA
|
||||
python tools/compare_import_time.py $HEAD_SHA $BASE_SHA
|
||||
@ -248,7 +248,9 @@ if [ -n "$COVERAGE" ] && [ "$PYTHON_IMPL" = "CPython" ]; then
|
||||
fi
|
||||
|
||||
if [ -n "$GREEN_FRAMEWORK" ]; then
|
||||
python -m pip install $GREEN_FRAMEWORK
|
||||
# Install all optional deps to ensure lazy imports are getting patched.
|
||||
python -m pip install -q ".[aws,encryption,gssapi,ocsp,snappy,zstd]"
|
||||
python -m pip install $GREEN_FRAMEWORK
|
||||
fi
|
||||
|
||||
# Show the installed packages
|
||||
|
||||
@ -74,6 +74,12 @@ Unavoidable breaking changes
|
||||
>>> dict_to_SON(data_as_dict)
|
||||
SON([('driver', SON([('name', 'PyMongo'), ('version', '4.7.0.dev0')])), ('os', SON([('type', 'Darwin'), ('name', 'Darwin'), ('architecture', 'arm64'), ('version', '14.3')])), ('platform', 'CPython 3.11.6.final.0')])
|
||||
|
||||
- PyMongo now uses `lazy imports <https://docs.python.org/3/library/importlib.html#implementing-lazy-imports>`_ for external dependencies.
|
||||
If you are relying on any kind of monkey-patching of the standard library, you may need to explicitly import those external libraries in addition
|
||||
to ``pymongo`` before applying the patch. Note that we test with ``gevent`` and ``eventlet`` patching, and those continue to work.
|
||||
|
||||
- The "aws" extra now requires minimum version of ``1.1.0`` for ``pymongo_auth_aws``.
|
||||
|
||||
Changes in Version 4.6.2
|
||||
------------------------
|
||||
|
||||
|
||||
@ -17,12 +17,14 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
def _get_azure_response(
|
||||
resource: str, client_id: Optional[str] = None, timeout: float = 5
|
||||
) -> dict[str, Any]:
|
||||
# Deferred import to save overall import time.
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
url = "http://169.254.169.254/metadata/identity/oauth2/token"
|
||||
url += "?api-version=2018-02-01"
|
||||
url += f"&resource={resource}"
|
||||
|
||||
@ -21,9 +21,10 @@ import time
|
||||
from collections import deque
|
||||
from contextlib import AbstractContextManager
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Deque, MutableMapping, Optional, TypeVar, cast
|
||||
|
||||
from pymongo.write_concern import WriteConcern
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
|
||||
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
|
||||
|
||||
38
pymongo/_lazy_import.py
Normal file
38
pymongo/_lazy_import.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you
|
||||
# may not use this file except in compliance with the License. You
|
||||
# may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
def lazy_import(name: str) -> ModuleType:
|
||||
"""Lazily import a module by name
|
||||
|
||||
From https://docs.python.org/3/library/importlib.html#implementing-lazy-imports
|
||||
"""
|
||||
try:
|
||||
spec = importlib.util.find_spec(name)
|
||||
except ValueError:
|
||||
raise ModuleNotFoundError(name=name) from None
|
||||
if spec is None:
|
||||
raise ModuleNotFoundError(name=name)
|
||||
assert spec is not None
|
||||
loader = importlib.util.LazyLoader(spec.loader) # type:ignore[arg-type]
|
||||
spec.loader = loader
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[name] = module
|
||||
loader.exec_module(module)
|
||||
return module
|
||||
@ -15,38 +15,16 @@
|
||||
"""MONGODB-AWS Authentication helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
try:
|
||||
import pymongo_auth_aws # type:ignore[import]
|
||||
from pymongo_auth_aws import (
|
||||
AwsCredential,
|
||||
AwsSaslContext,
|
||||
PyMongoAuthAwsError,
|
||||
)
|
||||
from pymongo._lazy_import import lazy_import
|
||||
|
||||
try:
|
||||
pymongo_auth_aws = lazy_import("pymongo_auth_aws")
|
||||
_HAVE_MONGODB_AWS = True
|
||||
except ImportError:
|
||||
|
||||
class AwsSaslContext: # type: ignore
|
||||
def __init__(self, credentials: MongoCredential):
|
||||
pass
|
||||
|
||||
_HAVE_MONGODB_AWS = False
|
||||
|
||||
try:
|
||||
from pymongo_auth_aws.auth import ( # type:ignore[import]
|
||||
set_cached_credentials,
|
||||
set_use_cached_credentials,
|
||||
)
|
||||
|
||||
# Enable credential caching.
|
||||
set_use_cached_credentials(True)
|
||||
except ImportError:
|
||||
|
||||
def set_cached_credentials(_creds: Optional[AwsCredential]) -> None:
|
||||
pass
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Type
|
||||
|
||||
import bson
|
||||
from bson.binary import Binary
|
||||
@ -58,21 +36,6 @@ if TYPE_CHECKING:
|
||||
from pymongo.pool import Connection
|
||||
|
||||
|
||||
class _AwsSaslContext(AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self) -> Type[Binary]:
|
||||
"""Return the bson.binary.Binary type."""
|
||||
return Binary
|
||||
|
||||
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
|
||||
"""Encode a dictionary to BSON."""
|
||||
return bson.encode(doc)
|
||||
|
||||
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
|
||||
"""Decode BSON to a dictionary."""
|
||||
return bson.decode(data)
|
||||
|
||||
|
||||
def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
"""Authenticate using MONGODB-AWS."""
|
||||
if not _HAVE_MONGODB_AWS:
|
||||
@ -84,9 +47,23 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if conn.max_wire_version < 9:
|
||||
raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later")
|
||||
|
||||
class AwsSaslContext(pymongo_auth_aws.AwsSaslContext): # type: ignore
|
||||
# Dependency injection:
|
||||
def binary_type(self) -> Type[Binary]:
|
||||
"""Return the bson.binary.Binary type."""
|
||||
return Binary
|
||||
|
||||
def bson_encode(self, doc: Mapping[str, Any]) -> bytes:
|
||||
"""Encode a dictionary to BSON."""
|
||||
return bson.encode(doc)
|
||||
|
||||
def bson_decode(self, data: _ReadableBuffer) -> Mapping[str, Any]:
|
||||
"""Decode BSON to a dictionary."""
|
||||
return bson.decode(data)
|
||||
|
||||
try:
|
||||
ctx = _AwsSaslContext(
|
||||
AwsCredential(
|
||||
ctx = AwsSaslContext(
|
||||
pymongo_auth_aws.AwsCredential(
|
||||
credentials.username,
|
||||
credentials.password,
|
||||
credentials.mechanism_properties.aws_session_token,
|
||||
@ -108,14 +85,14 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None:
|
||||
if res["done"]:
|
||||
# SASL complete.
|
||||
break
|
||||
except PyMongoAuthAwsError as exc:
|
||||
except pymongo_auth_aws.PyMongoAuthAwsError as exc:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
pymongo_auth_aws.set_cached_credentials(None)
|
||||
# Convert to OperationFailure and include pymongo-auth-aws version.
|
||||
raise OperationFailure(
|
||||
f"{exc} (pymongo-auth-aws version {pymongo_auth_aws.__version__})"
|
||||
) from None
|
||||
except Exception:
|
||||
# Clear the cached credentials if we hit a failure in auth.
|
||||
set_cached_credentials(None)
|
||||
pymongo_auth_aws.set_cached_credentials(None)
|
||||
raise
|
||||
|
||||
@ -16,16 +16,19 @@ from __future__ import annotations
|
||||
import warnings
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
try:
|
||||
import snappy # type:ignore[import]
|
||||
from pymongo._lazy_import import lazy_import
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.monitoring import _SENSITIVE_COMMANDS
|
||||
|
||||
try:
|
||||
snappy = lazy_import("snappy")
|
||||
_HAVE_SNAPPY = True
|
||||
except ImportError:
|
||||
# python-snappy isn't available.
|
||||
_HAVE_SNAPPY = False
|
||||
|
||||
try:
|
||||
import zlib
|
||||
zlib = lazy_import("zlib")
|
||||
|
||||
_HAVE_ZLIB = True
|
||||
except ImportError:
|
||||
@ -33,15 +36,11 @@ except ImportError:
|
||||
_HAVE_ZLIB = False
|
||||
|
||||
try:
|
||||
from zstandard import ZstdCompressor, ZstdDecompressor
|
||||
|
||||
zstandard = lazy_import("zstandard")
|
||||
_HAVE_ZSTD = True
|
||||
except ImportError:
|
||||
_HAVE_ZSTD = False
|
||||
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.monitoring import _SENSITIVE_COMMANDS
|
||||
|
||||
_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"}
|
||||
_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD}
|
||||
_NO_COMPRESSION.update(_SENSITIVE_COMMANDS)
|
||||
@ -138,7 +137,7 @@ class ZstdContext:
|
||||
def compress(data: bytes) -> bytes:
|
||||
# ZstdCompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
return ZstdCompressor().compress(data)
|
||||
return zstandard.ZstdCompressor().compress(data)
|
||||
|
||||
|
||||
def decompress(data: bytes, compressor_id: int) -> bytes:
|
||||
@ -153,6 +152,6 @@ def decompress(data: bytes, compressor_id: int) -> bytes:
|
||||
elif compressor_id == ZstdContext.compressor_id:
|
||||
# ZstdDecompressor is not thread safe.
|
||||
# TODO: Use a pool?
|
||||
return ZstdDecompressor().decompress(data)
|
||||
return zstandard.ZstdDecompressor().decompress(data)
|
||||
else:
|
||||
raise ValueError("Unknown compressorId %d" % (compressor_id,))
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
"""Exceptions raised by PyMongo."""
|
||||
from __future__ import annotations
|
||||
|
||||
from ssl import SSLCertVerificationError as _CertificateError # noqa: F401
|
||||
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Optional, Sequence, Union
|
||||
|
||||
from bson.errors import InvalidDocument
|
||||
@ -22,17 +23,6 @@ from bson.errors import InvalidDocument
|
||||
if TYPE_CHECKING:
|
||||
from pymongo.typings import _DocumentOut
|
||||
|
||||
try:
|
||||
# CPython 3.7+
|
||||
from ssl import SSLCertVerificationError as _CertificateError
|
||||
except ImportError:
|
||||
try:
|
||||
from ssl import CertificateError as _CertificateError
|
||||
except ImportError:
|
||||
|
||||
class _CertificateError(ValueError): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class PyMongoError(Exception):
|
||||
"""Base class for all PyMongo exceptions."""
|
||||
|
||||
@ -25,14 +25,10 @@ from errno import EINTR as _EINTR
|
||||
from ipaddress import ip_address as _ip_address
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate
|
||||
from OpenSSL import SSL as _SSL
|
||||
from OpenSSL import crypto as _crypto
|
||||
from service_identity import CertificateError as _SICertificateError
|
||||
from service_identity import VerificationError as _SIVerificationError
|
||||
from service_identity.pyopenssl import verify_hostname as _verify_hostname
|
||||
from service_identity.pyopenssl import verify_ip_address as _verify_ip_address
|
||||
|
||||
from pymongo._lazy_import import lazy_import
|
||||
from pymongo.errors import ConfigurationError as _ConfigurationError
|
||||
from pymongo.errors import _CertificateError # type:ignore[attr-defined]
|
||||
from pymongo.ocsp_cache import _OCSPCache
|
||||
@ -41,6 +37,10 @@ from pymongo.socket_checker import SocketChecker as _SocketChecker
|
||||
from pymongo.socket_checker import _errno_from_exception
|
||||
from pymongo.write_concern import validate_boolean
|
||||
|
||||
_x509 = lazy_import("cryptography.x509")
|
||||
_service_identity = lazy_import("service_identity")
|
||||
_service_identity_pyopenssl = lazy_import("service_identity.pyopenssl")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ssl import VerifyMode
|
||||
|
||||
@ -340,7 +340,7 @@ class SSLContext:
|
||||
if encoding == "x509_asn":
|
||||
if trust is True or oid in trust:
|
||||
cert_store.add_cert(
|
||||
_crypto.X509.from_cryptography(_load_der_x509_certificate(cert))
|
||||
_crypto.X509.from_cryptography(_x509.load_der_x509_certificate(cert))
|
||||
)
|
||||
|
||||
def load_default_certs(self) -> None:
|
||||
@ -406,9 +406,12 @@ class SSLContext:
|
||||
if self.check_hostname and server_hostname is not None:
|
||||
try:
|
||||
if _is_ip_address(server_hostname):
|
||||
_verify_ip_address(ssl_conn, server_hostname)
|
||||
_service_identity_pyopenssl.verify_ip_address(ssl_conn, server_hostname)
|
||||
else:
|
||||
_verify_hostname(ssl_conn, server_hostname)
|
||||
except (_SICertificateError, _SIVerificationError) as exc:
|
||||
_service_identity_pyopenssl.verify_hostname(ssl_conn, server_hostname)
|
||||
except (
|
||||
_service_identity.SICertificateError,
|
||||
_service_identity.SIVerificationError,
|
||||
) as exc:
|
||||
raise _CertificateError(str(exc)) from None
|
||||
return ssl_conn
|
||||
|
||||
@ -17,18 +17,19 @@ from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import random
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
try:
|
||||
from dns import resolver
|
||||
|
||||
_HAVE_DNSPYTHON = True
|
||||
except ImportError:
|
||||
_HAVE_DNSPYTHON = False
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pymongo._lazy_import import lazy_import
|
||||
from pymongo.common import CONNECT_TIMEOUT
|
||||
from pymongo.errors import ConfigurationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dns import resolver
|
||||
else:
|
||||
resolver = lazy_import("dns.resolver")
|
||||
|
||||
_HAVE_DNSPYTHON = True
|
||||
|
||||
|
||||
# dnspython can return bytes or str from various parts
|
||||
# of its API depending on version. We always want str.
|
||||
|
||||
@ -45,7 +45,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
aws = [
|
||||
"pymongo-auth-aws<2.0.0",
|
||||
"pymongo-auth-aws>=1.1.0,<2.0.0",
|
||||
]
|
||||
encryption = [
|
||||
"pymongo[aws]",
|
||||
@ -207,6 +207,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?)|dummy.*)$"
|
||||
"test/*.py" = ["PT", "E402", "PLW", "SIM", "E741", "PTH", "S", "B904", "E722", "T201",
|
||||
"RET", "ARG", "F405", "B028", "PGH001", "B018", "F403", "RUF015", "E731", "B007",
|
||||
"UP031", "F401", "B023", "F811"]
|
||||
"tools/*.py" = ["T201"]
|
||||
"green_framework_test.py" = ["T201"]
|
||||
|
||||
[tool.coverage.run]
|
||||
|
||||
37
tools/compare_import_time.py
Normal file
37
tools/compare_import_time.py
Normal file
@ -0,0 +1,37 @@
|
||||
# Copyright 2024-Present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
base_sha = sys.argv[-1]
|
||||
head_sha = sys.argv[-2]
|
||||
|
||||
|
||||
def get_total_time(sha: str) -> int:
|
||||
with open(f"pymongo-{sha}.log") as fid:
|
||||
last_line = fid.readlines()[-1]
|
||||
return int(last_line.split()[4])
|
||||
|
||||
|
||||
base_time = get_total_time(base_sha)
|
||||
curr_time = get_total_time(head_sha)
|
||||
|
||||
# Check if we got 20% or more slower.
|
||||
change = int((curr_time - base_time) / base_time * 100)
|
||||
if change > 20:
|
||||
print(f"PyMongo import got {change} percent worse")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Import time changed by {change} percent")
|
||||
@ -35,7 +35,7 @@ for dirname in ["pymongo", "bson", "gridfs"]:
|
||||
missing.append(path)
|
||||
|
||||
if missing:
|
||||
print(f"Missing '{pattern}' import in:") # noqa: T201
|
||||
print(f"Missing '{pattern}' import in:")
|
||||
for item in missing:
|
||||
print(item) # noqa: T201
|
||||
print(item)
|
||||
sys.exit(1)
|
||||
|
||||
@ -35,7 +35,7 @@ if os.environ.get("ENSURE_UNIVERSAL2") == "1":
|
||||
parent_dir = Path(pymongo.__path__[0]).parent
|
||||
for pkg in ["pymongo", "bson", "grifs"]:
|
||||
for so_file in Path(f"{parent_dir}/{pkg}").glob("*.so"):
|
||||
print(f"Checking universal2 compatibility in {so_file}...") # noqa: T201
|
||||
print(f"Checking universal2 compatibility in {so_file}...")
|
||||
output = subprocess.check_output(["file", so_file]) # noqa: S603, S607
|
||||
if "arm64" not in output.decode("utf-8"):
|
||||
sys.exit("Universal wheel was not compiled with arm64 support")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user